From d0825514a05754f864f3d578a06d2bceb402f54b Mon Sep 17 00:00:00 2001 From: JackDoan Date: Fri, 17 Apr 2026 10:25:05 -0500 Subject: [PATCH] GSO again --- batch.go | 70 ++++ batch_test.go | 137 +++++++ connection_manager_test.go | 9 +- inside.go | 69 +++- interface.go | 103 +++++- outside.go | 2 +- overlay/device.go | 55 ++- test/tun.go => overlay/noop.go | 13 +- overlay/tun_android.go | 39 +- overlay/tun_darwin.go | 43 ++- overlay/tun_disabled.go | 41 ++- overlay/tun_freebsd.go | 24 +- overlay/tun_ios.go | 37 +- overlay/tun_linux.go | 434 +++++++++++++++++++--- overlay/tun_linux_offload.go | 331 +++++++++++++++++ overlay/tun_linux_offload_test.go | 333 +++++++++++++++++ overlay/tun_netbsd.go | 24 +- overlay/tun_openbsd.go | 24 +- overlay/tun_tester.go | 26 +- overlay/tun_windows.go | 26 +- overlay/user.go | 23 +- tcp_coalesce.go | 484 +++++++++++++++++++++++++ tcp_coalesce_test.go | 576 ++++++++++++++++++++++++++++++ udp/conn.go | 40 ++- udp/udp_darwin.go | 23 +- udp/udp_generic.go | 23 +- udp/udp_linux.go | 323 ++++++++++++++++- udp/udp_linux_32.go | 32 ++ udp/udp_linux_64.go | 32 ++ udp/udp_rio_windows.go | 23 +- udp/udp_tester.go | 23 +- 31 files changed, 3278 insertions(+), 164 deletions(-) create mode 100644 batch.go create mode 100644 batch_test.go rename test/tun.go => overlay/noop.go (76%) create mode 100644 overlay/tun_linux_offload.go create mode 100644 overlay/tun_linux_offload_test.go create mode 100644 tcp_coalesce.go create mode 100644 tcp_coalesce_test.go diff --git a/batch.go b/batch.go new file mode 100644 index 00000000..02d86bc7 --- /dev/null +++ b/batch.go @@ -0,0 +1,70 @@ +package nebula + +import "net/netip" + +// sendBatchCap is the maximum number of encrypted packets accumulated before a +// flush is forced. TSO superpackets segment to at most ~45 packets on +// reasonable MTUs, so 128 leaves headroom without bloating the backing +// allocation. +const sendBatchCap = 128 + +// sendBatch accumulates encrypted UDP packets for a single sendmmsg flush. +// One sendBatch is owned by each listenIn goroutine; no locking is needed. +// The backing storage holds up to batchCap packets of slotCap bytes each; +// bufs and dsts are parallel slices of committed slots. +type sendBatch struct { + bufs [][]byte + dsts []netip.AddrPort + backing []byte + slotCap int + batchCap int + nextSlot int +} + +func newSendBatch(batchCap, slotCap int) *sendBatch { + return &sendBatch{ + bufs: make([][]byte, 0, batchCap), + dsts: make([]netip.AddrPort, 0, batchCap), + backing: make([]byte, batchCap*slotCap), + slotCap: slotCap, + batchCap: batchCap, + } +} + +// Next returns a zero-length slice with slotCap capacity over the next unused +// slot's backing bytes. The caller writes into the returned slice and then +// calls Commit with the final length and destination. Next returns nil when +// the batch is full. +func (b *sendBatch) Next() []byte { + if b.nextSlot >= b.batchCap { + return nil + } + start := b.nextSlot * b.slotCap + return b.backing[start : start : start+b.slotCap] +} + +// Commit records the slot just returned by Next as a packet of length n +// destined for dst. +func (b *sendBatch) Commit(n int, dst netip.AddrPort) { + start := b.nextSlot * b.slotCap + b.bufs = append(b.bufs, b.backing[start:start+n]) + b.dsts = append(b.dsts, dst) + b.nextSlot++ +} + +// Reset clears committed slots; backing storage is retained for reuse. +func (b *sendBatch) Reset() { + b.bufs = b.bufs[:0] + b.dsts = b.dsts[:0] + b.nextSlot = 0 +} + +// Len returns the number of committed packets. +func (b *sendBatch) Len() int { + return len(b.bufs) +} + +// Cap returns the maximum number of slots in the batch. +func (b *sendBatch) Cap() int { + return b.batchCap +} diff --git a/batch_test.go b/batch_test.go new file mode 100644 index 00000000..f33f9381 --- /dev/null +++ b/batch_test.go @@ -0,0 +1,137 @@ +package nebula + +import ( + "net/netip" + "testing" +) + +func TestSendBatchBookkeeping(t *testing.T) { + b := newSendBatch(4, 32) + if b.Len() != 0 || b.Cap() != 4 { + t.Fatalf("fresh batch: len=%d cap=%d", b.Len(), b.Cap()) + } + + ap := netip.MustParseAddrPort("10.0.0.1:4242") + for i := 0; i < 4; i++ { + slot := b.Next() + if slot == nil { + t.Fatalf("slot %d: Next returned nil before cap", i) + } + if cap(slot) != 32 || len(slot) != 0 { + t.Fatalf("slot %d: got len=%d cap=%d want len=0 cap=32", i, len(slot), cap(slot)) + } + // Write a marker byte. + slot = append(slot, byte(i), byte(i+1), byte(i+2)) + b.Commit(len(slot), ap) + } + if b.Next() != nil { + t.Fatalf("Next should return nil when full") + } + if b.Len() != 4 { + t.Fatalf("Len=%d want 4", b.Len()) + } + for i, buf := range b.bufs { + if len(buf) != 3 || buf[0] != byte(i) { + t.Errorf("buf %d: %x", i, buf) + } + if b.dsts[i] != ap { + t.Errorf("dst %d: got %v want %v", i, b.dsts[i], ap) + } + } + + // Reset returns empty and Next works again. + b.Reset() + if b.Len() != 0 { + t.Fatalf("after Reset Len=%d want 0", b.Len()) + } + slot := b.Next() + if slot == nil || cap(slot) != 32 { + t.Fatalf("after Reset Next nil or wrong cap: %v cap=%d", slot == nil, cap(slot)) + } +} + +func TestBatchSegmentable(t *testing.T) { + ap := netip.MustParseAddrPort("10.0.0.1:4242") + other := netip.MustParseAddrPort("10.0.0.2:4242") + + mk := func(addrs []netip.AddrPort, sizes []int) *sendBatch { + b := newSendBatch(len(addrs), 64) + for i, a := range addrs { + s := b.Next() + for j := 0; j < sizes[i]; j++ { + s = append(s, byte(j)) + } + b.Commit(len(s), a) + } + return b + } + + t.Run("uniform same dst", func(t *testing.T) { + b := mk([]netip.AddrPort{ap, ap, ap}, []int{10, 10, 10}) + seg, ok := batchSegmentable(b) + if !ok || seg != 10 { + t.Fatalf("got seg=%d ok=%v", seg, ok) + } + }) + + t.Run("last segment short ok", func(t *testing.T) { + b := mk([]netip.AddrPort{ap, ap, ap}, []int{10, 10, 4}) + seg, ok := batchSegmentable(b) + if !ok || seg != 10 { + t.Fatalf("got seg=%d ok=%v", seg, ok) + } + }) + + t.Run("mixed dst rejected", func(t *testing.T) { + b := mk([]netip.AddrPort{ap, other, ap}, []int{10, 10, 10}) + if _, ok := batchSegmentable(b); ok { + t.Fatalf("expected rejection for mixed dst") + } + }) + + t.Run("mid-batch short rejected", func(t *testing.T) { + b := mk([]netip.AddrPort{ap, ap, ap}, []int{10, 4, 10}) + if _, ok := batchSegmentable(b); ok { + t.Fatalf("expected rejection for short mid-batch") + } + }) + + t.Run("mid-batch longer rejected", func(t *testing.T) { + b := mk([]netip.AddrPort{ap, ap, ap}, []int{10, 11, 10}) + if _, ok := batchSegmentable(b); ok { + t.Fatalf("expected rejection for longer mid-batch") + } + }) + + t.Run("last longer rejected", func(t *testing.T) { + b := mk([]netip.AddrPort{ap, ap, ap}, []int{10, 10, 11}) + if _, ok := batchSegmentable(b); ok { + t.Fatalf("expected rejection for longer last segment") + } + }) + + t.Run("first zero rejected", func(t *testing.T) { + b := mk([]netip.AddrPort{ap, ap}, []int{0, 10}) + if _, ok := batchSegmentable(b); ok { + t.Fatalf("expected rejection for zero first") + } + }) +} + +func TestSendBatchSlotsDoNotOverlap(t *testing.T) { + b := newSendBatch(3, 8) + ap := netip.MustParseAddrPort("10.0.0.1:80") + + // Fill three slots, each with its own sentinel byte. + for i := 0; i < 3; i++ { + s := b.Next() + s = append(s, byte(0xA0+i), byte(0xB0+i)) + b.Commit(len(s), ap) + } + + for i, buf := range b.bufs { + if buf[0] != byte(0xA0+i) || buf[1] != byte(0xB0+i) { + t.Errorf("slot %d corrupted: %x", i, buf) + } + } +} diff --git a/connection_manager_test.go b/connection_manager_test.go index 647dd72b..6d07cd36 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -10,6 +10,7 @@ import ( "github.com/flynn/noise" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" @@ -52,7 +53,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, - inside: &test.NoopTun{}, + inside: &overlay.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, @@ -135,7 +136,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, - inside: &test.NoopTun{}, + inside: &overlay.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, @@ -220,7 +221,7 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) { lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, - inside: &test.NoopTun{}, + inside: &overlay.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, @@ -347,7 +348,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, - inside: &test.NoopTun{}, + inside: &overlay.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, diff --git a/inside.go b/inside.go index 0d53f952..981e93a0 100644 --- a/inside.go +++ b/inside.go @@ -11,7 +11,7 @@ import ( "github.com/slackhq/nebula/routing" ) -func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, batch *sendBatch, rejectBuf []byte, q int, localCache firewall.ConntrackCache) { err := newPacket(packet, false, fwPacket) if err != nil { if f.l.Level >= logrus.DebugLevel { @@ -33,7 +33,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet // routes packets from the Nebula addr to the Nebula addr through the Nebula // TUN device. if immediatelyForwardToSelf { - _, err := f.readers[q].Write(packet) + _, err := f.readers[q].WriteReject(packet) if err != nil { f.l.WithError(err).Error("Failed to forward to tun") } @@ -53,7 +53,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet }) if hostinfo == nil { - f.rejectInside(packet, out, q) + f.rejectInside(packet, rejectBuf, q) if f.l.Level >= logrus.DebugLevel { f.l.WithField("vpnAddr", fwPacket.RemoteAddr). WithField("fwPacket", fwPacket). @@ -68,10 +68,10 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { - f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) + f.sendInsideMessage(hostinfo, packet, nb, batch, rejectBuf, q) } else { - f.rejectInside(packet, out, q) + f.rejectInside(packet, rejectBuf, q) if f.l.Level >= logrus.DebugLevel { hostinfo.logger(f.l). WithField("fwPacket", fwPacket). @@ -81,6 +81,63 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } } +// sendInsideMessage encrypts a firewall-approved inside packet into the +// caller's batch slot for later sendmmsg flush. When hostinfo.remote is not +// valid we fall through to the relay slow path via the unbatched sendNoMetrics +// so relay behavior is unchanged. +func (f *Interface) sendInsideMessage(hostinfo *HostInfo, p, nb []byte, batch *sendBatch, rejectBuf []byte, q int) { + ci := hostinfo.ConnectionState + if ci.eKey == nil { + return + } + + if !hostinfo.remote.IsValid() { + // Slow path: relay fallback. Reuse rejectBuf as the ciphertext + // scratch; sendNoMetrics arranges header space for SendVia. + f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, p, nb, rejectBuf, q) + return + } + + scratch := batch.Next() + if scratch == nil { + // Batch full: bypass batching and send this packet directly so we + // never drop traffic on over-subscribed iterations. + f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, p, nb, rejectBuf, q) + return + } + + if noiseutil.EncryptLockNeeded { + ci.writeLock.Lock() + } + c := ci.messageCounter.Add(1) + + out := header.Encode(scratch, header.Version, header.Message, 0, hostinfo.remoteIndexId, c) + f.connectionManager.Out(hostinfo) + + if hostinfo.lastRebindCount != f.rebindCount { + //NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is + // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. + f.lightHouse.QueryServer(hostinfo.vpnAddrs[0]) + hostinfo.lastRebindCount = f.rebindCount + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter") + } + } + + out, err := ci.eKey.EncryptDanger(out, out, p, c, nb) + if noiseutil.EncryptLockNeeded { + ci.writeLock.Unlock() + } + if err != nil { + hostinfo.logger(f.l).WithError(err). + WithField("udpAddr", hostinfo.remote).WithField("counter", c). + Error("Failed to encrypt outgoing packet") + return + } + + batch.Commit(len(out), hostinfo.remote) +} + func (f *Interface) rejectInside(packet []byte, out []byte, q int) { if !f.firewall.InSendReject { return @@ -91,7 +148,7 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) { return } - _, err := f.readers[q].Write(out) + _, err := f.readers[q].WriteReject(out) if err != nil { f.l.WithError(err).Error("Failed to write to tun") } diff --git a/interface.go b/interface.go index 9e7a98a9..590b81ea 100644 --- a/interface.go +++ b/interface.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io" "net/netip" "sync" "sync/atomic" @@ -86,8 +85,12 @@ type Interface struct { conntrackCacheTimeout time.Duration writers []udp.Conn - readers []io.ReadWriteCloser - wg sync.WaitGroup + readers []overlay.Queue + // tunCoalescers is one tcpCoalescer per tun queue, wrapping readers[i]. + // decryptToTun sends plaintext into the coalescer; listenOut calls its + // Flush at the end of each UDP recvmmsg batch. + tunCoalescers []*tcpCoalescer + wg sync.WaitGroup // fatalErr holds the first unexpected reader error that caused shutdown. // nil means "no fatal error" (yet) @@ -184,7 +187,8 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { routines: c.routines, version: c.version, writers: make([]udp.Conn, c.routines), - readers: make([]io.ReadWriteCloser, c.routines), + readers: make([]overlay.Queue, c.routines), + tunCoalescers: make([]*tcpCoalescer, c.routines), myVpnNetworks: cs.myVpnNetworks, myVpnNetworksTable: cs.myVpnNetworksTable, myVpnAddrs: cs.myVpnAddrs, @@ -239,7 +243,7 @@ func (f *Interface) activate() error { metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines)) // Prepare n tun queues - var reader io.ReadWriteCloser = f.inside + var reader overlay.Queue = f.inside for i := 0; i < f.routines; i++ { if i > 0 { reader, err = f.inside.NewMultiQueueReader() @@ -248,6 +252,7 @@ func (f *Interface) activate() error { } } f.readers[i] = reader + f.tunCoalescers[i] = newTCPCoalescer(reader) } f.wg.Add(1) // for us to wait on Close() to return @@ -305,13 +310,28 @@ func (f *Interface) listenOut(i int) { ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) lhh := f.lightHouse.NewRequestHandler() - plaintext := make([]byte, udp.MTU) h := &header.H{} fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) + // plaintexts is a ring of decrypt scratches, one per packet in a UDP + // recvmmsg batch. The coalescer borrows payload slices from here and + // requires they stay valid until Flush — so we rotate each packet and + // reset only in the batch-end flush callback. + var plaintexts [][]byte + idx := 0 + coalescer := f.tunCoalescers[i] err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { - f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) + if idx >= len(plaintexts) { + plaintexts = append(plaintexts, make([]byte, udp.MTU)) + } + f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintexts[idx][:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) + idx++ + }, func() { + if err := coalescer.Flush(); err != nil { + f.l.WithError(err).Error("Failed to flush tun coalescer") + } + idx = 0 }) if err != nil && !f.closed.Load() { @@ -322,16 +342,16 @@ func (f *Interface) listenOut(i int) { f.l.Debugf("underlay reader %v is done", i) } -func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { - packet := make([]byte, mtu) - out := make([]byte, mtu) +func (f *Interface) listenIn(reader overlay.Queue, i int) { + rejectBuf := make([]byte, mtu) + batch := newSendBatch(sendBatchCap, udp.MTU+32) fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) for { - n, err := reader.Read(packet) + pkts, err := reader.Read() if err != nil { if !f.closed.Load() { f.l.WithError(err).WithField("reader", i).Error("Error while reading outbound packet, closing") @@ -340,12 +360,71 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { break } - f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) + batch.Reset() + for _, pkt := range pkts { + if batch.Len() >= batch.Cap() { + f.flushBatch(batch, i) + batch.Reset() + } + f.consumeInsidePacket(pkt, fwPacket, nb, batch, rejectBuf, i, conntrackCache.Get(f.l)) + } + if batch.Len() > 0 { + f.flushBatch(batch, i) + } } f.l.Debugf("overlay reader %v is done", i) } +func (f *Interface) flushBatch(batch *sendBatch, q int) { + //if len(batch.bufs) == 1 { + // if err := f.writers[q].WriteTo(batch.bufs[0], batch.dsts[0]); err != nil { + // f.l.WithError(err).WithField("writer", q).Error("Failed to write outgoing single-batch") + // } + // return + //} + w := f.writers[q] + if w.SupportsGSO() { + if segSize, ok := batchSegmentable(batch); ok { + if err := w.WriteSegmented(batch.bufs, batch.dsts[0], segSize); err != nil { + f.l.WithError(err).WithField("writer", q).Error("Failed to write outgoing GSO batch") + } + return + } + } + if err := w.WriteBatch(batch.bufs, batch.dsts); err != nil { + f.l.WithError(err).WithField("writer", q).Error("Failed to write outgoing batch") + } +} + +// batchSegmentable reports whether a batch can be emitted as a single UDP GSO +// superpacket: all packets go to the same destination, and every packet +// except possibly the last has the same length. Returns the segment size on +// success. The single-packet case is handled in flushBatch before this runs. +func batchSegmentable(b *sendBatch) (int, bool) { + segSize := len(b.bufs[0]) + if segSize == 0 { + return 0, false + } + dst := b.dsts[0] + last := len(b.bufs) - 1 + for i := 1; i <= last; i++ { + if b.dsts[i] != dst { + return 0, false + } + if i < last { + if len(b.bufs[i]) != segSize { + return 0, false + } + } else { + if len(b.bufs[i]) == 0 || len(b.bufs[i]) > segSize { + return 0, false + } + } + } + return segSize, true +} + func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { c.RegisterReloadCallback(f.reloadFirewall) c.RegisterReloadCallback(f.reloadSendRecvError) diff --git a/outside.go b/outside.go index eba9d887..41fa5dd4 100644 --- a/outside.go +++ b/outside.go @@ -535,7 +535,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out } f.connectionManager.In(hostinfo) - _, err = f.readers[q].Write(out) + err = f.tunCoalescers[q].Add(out) if err != nil { f.l.WithError(err).Error("Failed to write to tun") } diff --git a/overlay/device.go b/overlay/device.go index b6077aba..70ca01a5 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -7,12 +7,63 @@ import ( "github.com/slackhq/nebula/routing" ) +// defaultBatchBufSize is the per-Queue scratch size for Read on backends +// that don't do TSO segmentation. 65535 covers any single IP packet. +const defaultBatchBufSize = 65535 + +// Queue is a readable/writable tun queue. One Queue is driven by a single +// read goroutine plus concurrent writers (see Write / WriteReject below). +type Queue interface { + io.Closer + + // Read returns one or more packets. The returned slices are borrowed + // from the Queue's internal buffer and are only valid until the next + // Read or Close on this Queue — callers must encrypt or copy each + // slice before the next call. Not safe for concurrent Reads; exactly + // one goroutine per Queue reads. + Read() ([][]byte, error) + + // Write emits a single packet on the plaintext (outside→inside) + // delivery path. May run concurrently with WriteReject on the same + // Queue, but not with itself. + Write(p []byte) (int, error) + + // WriteReject writes a single packet that originated from the inside + // path (reject replies or self-forward) using scratch state distinct + // from Write, so it can run concurrently with Write on the same Queue + // without a data race. On backends without a shared-scratch Write, a + // trivial delegation to Write is acceptable. + WriteReject(p []byte) (int, error) +} + type Device interface { - io.ReadWriteCloser + Queue Activate() error Networks() []netip.Prefix Name() string RoutesFor(netip.Addr) routing.Gateways SupportsMultiqueue() bool - NewMultiQueueReader() (io.ReadWriteCloser, error) + NewMultiQueueReader() (Queue, error) +} + +// GSOWriter is implemented by Queues that can emit a TCP TSO superpacket +// assembled from a header prefix plus one or more borrowed payload +// fragments, in a single vectored write (writev with a leading +// virtio_net_hdr). This lets the coalescer avoid copying payload bytes +// between the caller's decrypt buffer and the TUN. Backends without GSO +// support return false from GSOSupported and coalescing is skipped. +// +// hdr contains the IPv4/IPv6 + TCP header prefix (mutable — callers will +// have filled in total length and pseudo-header partial). pays are +// non-overlapping payload fragments whose concatenation is the full +// superpacket payload; they are read-only from the writer's perspective +// and must remain valid until the call returns. gsoSize is the MSS: +// every segment except possibly the last is exactly that many bytes. +// csumStart is the byte offset where the TCP header begins within hdr. +// +// hdr's TCP checksum field must already hold the pseudo-header partial +// sum (single-fold, not inverted), per virtio NEEDS_CSUM semantics. +type GSOWriter interface { + WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error + GSOSupported() bool } diff --git a/test/tun.go b/overlay/noop.go similarity index 76% rename from test/tun.go rename to overlay/noop.go index fb32782f..dc2d3fb9 100644 --- a/test/tun.go +++ b/overlay/noop.go @@ -1,8 +1,7 @@ -package test +package overlay import ( "errors" - "io" "net/netip" "github.com/slackhq/nebula/routing" @@ -26,19 +25,23 @@ func (NoopTun) Name() string { return "noop" } -func (NoopTun) Read([]byte) (int, error) { - return 0, nil +func (NoopTun) Read() ([][]byte, error) { + return nil, nil } func (NoopTun) Write([]byte) (int, error) { return 0, nil } +func (NoopTun) WriteReject(p []byte) (int, error) { + return 0, nil +} + func (NoopTun) SupportsMultiqueue() bool { return false } -func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (NoopTun) NewMultiQueueReader() (Queue, error) { return nil, errors.New("unsupported") } diff --git a/overlay/tun_android.go b/overlay/tun_android.go index eddef882..62de337d 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -18,12 +18,39 @@ import ( ) type tun struct { - io.ReadWriteCloser + rwc io.ReadWriteCloser fd int vpnNetworks []netip.Prefix Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger + + readBuf []byte + batchRet [1][]byte +} + +func (t *tun) Read() ([][]byte, error) { + if t.readBuf == nil { + t.readBuf = make([]byte, defaultBatchBufSize) + } + n, err := t.rwc.Read(t.readBuf) + if err != nil { + return nil, err + } + t.batchRet[0] = t.readBuf[:n] + return t.batchRet[:], nil +} + +func (t *tun) Write(p []byte) (int, error) { + return t.rwc.Write(p) +} + +func (t *tun) WriteReject(p []byte) (int, error) { + return t.rwc.Write(p) +} + +func (t *tun) Close() error { + return t.rwc.Close() } func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { @@ -32,10 +59,10 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") t := &tun{ - ReadWriteCloser: file, - fd: deviceFd, - vpnNetworks: vpnNetworks, - l: l, + rwc: file, + fd: deviceFd, + vpnNetworks: vpnNetworks, + l: l, } err := t.reload(c, true) @@ -99,6 +126,6 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *tun) NewMultiQueueReader() (Queue, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for android") } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 128c2001..7f50c705 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -23,7 +23,7 @@ import ( ) type tun struct { - io.ReadWriteCloser + rwc io.ReadWriteCloser Device string vpnNetworks []netip.Prefix DefaultMTU int @@ -34,6 +34,9 @@ type tun struct { // cache out buffer since we need to prepend 4 bytes for tun metadata out []byte + + readBuf []byte + batchRet [1][]byte } type ifReq struct { @@ -124,11 +127,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( } t := &tun{ - ReadWriteCloser: os.NewFile(uintptr(fd), ""), - Device: name, - vpnNetworks: vpnNetworks, - DefaultMTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, + rwc: os.NewFile(uintptr(fd), ""), + Device: name, + vpnNetworks: vpnNetworks, + DefaultMTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, } err = t.reload(c, true) @@ -158,8 +161,8 @@ func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, } func (t *tun) Close() error { - if t.ReadWriteCloser != nil { - return t.ReadWriteCloser.Close() + if t.rwc != nil { + return t.rwc.Close() } return nil } @@ -503,15 +506,31 @@ func delRoute(prefix netip.Prefix, gateway netroute.Addr) error { return nil } -func (t *tun) Read(to []byte) (int, error) { +func (t *tun) readOne(to []byte) (int, error) { buf := make([]byte, len(to)+4) - n, err := t.ReadWriteCloser.Read(buf) + n, err := t.rwc.Read(buf) copy(to, buf[4:]) return n - 4, err } +func (t *tun) Read() ([][]byte, error) { + if t.readBuf == nil { + t.readBuf = make([]byte, defaultBatchBufSize) + } + n, err := t.readOne(t.readBuf) + if err != nil { + return nil, err + } + t.batchRet[0] = t.readBuf[:n] + return t.batchRet[:], nil +} + +func (t *tun) WriteReject(p []byte) (int, error) { + return t.Write(p) +} + // Write is only valid for single threaded use func (t *tun) Write(from []byte) (int, error) { buf := t.out @@ -537,7 +556,7 @@ func (t *tun) Write(from []byte) (int, error) { copy(buf[4:], from) - n, err := t.ReadWriteCloser.Write(buf) + n, err := t.rwc.Write(buf) return n - 4, err } @@ -553,6 +572,6 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *tun) NewMultiQueueReader() (Queue, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") } diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index aa3dddaf..8a691ae0 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -20,6 +20,23 @@ type disabledTun struct { tx metrics.Counter rx metrics.Counter l *logrus.Logger + + batchRet [1][]byte +} + +func (t *disabledTun) Read() ([][]byte, error) { + r, ok := <-t.read + if !ok { + return nil, io.EOF + } + + t.tx.Inc(1) + if t.l.Level >= logrus.DebugLevel { + t.l.WithField("raw", prettyPacket(r)).Debugf("Write payload") + } + + t.batchRet[0] = r + return t.batchRet[:], nil } func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { @@ -56,24 +73,6 @@ func (*disabledTun) Name() string { return "disabled" } -func (t *disabledTun) Read(b []byte) (int, error) { - r, ok := <-t.read - if !ok { - return 0, io.EOF - } - - if len(r) > len(b) { - return 0, fmt.Errorf("packet larger than mtu: %d > %d bytes", len(r), len(b)) - } - - t.tx.Inc(1) - if t.l.Level >= logrus.DebugLevel { - t.l.WithField("raw", prettyPacket(r)).Debugf("Write payload") - } - - return copy(b, r), nil -} - func (t *disabledTun) handleICMPEchoRequest(b []byte) bool { out := make([]byte, len(b)) out = iputil.CreateICMPEchoResponse(b, out) @@ -105,11 +104,15 @@ func (t *disabledTun) Write(b []byte) (int, error) { return len(b), nil } +func (t *disabledTun) WriteReject(b []byte) (int, error) { + return t.Write(b) +} + func (t *disabledTun) SupportsMultiqueue() bool { return true } -func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *disabledTun) NewMultiQueueReader() (Queue, error) { return t, nil } diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 91c51159..68278932 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -7,7 +7,6 @@ import ( "bytes" "errors" "fmt" - "io" "io/fs" "net/netip" "os" @@ -101,6 +100,9 @@ type tun struct { readPoll [2]unix.PollFd writePoll [2]unix.PollFd closed atomic.Bool + + readBuf []byte + batchRet [1][]byte } // blockOnRead waits until the tun fd is readable or shutdown has been signaled. @@ -155,7 +157,23 @@ func (t *tun) blockOnWrite() error { return nil } -func (t *tun) Read(to []byte) (int, error) { +func (t *tun) Read() ([][]byte, error) { + if t.readBuf == nil { + t.readBuf = make([]byte, defaultBatchBufSize) + } + n, err := t.readOne(t.readBuf) + if err != nil { + return nil, err + } + t.batchRet[0] = t.readBuf[:n] + return t.batchRet[:], nil +} + +func (t *tun) WriteReject(p []byte) (int, error) { + return t.Write(p) +} + +func (t *tun) readOne(to []byte) (int, error) { // first 4 bytes is protocol family, in network byte order var head [4]byte iovecs := [2]syscall.Iovec{ @@ -563,7 +581,7 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *tun) NewMultiQueueReader() (Queue, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd") } diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 0ce01df8..ebf134b8 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -21,11 +21,38 @@ import ( ) type tun struct { - io.ReadWriteCloser + rwc io.ReadWriteCloser vpnNetworks []netip.Prefix Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger + + readBuf []byte + batchRet [1][]byte +} + +func (t *tun) Read() ([][]byte, error) { + if t.readBuf == nil { + t.readBuf = make([]byte, defaultBatchBufSize) + } + n, err := t.rwc.Read(t.readBuf) + if err != nil { + return nil, err + } + t.batchRet[0] = t.readBuf[:n] + return t.batchRet[:], nil +} + +func (t *tun) Write(p []byte) (int, error) { + return t.rwc.Write(p) +} + +func (t *tun) WriteReject(p []byte) (int, error) { + return t.rwc.Write(p) +} + +func (t *tun) Close() error { + return t.rwc.Close() } func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { @@ -35,9 +62,9 @@ func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, erro func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/tun") t := &tun{ - vpnNetworks: vpnNetworks, - ReadWriteCloser: &tunReadCloser{f: file}, - l: l, + vpnNetworks: vpnNetworks, + rwc: &tunReadCloser{f: file}, + l: l, } err := t.reload(c, true) @@ -155,6 +182,6 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *tun) NewMultiQueueReader() (Queue, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for ios") } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 2830ff6b..3a75685b 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -10,9 +10,11 @@ import ( "net" "net/netip" "os" + "runtime" "strings" "sync" "sync/atomic" + "syscall" "time" "unsafe" @@ -34,16 +36,58 @@ type tunFile struct { readPoll [2]unix.PollFd writePoll [2]unix.PollFd closed bool + + // vnetHdr is true when this fd was opened with IFF_VNET_HDR and the + // kernel successfully accepted TUNSETOFFLOAD. Reads include a leading + // virtio_net_hdr and may carry a TSO superpacket we must segment; + // writes must prepend a zeroed virtio_net_hdr. + vnetHdr bool + readBuf []byte // scratch for a single raw read (virtio hdr + superpacket) + segBuf []byte // backing store for segmented output + segOff int // cursor into segBuf for the current Read drain + pending [][]byte // segments returned from the most recent Read + writeIovs [2]unix.Iovec // preallocated iovecs for Write (coalescer passthrough); iovs[0] is fixed to validVnetHdr + // rejectIovs is a second preallocated iovec scratch used exclusively by + // WriteReject (reject + self-forward from the inside path). It mirrors + // writeIovs but lets listenIn goroutines emit reject packets without + // racing with the listenOut coalescer that owns writeIovs. + rejectIovs [2]unix.Iovec + + // gsoHdrBuf is a per-queue 10-byte scratch for the virtio_net_hdr emitted + // by WriteGSO. Separate from validVnetHdr so a concurrent non-GSO Write on + // another queue never observes a half-written header. + gsoHdrBuf [virtioNetHdrLen]byte + // gsoIovs is the writev iovec scratch for WriteGSO. Sized to hold the + // virtio header + IP/TCP header + up to gsoInitialPayIovs payload + // fragments; grown on demand if a coalescer pushes more. + gsoIovs []unix.Iovec } +// gsoInitialPayIovs is the starting capacity (in payload fragments) of +// tunFile.gsoIovs. Sized to cover the default coalesce segment cap without +// any reallocations. +const gsoInitialPayIovs = 66 + +// validVnetHdr is the 10-byte virtio_net_hdr we prepend to every non-GSO TUN +// write. Only flag set is VIRTIO_NET_HDR_F_DATA_VALID, which marks the skb +// CHECKSUM_UNNECESSARY so the receiving network stack skips L4 checksum +// verification. All packets that reach the plain Write / WriteReject paths +// already carry a valid L4 checksum (either supplied by a remote peer whose +// ciphertext we AEAD-authenticated, or produced by finishChecksum during TSO +// segmentation, or built locally by CreateRejectPacket), so trusting them is +// safe. +var validVnetHdr = [virtioNetHdrLen]byte{unix.VIRTIO_NET_HDR_F_DATA_VALID} + // newFriend makes a tunFile for a MultiQueueReader that copies the shutdown eventfd from the parent tun func (r *tunFile) newFriend(fd int) (*tunFile, error) { if err := unix.SetNonblock(fd, true); err != nil { return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) } - return &tunFile{ + out := &tunFile{ fd: fd, shutdownFd: r.shutdownFd, + vnetHdr: r.vnetHdr, + readBuf: make([]byte, tunReadBufSize), readPoll: [2]unix.PollFd{ {Fd: int32(fd), Events: unix.POLLIN}, {Fd: int32(r.shutdownFd), Events: unix.POLLIN}, @@ -52,10 +96,21 @@ func (r *tunFile) newFriend(fd int) (*tunFile, error) { {Fd: int32(fd), Events: unix.POLLOUT}, {Fd: int32(r.shutdownFd), Events: unix.POLLIN}, }, - }, nil + } + if r.vnetHdr { + out.segBuf = make([]byte, tunSegBufCap) + out.writeIovs[0].Base = &validVnetHdr[0] + out.writeIovs[0].SetLen(virtioNetHdrLen) + out.rejectIovs[0].Base = &validVnetHdr[0] + out.rejectIovs[0].SetLen(virtioNetHdrLen) + out.gsoIovs = make([]unix.Iovec, 2, 2+gsoInitialPayIovs) + out.gsoIovs[0].Base = &out.gsoHdrBuf[0] + out.gsoIovs[0].SetLen(virtioNetHdrLen) + } + return out, nil } -func newTunFd(fd int) (*tunFile, error) { +func newTunFd(fd int, vnetHdr bool) (*tunFile, error) { if err := unix.SetNonblock(fd, true); err != nil { return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) } @@ -69,6 +124,8 @@ func newTunFd(fd int) (*tunFile, error) { fd: fd, shutdownFd: shutdownFd, lastOne: true, + vnetHdr: vnetHdr, + readBuf: make([]byte, tunReadBufSize), readPoll: [2]unix.PollFd{ {Fd: int32(fd), Events: unix.POLLIN}, {Fd: int32(shutdownFd), Events: unix.POLLIN}, @@ -78,6 +135,16 @@ func newTunFd(fd int) (*tunFile, error) { {Fd: int32(shutdownFd), Events: unix.POLLIN}, }, } + if vnetHdr { + out.segBuf = make([]byte, tunSegBufCap) + out.writeIovs[0].Base = &validVnetHdr[0] + out.writeIovs[0].SetLen(virtioNetHdrLen) + out.rejectIovs[0].Base = &validVnetHdr[0] + out.rejectIovs[0].SetLen(virtioNetHdrLen) + out.gsoIovs = make([]unix.Iovec, 2, 2+gsoInitialPayIovs) + out.gsoIovs[0].Base = &out.gsoHdrBuf[0] + out.gsoIovs[0].SetLen(virtioNetHdrLen) + } return out, nil } @@ -134,7 +201,7 @@ func (r *tunFile) blockOnWrite() error { return nil } -func (r *tunFile) Read(buf []byte) (int, error) { +func (r *tunFile) readRaw(buf []byte) (int, error) { for { if n, err := unix.Read(r.fd, buf); err == nil { return n, nil @@ -153,22 +220,238 @@ func (r *tunFile) Read(buf []byte) (int, error) { } } -func (r *tunFile) Write(buf []byte) (int, error) { +// Read reads one or more superpackets from the tun and returns the +// resulting packets. The first read blocks via poll; once the fd is known +// readable we drain additional packets non-blocking until the kernel queue +// is empty (EAGAIN), we've collected tunDrainCap packets, or we're out of +// segBuf headroom. This amortizes the poll wake over bursts of small +// packets (e.g. TCP ACKs). Slices point into the tunFile's internal buffers +// and are only valid until the next Read or Close on this Queue. +func (r *tunFile) Read() ([][]byte, error) { + r.pending = r.pending[:0] + r.segOff = 0 + + // Initial (blocking) read. Retry on decode errors so a single bad + // packet does not stall the reader. for { - if n, err := unix.Write(r.fd, buf); err == nil { - return n, nil - } else if err == unix.EAGAIN { - if err = r.blockOnWrite(); err != nil { + n, err := r.readRaw(r.readBuf) + if err != nil { + return nil, err + } + if !r.vnetHdr { + r.pending = append(r.pending, r.readBuf[:n]) + // Non-vnetHdr mode shares one readBuf so we can't drain safely + // without copying; return the single packet as before. + return r.pending, nil + } + if err := r.decodeRead(n); err != nil { + // Drop and read again — a bad packet should not kill the reader. + continue + } + break + } + + // Drain: non-blocking reads until the kernel queue is empty, the drain + // cap is reached, or segBuf no longer has room for another worst-case + // superpacket. + for len(r.pending) < tunDrainCap && tunSegBufCap-r.segOff >= tunSegBufSize { + n, err := unix.Read(r.fd, r.readBuf) + if err != nil { + // EAGAIN / EINTR / anything else: stop draining. We already + // have a valid batch from the first read. + break + } + if n <= 0 { + break + } + if err := r.decodeRead(n); err != nil { + // Drop this packet and stop the drain; we'd rather hand off + // what we have than keep spinning here. + break + } + } + + return r.pending, nil +} + +// decodeRead decodes the virtio header plus payload in r.readBuf[:n], appends +// the segments to r.pending, and advances r.segOff by the total scratch used. +// Caller must have already ensured r.vnetHdr is true. +func (r *tunFile) decodeRead(n int) error { + if n < virtioNetHdrLen { + return fmt.Errorf("short tun read: %d < %d", n, virtioNetHdrLen) + } + var hdr virtioNetHdr + hdr.decode(r.readBuf[:virtioNetHdrLen]) + before := len(r.pending) + if err := segmentInto(r.readBuf[virtioNetHdrLen:n], hdr, &r.pending, r.segBuf[r.segOff:]); err != nil { + return err + } + for k := before; k < len(r.pending); k++ { + r.segOff += len(r.pending[k]) + } + return nil +} + +func (r *tunFile) Write(buf []byte) (int, error) { + return r.writeWithScratch(buf, &r.writeIovs) +} + +// WriteReject emits a packet using a dedicated iovec scratch (rejectIovs) +// distinct from the one used by the coalescer's Write path. This avoids a +// data race between the inside (listenIn) goroutine emitting reject or +// self-forward packets and the outside (listenOut) goroutine flushing TCP +// coalescer passthroughs on the same tunFile. +func (r *tunFile) WriteReject(buf []byte) (int, error) { + return r.writeWithScratch(buf, &r.rejectIovs) +} + +func (r *tunFile) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error) { + if !r.vnetHdr { + for { + if n, err := unix.Write(r.fd, buf); err == nil { + return n, nil + } else if err == unix.EAGAIN { + if err = r.blockOnWrite(); err != nil { + return 0, err + } + continue + } else if err == unix.EINTR { + continue + } else if err == unix.EBADF { + return 0, os.ErrClosed + } else { + return 0, err + } + } + } + + if len(buf) == 0 { + return 0, nil + } + // Point the payload iovec at the caller's buffer. iovs[0] is pre-wired + // to validVnetHdr during tunFile construction so we don't rebuild it here. + iovs[1].Base = &buf[0] + iovs[1].SetLen(len(buf)) + iovPtr := uintptr(unsafe.Pointer(&iovs[0])) + // The TUN fd is non-blocking (set in newTunFd / newFriend), so writev + // either completes promptly or returns EAGAIN — it cannot park the + // goroutine inside the kernel. That lets us use syscall.RawSyscall and + // skip the runtime.entersyscall / exitsyscall bookkeeping on every + // packet; we only pay that cost when we fall through to blockOnWrite. + for { + n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, 2) + if errno == 0 { + runtime.KeepAlive(buf) + if int(n) < virtioNetHdrLen { + return 0, io.ErrShortWrite + } + return int(n) - virtioNetHdrLen, nil + } + if errno == unix.EAGAIN { + runtime.KeepAlive(buf) + if err := r.blockOnWrite(); err != nil { return 0, err } continue - } else if err == unix.EINTR { - continue - } else if err == unix.EBADF { - return 0, os.ErrClosed - } else { - return 0, err } + if errno == unix.EINTR { + continue + } + runtime.KeepAlive(buf) + return 0, errno + } +} + +// GSOSupported reports whether this queue was opened with IFF_VNET_HDR and +// can accept WriteGSO. When false, callers should fall back to per-segment +// Write calls. +func (r *tunFile) GSOSupported() bool { return r.vnetHdr } + +// WriteGSO emits a TCP TSO superpacket in a single writev. hdr is the +// IPv4/IPv6 + TCP header prefix (already finalized — total length, IP csum, +// and TCP pseudo-header partial set by the caller). pays are payload +// fragments whose concatenation forms the full coalesced payload; each +// slice is read-only and must stay valid until return. gsoSize is the MSS; +// every segment except possibly the last is exactly gsoSize bytes. +// csumStart is the byte offset where the TCP header begins within hdr. +func (r *tunFile) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error { + if !r.vnetHdr { + return fmt.Errorf("WriteGSO called on tun without IFF_VNET_HDR") + } + if len(hdr) == 0 || len(pays) == 0 { + return nil + } + + // Build the virtio_net_hdr. When pays total to <= gsoSize the kernel + // would produce a single segment; keep NEEDS_CSUM semantics but skip + // the GSO type so the kernel doesn't spuriously mark this as TSO. + vhdr := virtioNetHdr{ + Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + HdrLen: uint16(len(hdr)), + GSOSize: gsoSize, + CsumStart: csumStart, + CsumOffset: 16, // TCP checksum field lives 16 bytes into the TCP header + } + var totalPay int + for _, p := range pays { + totalPay += len(p) + } + if totalPay > int(gsoSize) { + if isV6 { + vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV6 + } else { + vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV4 + } + } else { + vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE + vhdr.GSOSize = 0 + } + vhdr.encode(r.gsoHdrBuf[:]) + + // Build the iovec array: [virtio_hdr, hdr, pays...]. r.gsoIovs[0] is + // wired to gsoHdrBuf at construction and never changes. + need := 2 + len(pays) + if cap(r.gsoIovs) < need { + grown := make([]unix.Iovec, need) + grown[0] = r.gsoIovs[0] + r.gsoIovs = grown + } else { + r.gsoIovs = r.gsoIovs[:need] + } + r.gsoIovs[1].Base = &hdr[0] + r.gsoIovs[1].SetLen(len(hdr)) + for i, p := range pays { + r.gsoIovs[2+i].Base = &p[0] + r.gsoIovs[2+i].SetLen(len(p)) + } + + iovPtr := uintptr(unsafe.Pointer(&r.gsoIovs[0])) + iovCnt := uintptr(len(r.gsoIovs)) + for { + n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, iovCnt) + if errno == 0 { + runtime.KeepAlive(hdr) + runtime.KeepAlive(pays) + if int(n) < virtioNetHdrLen { + return io.ErrShortWrite + } + return nil + } + if errno == unix.EAGAIN { + runtime.KeepAlive(hdr) + runtime.KeepAlive(pays) + if err := r.blockOnWrite(); err != nil { + return err + } + continue + } + if errno == unix.EINTR { + continue + } + runtime.KeepAlive(hdr) + runtime.KeepAlive(pays) + return errno } } @@ -239,7 +522,9 @@ type ifreqQLEN struct { } func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { - t, err := newTunGeneric(c, l, deviceFd, vpnNetworks) + // We don't know what flags the caller opened this fd with and can't turn + // on IFF_VNET_HDR after TUNSETIFF, so skip offload on inherited fds. + t, err := newTunGeneric(c, l, deviceFd, false, vpnNetworks) if err != nil { return nil, err } @@ -249,46 +534,83 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net return t, nil } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { +// openTunDev opens /dev/net/tun, creating the device node first if it's +// missing (docker containers occasionally omit it). +func openTunDev() (int, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) - if err != nil { - // If /dev/net/tun doesn't exist, try to create it (will happen in docker) - if os.IsNotExist(err) { - err = os.MkdirAll("/dev/net", 0755) - if err != nil { - return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err) - } - err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))) - if err != nil { - return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err) - } - - fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0) - if err != nil { - return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err) - } - } else { - return nil, err - } + if err == nil { + return fd, nil } + if !os.IsNotExist(err) { + return -1, err + } + if err = os.MkdirAll("/dev/net", 0755); err != nil { + return -1, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err) + } + if err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))); err != nil { + return -1, fmt.Errorf("failed to create /dev/net/tun: %w", err) + } + fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0) + if err != nil { + return -1, fmt.Errorf("created /dev/net/tun, but still failed: %w", err) + } + return fd, nil +} +// tunSetIff runs TUNSETIFF with the given flags and returns the kernel-chosen +// device name on success. +func tunSetIff(fd int, name string, flags uint16) (string, error) { var req ifReq - req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI) + req.Flags = flags + copy(req.Name[:], name) + if err := ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + return "", err + } + return strings.Trim(string(req.Name[:]), "\x00"), nil +} + +// tsoOffloadFlags are the TUN_F_* bits we ask the kernel to enable when a +// TSO-capable TUN is available. CSUM is required as a prerequisite for TSO. +const tsoOffloadFlags = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 + +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { + baseFlags := uint16(unix.IFF_TUN | unix.IFF_NO_PI) if multiqueue { - req.Flags |= unix.IFF_MULTI_QUEUE + baseFlags |= unix.IFF_MULTI_QUEUE } nameStr := c.GetString("tun.dev", "") - copy(req.Name[:], nameStr) - if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + + // First try to open with IFF_VNET_HDR + TUNSETOFFLOAD so we can receive + // TSO superpackets. If either step fails (older kernel, unprivileged + // container, etc.) we close and fall back to a plain TUN. + fd, err := openTunDev() + if err != nil { + return nil, err + } + vnetHdr := true + name, err := tunSetIff(fd, nameStr, baseFlags|unix.IFF_VNET_HDR|unix.IFF_NAPI) + if err != nil { _ = unix.Close(fd) - return nil, &NameError{ - Name: nameStr, - Underlying: err, + vnetHdr = false + } else if err = ioctl(uintptr(fd), unix.TUNSETOFFLOAD, uintptr(tsoOffloadFlags)); err != nil { + l.WithError(err).Warn("Failed to enable TUN offload (TSO); proceeding without virtio headers") + _ = unix.Close(fd) + vnetHdr = false + } + + if !vnetHdr { + fd, err = openTunDev() + if err != nil { + return nil, err + } + name, err = tunSetIff(fd, nameStr, baseFlags) + if err != nil { + _ = unix.Close(fd) + return nil, &NameError{Name: nameStr, Underlying: err} } } - name := strings.Trim(string(req.Name[:]), "\x00") - t, err := newTunGeneric(c, l, fd, vpnNetworks) + t, err := newTunGeneric(c, l, fd, vnetHdr, vpnNetworks) if err != nil { return nil, err } @@ -299,8 +621,8 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu } // newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error. -func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) { - tfd, err := newTunFd(fd) +func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vnetHdr bool, vpnNetworks []netip.Prefix) (*tun, error) { + tfd, err := newTunFd(fd, vnetHdr) if err != nil { _ = unix.Close(fd) return nil, err @@ -410,7 +732,7 @@ func (t *tun) SupportsMultiqueue() bool { return true } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *tun) NewMultiQueueReader() (Queue, error) { t.closeLock.Lock() defer t.closeLock.Unlock() @@ -419,14 +741,22 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, err } - var req ifReq - req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) - copy(req.Name[:], t.Device) - if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + flags := uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) + if t.vnetHdr { + flags |= unix.IFF_VNET_HDR | unix.IFF_NAPI + } + if _, err = tunSetIff(fd, t.Device, flags); err != nil { _ = unix.Close(fd) return nil, err } + if t.vnetHdr { + if err = ioctl(uintptr(fd), unix.TUNSETOFFLOAD, uintptr(tsoOffloadFlags)); err != nil { + _ = unix.Close(fd) + return nil, fmt.Errorf("failed to enable offload on multiqueue tun fd: %w", err) + } + } + out, err := t.tunFile.newFriend(fd) if err != nil { _ = unix.Close(fd) diff --git a/overlay/tun_linux_offload.go b/overlay/tun_linux_offload.go new file mode 100644 index 00000000..2d6e9a58 --- /dev/null +++ b/overlay/tun_linux_offload.go @@ -0,0 +1,331 @@ +//go:build linux && !android && !e2e_testing +// +build linux,!android,!e2e_testing + +package overlay + +import ( + "encoding/binary" + "fmt" + + "golang.org/x/sys/unix" +) + +// Size of the legacy struct virtio_net_hdr that the kernel prepends/expects on +// a TUN opened with IFF_VNET_HDR (TUNSETVNETHDRSZ not set). +const virtioNetHdrLen = 10 + +// Maximum size we accept for a single read from a TUN with IFF_VNET_HDR. A +// TSO superpacket can be up to 64KiB of payload plus a single L2/L3/L4 header +// prefix plus the virtio header. +const tunReadBufSize = 65535 + +// Space for segmented output. Worst case is many small segments, each paying +// an IP+TCP header. 128KiB comfortably covers the 64KiB payload ceiling. +const tunSegBufSize = 131072 + +// tunSegBufCap is the total size we allocate for the per-reader segment +// buffer. It is sized as one worst-case TSO superpacket (tunSegBufSize) plus +// the same again as drain headroom so a Read wake can accumulate +// additional packets after an initial big read without overflowing. +const tunSegBufCap = tunSegBufSize * 2 + +// tunDrainCap caps how many packets a single Read will accumulate via +// the post-wake drain loop. Sized to soak up a burst of small ACKs while +// bounding how much work a single caller holds before handing off. +const tunDrainCap = 64 + +type virtioNetHdr struct { + Flags uint8 + GSOType uint8 + HdrLen uint16 + GSOSize uint16 + CsumStart uint16 + CsumOffset uint16 +} + +// decode reads a virtio_net_hdr in host byte order (TUN default; we never +// call TUNSETVNETLE so the kernel matches our endianness). +func (h *virtioNetHdr) decode(b []byte) { + h.Flags = b[0] + h.GSOType = b[1] + h.HdrLen = binary.NativeEndian.Uint16(b[2:4]) + h.GSOSize = binary.NativeEndian.Uint16(b[4:6]) + h.CsumStart = binary.NativeEndian.Uint16(b[6:8]) + h.CsumOffset = binary.NativeEndian.Uint16(b[8:10]) +} + +// encode is the inverse of decode: writes the virtio_net_hdr fields into b +// (must be at least virtioNetHdrLen bytes). Used to emit a TSO superpacket +// on egress. +func (h *virtioNetHdr) encode(b []byte) { + b[0] = h.Flags + b[1] = h.GSOType + binary.NativeEndian.PutUint16(b[2:4], h.HdrLen) + binary.NativeEndian.PutUint16(b[4:6], h.GSOSize) + binary.NativeEndian.PutUint16(b[6:8], h.CsumStart) + binary.NativeEndian.PutUint16(b[8:10], h.CsumOffset) +} + +// segmentInto splits a TUN-side packet described by hdr into one or more +// IP packets, each appended to *out as a slice of scratch. scratch must be +// sized to hold every segment (including replicated headers). +func segmentInto(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) error { + // When RSC_INFO is set the csum_start/csum_offset fields are repurposed to + // carry coalescing info rather than checksum offsets. A TUN writing via + // IFF_VNET_HDR should never emit this, but if it did we would silently + // miscompute the segment checksums — refuse the packet instead. + if hdr.Flags&unix.VIRTIO_NET_HDR_F_RSC_INFO != 0 { + return fmt.Errorf("virtio RSC_INFO flag not supported on TUN reads") + } + + switch hdr.GSOType { + case unix.VIRTIO_NET_HDR_GSO_NONE: + if len(pkt) > len(scratch) { + return fmt.Errorf("packet larger than segment buffer: %d > %d", len(pkt), len(scratch)) + } + copy(scratch, pkt) + seg := scratch[:len(pkt)] + if hdr.Flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 { + if err := finishChecksum(seg, hdr); err != nil { + return err + } + } + *out = append(*out, seg) + return nil + + case unix.VIRTIO_NET_HDR_GSO_TCPV4, unix.VIRTIO_NET_HDR_GSO_TCPV6: + return segmentTCP(pkt, hdr, out, scratch) + + default: + return fmt.Errorf("unsupported virtio gso type: %d", hdr.GSOType) + } +} + +// finishChecksum computes the L4 checksum for a non-GSO packet that the kernel +// handed us with NEEDS_CSUM set. csum_start / csum_offset point at the 16-bit +// checksum field; we zero it, fold a full sum (the field was pre-loaded with +// the pseudo-header partial sum by the kernel), and store the result. +func finishChecksum(seg []byte, hdr virtioNetHdr) error { + cs := int(hdr.CsumStart) + co := int(hdr.CsumOffset) + if cs+co+2 > len(seg) { + return fmt.Errorf("csum offsets out of range: start=%d offset=%d len=%d", cs, co, len(seg)) + } + // The kernel stores a partial pseudo-header sum at [cs+co:]; sum over the + // L4 region starting at cs, folding the prior partial in as the seed. + partial := uint32(binary.BigEndian.Uint16(seg[cs+co : cs+co+2])) + seg[cs+co] = 0 + seg[cs+co+1] = 0 + sum := checksumBytes(seg[cs:], partial) + binary.BigEndian.PutUint16(seg[cs+co:cs+co+2], checksumFold(sum)) + return nil +} + +// segmentTCP software-segments a TSO superpacket into one IP packet per MSS +// chunk. The caller guarantees hdr.GSOType is TCPV4 or TCPV6. +// +// Hot-path shape: the per-segment loop only sums the payload chunk. The TCP +// header, the IPv4 header, and the pseudo-header src/dst/proto contributions +// are each summed once up front — every segment reuses those three pre-folded +// uint32 values and combines them with small per-segment deltas (seq, flags, +// tcpLen, ip_id, total_len) that are cheap to fold in. +func segmentTCP(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) error { + if hdr.GSOSize == 0 { + return fmt.Errorf("gso_size is zero") + } + if int(hdr.HdrLen) > len(pkt) || hdr.HdrLen == 0 { + return fmt.Errorf("hdr_len %d out of range (pkt %d)", hdr.HdrLen, len(pkt)) + } + if hdr.CsumStart == 0 || hdr.CsumStart >= hdr.HdrLen { + return fmt.Errorf("csum_start %d out of range (hdr_len %d)", hdr.CsumStart, hdr.HdrLen) + } + + isV4 := hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_TCPV4 + headerLen := int(hdr.HdrLen) + csumStart := int(hdr.CsumStart) + + if isV4 && csumStart < 20 { + return fmt.Errorf("csum_start %d too small for IPv4", csumStart) + } + if !isV4 && csumStart < 40 { + return fmt.Errorf("csum_start %d too small for IPv6", csumStart) + } + tcpHdrLen := headerLen - csumStart + if tcpHdrLen < 20 { + return fmt.Errorf("tcp header region too small: %d", tcpHdrLen) + } + + payload := pkt[headerLen:] + payLen := len(payload) + gso := int(hdr.GSOSize) + numSeg := (payLen + gso - 1) / gso + if numSeg == 0 { + numSeg = 1 + } + + need := numSeg*headerLen + payLen + if need > len(scratch) { + return fmt.Errorf("scratch too small for %d segments: need %d have %d", numSeg, need, len(scratch)) + } + + origSeq := binary.BigEndian.Uint32(pkt[csumStart+4 : csumStart+8]) + origFlags := pkt[csumStart+13] + const tcpFinPsh = 0x09 // FIN(0x01) | PSH(0x08) + + // Precompute the TCP header sum with seq/flags/csum zeroed. The max TCP + // header is 60 bytes; copy onto the stack, zero the per-segment-varying + // fields, sum once. + var tmp [60]byte + copy(tmp[:tcpHdrLen], pkt[csumStart:headerLen]) + tmp[4], tmp[5], tmp[6], tmp[7] = 0, 0, 0, 0 // seq + tmp[13] = 0 // flags + tmp[16], tmp[17] = 0, 0 // csum + baseTcpHdrSum := checksumBytes(tmp[:tcpHdrLen], 0) + + // Pseudo-header src+dst+proto contribution (tcpLen varies per segment). + var baseProtoSum uint32 + if isV4 { + baseProtoSum = checksumBytes(pkt[12:16], 0) + baseProtoSum = checksumBytes(pkt[16:20], baseProtoSum) + } else { + baseProtoSum = checksumBytes(pkt[8:24], 0) + baseProtoSum = checksumBytes(pkt[24:40], baseProtoSum) + } + baseProtoSum += uint32(unix.IPPROTO_TCP) + + // Precompute IPv4 header sum with total_len/id/csum zeroed. + var origIPID uint16 + var ihl int + var baseIPHdrSum uint32 + if isV4 { + origIPID = binary.BigEndian.Uint16(pkt[4:6]) + ihl = int(pkt[0]&0x0f) * 4 + if ihl < 20 || ihl > csumStart { + return fmt.Errorf("bad IPv4 IHL: %d", ihl) + } + var ipTmp [60]byte + copy(ipTmp[:ihl], pkt[:ihl]) + ipTmp[2], ipTmp[3] = 0, 0 // total_len + ipTmp[4], ipTmp[5] = 0, 0 // id + ipTmp[10], ipTmp[11] = 0, 0 // checksum + baseIPHdrSum = checksumBytes(ipTmp[:ihl], 0) + } + + off := 0 + for i := 0; i < numSeg; i++ { + segStart := i * gso + segEnd := segStart + gso + if segEnd > payLen { + segEnd = payLen + } + segPayLen := segEnd - segStart + + copy(scratch[off:], pkt[:headerLen]) + copy(scratch[off+headerLen:], payload[segStart:segEnd]) + seg := scratch[off : off+headerLen+segPayLen] + off += headerLen + segPayLen + + segSeq := origSeq + uint32(segStart) + segFlags := origFlags + if i != numSeg-1 { + segFlags = origFlags &^ tcpFinPsh + } + totalLen := headerLen + segPayLen + + // Patch IP header and write the v4 header checksum from the precomputed base. + if isV4 { + segID := origIPID + uint16(i) + binary.BigEndian.PutUint16(seg[2:4], uint16(totalLen)) + binary.BigEndian.PutUint16(seg[4:6], segID) + ipSum := baseIPHdrSum + uint32(totalLen) + uint32(segID) + binary.BigEndian.PutUint16(seg[10:12], checksumFold(ipSum)) + } else { + // IPv6 payload length excludes the 40-byte fixed header but + // includes any extension headers between [40:csumStart]. + binary.BigEndian.PutUint16(seg[4:6], uint16(headerLen-40+segPayLen)) + } + + // Patch TCP header. + binary.BigEndian.PutUint32(seg[csumStart+4:csumStart+8], segSeq) + seg[csumStart+13] = segFlags + // (csum is written below; its prior contents in `seg` don't affect the + // computation since we never sum over the segment's own header.) + + tcpLen := tcpHdrLen + segPayLen + paySum := checksumBytes(payload[segStart:segEnd], 0) + + // Combine pre-folded uint32s into a wider accumulator, then fold. Using + // uint64 guards against overflow when segSeq's high bits set. + wide := uint64(baseTcpHdrSum) + uint64(paySum) + uint64(baseProtoSum) + wide += uint64(segSeq) + uint64(segFlags) + uint64(tcpLen) + wide = (wide & 0xffffffff) + (wide >> 32) + wide = (wide & 0xffffffff) + (wide >> 32) + binary.BigEndian.PutUint16(seg[csumStart+16:csumStart+18], checksumFold(uint32(wide))) + + *out = append(*out, seg) + } + + return nil +} + +// checksumBytes returns the Internet-checksum partial sum of b, seeded with +// initial. Result is a 32-bit accumulator; the caller folds to 16. +// +// Each 4-byte load is added directly into a 64-bit accumulator. Two parallel +// accumulators break the serial dependency through `sum` and let the CPU +// overlap independent adds. The final fold from 64 → 32 → 16 handles the +// carries that accumulated across the 32-bit lane boundary. +func checksumBytes(b []byte, initial uint32) uint32 { + s0 := uint64(initial) + var s1 uint64 + for len(b) >= 32 { + s0 += uint64(binary.BigEndian.Uint32(b[0:4])) + s1 += uint64(binary.BigEndian.Uint32(b[4:8])) + s0 += uint64(binary.BigEndian.Uint32(b[8:12])) + s1 += uint64(binary.BigEndian.Uint32(b[12:16])) + s0 += uint64(binary.BigEndian.Uint32(b[16:20])) + s1 += uint64(binary.BigEndian.Uint32(b[20:24])) + s0 += uint64(binary.BigEndian.Uint32(b[24:28])) + s1 += uint64(binary.BigEndian.Uint32(b[28:32])) + b = b[32:] + } + sum := s0 + s1 + for len(b) >= 4 { + sum += uint64(binary.BigEndian.Uint32(b[:4])) + b = b[4:] + } + if len(b) >= 2 { + sum += uint64(binary.BigEndian.Uint16(b[:2])) + b = b[2:] + } + if len(b) == 1 { + sum += uint64(b[0]) << 8 + } + sum = (sum & 0xffffffff) + (sum >> 32) + sum = (sum & 0xffffffff) + (sum >> 32) + return uint32(sum) +} + +func checksumFold(sum uint32) uint16 { + for sum>>16 != 0 { + sum = (sum & 0xffff) + (sum >> 16) + } + return ^uint16(sum) +} + +func pseudoHeaderIPv4(src, dst []byte, proto byte, tcpLen int) uint32 { + sum := checksumBytes(src, 0) + sum = checksumBytes(dst, sum) + sum += uint32(proto) + sum += uint32(tcpLen) + return sum +} + +func pseudoHeaderIPv6(src, dst []byte, proto byte, tcpLen int) uint32 { + sum := checksumBytes(src, 0) + sum = checksumBytes(dst, sum) + sum += uint32(tcpLen >> 16) + sum += uint32(tcpLen & 0xffff) + sum += uint32(proto) + return sum +} diff --git a/overlay/tun_linux_offload_test.go b/overlay/tun_linux_offload_test.go new file mode 100644 index 00000000..650165bc --- /dev/null +++ b/overlay/tun_linux_offload_test.go @@ -0,0 +1,333 @@ +//go:build linux && !android && !e2e_testing +// +build linux,!android,!e2e_testing + +package overlay + +import ( + "encoding/binary" + "os" + "testing" + + "golang.org/x/sys/unix" +) + +// verifyChecksum confirms that the one's-complement sum across `b`, optionally +// seeded with a pseudo-header sum, folds to all-ones (valid). +func verifyChecksum(b []byte, pseudo uint32) bool { + sum := checksumBytes(b, pseudo) + for sum>>16 != 0 { + sum = (sum & 0xffff) + (sum >> 16) + } + return uint16(sum) == 0xffff +} + +// buildTSOv4 builds a synthetic IPv4/TCP TSO superpacket with a payload of +// `payLen` bytes split at `mss`. +func buildTSOv4(t *testing.T, payLen, mss int) ([]byte, virtioNetHdr) { + t.Helper() + const ipLen = 20 + const tcpLen = 20 + pkt := make([]byte, ipLen+tcpLen+payLen) + + // IPv4 header + pkt[0] = 0x45 // version 4, IHL 5 + // total length is meaningless for TSO but set it anyway + binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+tcpLen+payLen)) + binary.BigEndian.PutUint16(pkt[4:6], 0x4242) // original ID + pkt[8] = 64 // TTL + pkt[9] = unix.IPPROTO_TCP + copy(pkt[12:16], []byte{10, 0, 0, 1}) // src + copy(pkt[16:20], []byte{10, 0, 0, 2}) // dst + + // TCP header + binary.BigEndian.PutUint16(pkt[20:22], 12345) // sport + binary.BigEndian.PutUint16(pkt[22:24], 80) // dport + binary.BigEndian.PutUint32(pkt[24:28], 10000) // seq + binary.BigEndian.PutUint32(pkt[28:32], 20000) // ack + pkt[32] = 0x50 // data offset 5 words + pkt[33] = 0x18 // ACK | PSH + binary.BigEndian.PutUint16(pkt[34:36], 65535) // window + + // payload + for i := 0; i < payLen; i++ { + pkt[ipLen+tcpLen+i] = byte(i & 0xff) + } + + return pkt, virtioNetHdr{ + Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4, + HdrLen: uint16(ipLen + tcpLen), + GSOSize: uint16(mss), + CsumStart: uint16(ipLen), + CsumOffset: 16, + } +} + +func TestSegmentTCPv4(t *testing.T) { + const mss = 100 + const numSeg = 3 + pkt, hdr := buildTSOv4(t, mss*numSeg, mss) + + scratch := make([]byte, tunSegBufSize) + var out [][]byte + if err := segmentTCP(pkt, hdr, &out, scratch); err != nil { + t.Fatalf("segmentTCP: %v", err) + } + if len(out) != numSeg { + t.Fatalf("expected %d segments, got %d", numSeg, len(out)) + } + + for i, seg := range out { + if len(seg) != 40+mss { + t.Errorf("seg %d: unexpected len %d", i, len(seg)) + } + totalLen := binary.BigEndian.Uint16(seg[2:4]) + if totalLen != uint16(40+mss) { + t.Errorf("seg %d: total_len=%d want %d", i, totalLen, 40+mss) + } + id := binary.BigEndian.Uint16(seg[4:6]) + if id != 0x4242+uint16(i) { + t.Errorf("seg %d: ip id=%#x want %#x", i, id, 0x4242+uint16(i)) + } + seq := binary.BigEndian.Uint32(seg[24:28]) + wantSeq := uint32(10000 + i*mss) + if seq != wantSeq { + t.Errorf("seg %d: seq=%d want %d", i, seq, wantSeq) + } + flags := seg[33] + wantFlags := byte(0x10) // ACK only, PSH cleared + if i == numSeg-1 { + wantFlags = 0x18 // ACK | PSH preserved on last + } + if flags != wantFlags { + t.Errorf("seg %d: flags=%#x want %#x", i, flags, wantFlags) + } + // IPv4 header checksum must verify against itself. + if !verifyChecksum(seg[:20], 0) { + t.Errorf("seg %d: bad IPv4 header checksum", i) + } + // TCP checksum must verify against the pseudo-header. + psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+mss) + if !verifyChecksum(seg[20:], psum) { + t.Errorf("seg %d: bad TCP checksum", i) + } + } +} + +func TestSegmentTCPv4OddTail(t *testing.T) { + // Payload of 250 bytes with MSS 100 → segments of 100, 100, 50. + pkt, hdr := buildTSOv4(t, 250, 100) + scratch := make([]byte, tunSegBufSize) + var out [][]byte + if err := segmentTCP(pkt, hdr, &out, scratch); err != nil { + t.Fatalf("segmentTCP: %v", err) + } + if len(out) != 3 { + t.Fatalf("want 3 segments, got %d", len(out)) + } + wantPayLens := []int{100, 100, 50} + for i, seg := range out { + if len(seg)-40 != wantPayLens[i] { + t.Errorf("seg %d: pay len %d want %d", i, len(seg)-40, wantPayLens[i]) + } + if !verifyChecksum(seg[:20], 0) { + t.Errorf("seg %d: bad IPv4 header checksum", i) + } + psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+wantPayLens[i]) + if !verifyChecksum(seg[20:], psum) { + t.Errorf("seg %d: bad TCP checksum", i) + } + } +} + +func TestSegmentTCPv6(t *testing.T) { + const ipLen = 40 + const tcpLen = 20 + const mss = 120 + const numSeg = 2 + payLen := mss * numSeg + pkt := make([]byte, ipLen+tcpLen+payLen) + + // IPv6 header + pkt[0] = 0x60 // version 6 + binary.BigEndian.PutUint16(pkt[4:6], uint16(tcpLen+payLen)) + pkt[6] = unix.IPPROTO_TCP + pkt[7] = 64 + // src/dst fe80::1 / fe80::2 + pkt[8] = 0xfe + pkt[9] = 0x80 + pkt[23] = 1 + pkt[24] = 0xfe + pkt[25] = 0x80 + pkt[39] = 2 + + // TCP header + binary.BigEndian.PutUint16(pkt[40:42], 12345) + binary.BigEndian.PutUint16(pkt[42:44], 80) + binary.BigEndian.PutUint32(pkt[44:48], 7) + binary.BigEndian.PutUint32(pkt[48:52], 99) + pkt[52] = 0x50 + pkt[53] = 0x19 // FIN | ACK | PSH — exercise FIN clearing too + binary.BigEndian.PutUint16(pkt[54:56], 65535) + + for i := 0; i < payLen; i++ { + pkt[ipLen+tcpLen+i] = byte(i) + } + + hdr := virtioNetHdr{ + Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV6, + HdrLen: uint16(ipLen + tcpLen), + GSOSize: uint16(mss), + CsumStart: uint16(ipLen), + CsumOffset: 16, + } + + scratch := make([]byte, tunSegBufSize) + var out [][]byte + if err := segmentTCP(pkt, hdr, &out, scratch); err != nil { + t.Fatalf("segmentTCP: %v", err) + } + if len(out) != numSeg { + t.Fatalf("want %d segments, got %d", numSeg, len(out)) + } + + for i, seg := range out { + if len(seg) != ipLen+tcpLen+mss { + t.Errorf("seg %d: len %d want %d", i, len(seg), ipLen+tcpLen+mss) + } + pl := binary.BigEndian.Uint16(seg[4:6]) + if pl != uint16(tcpLen+mss) { + t.Errorf("seg %d: payload_length=%d want %d", i, pl, tcpLen+mss) + } + seq := binary.BigEndian.Uint32(seg[44:48]) + if seq != uint32(7+i*mss) { + t.Errorf("seg %d: seq=%d want %d", i, seq, 7+i*mss) + } + flags := seg[53] + // Original flags = 0x19 (FIN|ACK|PSH). FIN(0x01)+PSH(0x08) should be + // cleared on all but the last; ACK(0x10) always preserved. + wantFlags := byte(0x10) + if i == numSeg-1 { + wantFlags = 0x19 + } + if flags != wantFlags { + t.Errorf("seg %d: flags=%#x want %#x", i, flags, wantFlags) + } + psum := pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_TCP, tcpLen+mss) + if !verifyChecksum(seg[ipLen:], psum) { + t.Errorf("seg %d: bad TCP checksum", i) + } + } +} + +func TestSegmentGSONonePassesThrough(t *testing.T) { + pkt, hdr := buildTSOv4(t, 100, 100) + hdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE + hdr.Flags = 0 // no NEEDS_CSUM, leave packet untouched + + scratch := make([]byte, tunSegBufSize) + var out [][]byte + if err := segmentInto(pkt, hdr, &out, scratch); err != nil { + t.Fatalf("segmentInto: %v", err) + } + if len(out) != 1 { + t.Fatalf("want 1 segment, got %d", len(out)) + } + if len(out[0]) != len(pkt) { + t.Fatalf("unexpected length: %d vs %d", len(out[0]), len(pkt)) + } +} + +func TestSegmentRejectsUDP(t *testing.T) { + hdr := virtioNetHdr{GSOType: unix.VIRTIO_NET_HDR_GSO_UDP} + var out [][]byte + if err := segmentInto(nil, hdr, &out, nil); err == nil { + t.Fatalf("expected rejection for UDP GSO") + } +} + +func BenchmarkSegmentTCPv4(b *testing.B) { + sizes := []struct { + name string + payLen int + mss int + }{ + {"64KiB_MSS1460", 65000, 1460}, + {"16KiB_MSS1460", 16384, 1460}, + {"4KiB_MSS1460", 4096, 1460}, + } + for _, sz := range sizes { + b.Run(sz.name, func(b *testing.B) { + const ipLen = 20 + const tcpLen = 20 + pkt := make([]byte, ipLen+tcpLen+sz.payLen) + pkt[0] = 0x45 + binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+tcpLen+sz.payLen)) + binary.BigEndian.PutUint16(pkt[4:6], 0x4242) + pkt[8] = 64 + pkt[9] = unix.IPPROTO_TCP + copy(pkt[12:16], []byte{10, 0, 0, 1}) + copy(pkt[16:20], []byte{10, 0, 0, 2}) + binary.BigEndian.PutUint16(pkt[20:22], 12345) + binary.BigEndian.PutUint16(pkt[22:24], 80) + binary.BigEndian.PutUint32(pkt[24:28], 10000) + binary.BigEndian.PutUint32(pkt[28:32], 20000) + pkt[32] = 0x50 + pkt[33] = 0x18 + binary.BigEndian.PutUint16(pkt[34:36], 65535) + for i := 0; i < sz.payLen; i++ { + pkt[ipLen+tcpLen+i] = byte(i) + } + hdr := virtioNetHdr{ + Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4, + HdrLen: uint16(ipLen + tcpLen), + GSOSize: uint16(sz.mss), + CsumStart: uint16(ipLen), + CsumOffset: 16, + } + + scratch := make([]byte, tunSegBufSize) + out := make([][]byte, 0, 64) + + b.SetBytes(int64(len(pkt))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + out = out[:0] + if err := segmentTCP(pkt, hdr, &out, scratch); err != nil { + b.Fatal(err) + } + } + }) + } +} + +// TestTunFileWriteVnetHdrNoAlloc verifies the IFF_VNET_HDR fast-path write is +// allocation-free. We write to /dev/null so every call succeeds synchronously. +func TestTunFileWriteVnetHdrNoAlloc(t *testing.T) { + fd, err := unix.Open("/dev/null", os.O_WRONLY, 0) + if err != nil { + t.Fatalf("open /dev/null: %v", err) + } + t.Cleanup(func() { _ = unix.Close(fd) }) + + tf := &tunFile{fd: fd, vnetHdr: true} + tf.writeIovs[0].Base = &validVnetHdr[0] + tf.writeIovs[0].SetLen(virtioNetHdrLen) + + payload := make([]byte, 1400) + // Warm up (first call may trigger one-time internal allocations elsewhere). + if _, err := tf.Write(payload); err != nil { + t.Fatalf("Write: %v", err) + } + + allocs := testing.AllocsPerRun(1000, func() { + if _, err := tf.Write(payload); err != nil { + t.Fatalf("Write: %v", err) + } + }) + if allocs != 0 { + t.Fatalf("Write allocated %.1f times per call, want 0", allocs) + } +} diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 2986c895..995a9a9f 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -6,7 +6,6 @@ package overlay import ( "errors" "fmt" - "io" "net/netip" "os" "regexp" @@ -66,6 +65,25 @@ type tun struct { l *logrus.Logger f *os.File fd int + + readBuf []byte + batchRet [1][]byte +} + +func (t *tun) Read() ([][]byte, error) { + if t.readBuf == nil { + t.readBuf = make([]byte, defaultBatchBufSize) + } + n, err := t.readOne(t.readBuf) + if err != nil { + return nil, err + } + t.batchRet[0] = t.readBuf[:n] + return t.batchRet[:], nil +} + +func (t *tun) WriteReject(p []byte) (int, error) { + return t.Write(p) } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) @@ -141,7 +159,7 @@ func (t *tun) Close() error { return nil } -func (t *tun) Read(to []byte) (int, error) { +func (t *tun) readOne(to []byte) (int, error) { rc, err := t.f.SyscallConn() if err != nil { return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err) @@ -394,7 +412,7 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *tun) NewMultiQueueReader() (Queue, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd") } diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 9209b795..aab29bb5 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -6,7 +6,6 @@ package overlay import ( "errors" "fmt" - "io" "net/netip" "os" "regexp" @@ -59,6 +58,25 @@ type tun struct { fd int // cache out buffer since we need to prepend 4 bytes for tun metadata out []byte + + readBuf []byte + batchRet [1][]byte +} + +func (t *tun) Read() ([][]byte, error) { + if t.readBuf == nil { + t.readBuf = make([]byte, defaultBatchBufSize) + } + n, err := t.readOne(t.readBuf) + if err != nil { + return nil, err + } + t.batchRet[0] = t.readBuf[:n] + return t.batchRet[:], nil +} + +func (t *tun) WriteReject(p []byte) (int, error) { + return t.Write(p) } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) @@ -124,7 +142,7 @@ func (t *tun) Close() error { return nil } -func (t *tun) Read(to []byte) (int, error) { +func (t *tun) readOne(to []byte) (int, error) { buf := make([]byte, len(to)+4) n, err := t.f.Read(buf) @@ -314,7 +332,7 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *tun) NewMultiQueueReader() (Queue, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd") } diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 3477de3d..684d1ce1 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -26,6 +26,17 @@ type TestTun struct { closed atomic.Bool rxPackets chan []byte // Packets to receive into nebula TxPackets chan []byte // Packets transmitted outside by nebula + + batchRet [1][]byte +} + +func (t *TestTun) Read() ([][]byte, error) { + p, ok := <-t.rxPackets + if !ok { + return nil, os.ErrClosed + } + t.batchRet[0] = p + return t.batchRet[:], nil } func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { @@ -115,6 +126,10 @@ func (t *TestTun) Write(b []byte) (n int, err error) { return len(b), nil } +func (t *TestTun) WriteReject(b []byte) (int, error) { + return t.Write(b) +} + func (t *TestTun) Close() error { if t.closed.CompareAndSwap(false, true) { close(t.rxPackets) @@ -123,19 +138,10 @@ func (t *TestTun) Close() error { return nil } -func (t *TestTun) Read(b []byte) (int, error) { - p, ok := <-t.rxPackets - if !ok { - return 0, os.ErrClosed - } - copy(b, p) - return len(p), nil -} - func (t *TestTun) SupportsMultiqueue() bool { return false } -func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *TestTun) NewMultiQueueReader() (Queue, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented") } diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 223eabee..b02f33d5 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -6,7 +6,6 @@ package overlay import ( "crypto" "fmt" - "io" "net/netip" "os" "path/filepath" @@ -36,6 +35,25 @@ type winTun struct { l *logrus.Logger tun *wintun.NativeTun + + readBuf []byte + batchRet [1][]byte +} + +func (t *winTun) Read() ([][]byte, error) { + if t.readBuf == nil { + t.readBuf = make([]byte, defaultBatchBufSize) + } + n, err := t.tun.Read(t.readBuf, 0) + if err != nil { + return nil, err + } + t.batchRet[0] = t.readBuf[:n] + return t.batchRet[:], nil +} + +func (t *winTun) WriteReject(p []byte) (int, error) { + return t.Write(p) } func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) { @@ -229,10 +247,6 @@ func (t *winTun) Name() string { return t.Device } -func (t *winTun) Read(b []byte) (int, error) { - return t.tun.Read(b, 0) -} - func (t *winTun) Write(b []byte) (int, error) { return t.tun.Write(b, 0) } @@ -241,7 +255,7 @@ func (t *winTun) SupportsMultiqueue() bool { return false } -func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *winTun) NewMultiQueueReader() (Queue, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") } diff --git a/overlay/user.go b/overlay/user.go index 1f92d4e9..77c2d025 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -34,6 +34,21 @@ type UserDevice struct { inboundReader *io.PipeReader inboundWriter *io.PipeWriter + + readBuf []byte + batchRet [1][]byte +} + +func (d *UserDevice) Read() ([][]byte, error) { + if d.readBuf == nil { + d.readBuf = make([]byte, defaultBatchBufSize) + } + n, err := d.outboundReader.Read(d.readBuf) + if err != nil { + return nil, err + } + d.batchRet[0] = d.readBuf[:n] + return d.batchRet[:], nil } func (d *UserDevice) Activate() error { @@ -50,7 +65,7 @@ func (d *UserDevice) SupportsMultiqueue() bool { return true } -func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (d *UserDevice) NewMultiQueueReader() (Queue, error) { return d, nil } @@ -58,12 +73,12 @@ func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) { return d.inboundReader, d.outboundWriter } -func (d *UserDevice) Read(p []byte) (n int, err error) { - return d.outboundReader.Read(p) -} func (d *UserDevice) Write(p []byte) (n int, err error) { return d.inboundWriter.Write(p) } +func (d *UserDevice) WriteReject(p []byte) (n int, err error) { + return d.Write(p) +} func (d *UserDevice) Close() error { d.inboundWriter.Close() d.outboundWriter.Close() diff --git a/tcp_coalesce.go b/tcp_coalesce.go new file mode 100644 index 00000000..26ba1181 --- /dev/null +++ b/tcp_coalesce.go @@ -0,0 +1,484 @@ +package nebula + +import ( + "bytes" + "encoding/binary" + "io" + + "github.com/slackhq/nebula/overlay" +) + +// ipProtoTCP is the IANA protocol number for TCP. Hardcoded instead of +// reaching for golang.org/x/sys/unix — that package doesn't define the +// constant on Windows, which would break cross-compiles even though this +// file runs unchanged on every platform. +const ipProtoTCP = 6 + +// tcpCoalesceBufSize caps total bytes per superpacket. Mirrors the kernel's +// sk_gso_max_size of ~64KiB; anything beyond this would be rejected anyway. +const tcpCoalesceBufSize = 65535 + +// tcpCoalesceMaxSegs caps how many segments we'll coalesce into a single +// superpacket. Keeping this well below the kernel's TSO ceiling bounds +// latency. +const tcpCoalesceMaxSegs = 64 + +// tcpCoalesceHdrCap is the scratch space we copy a seed's IP+TCP header +// into. IPv6 (40) + TCP with full options (60) = 100 bytes. +const tcpCoalesceHdrCap = 100 + +// initialSlots is the starting capacity of the slot pool. One flow per +// packet is the worst case so this matches a typical UDP recvmmsg batch. +const initialSlots = 64 + +// flowKey identifies a TCP flow by {src, dst, sport, dport, family}. +// Comparable, so linear scans over the slot list stay tight. +type flowKey struct { + src, dst [16]byte + sport, dport uint16 + isV6 bool +} + +// coalesceSlot is one entry in the coalescer's ordered event queue. When +// passthrough is true the slot holds a single borrowed packet that must be +// emitted verbatim (non-TCP, non-admissible TCP, or oversize seed). When +// passthrough is false the slot is an in-progress coalesced superpacket: +// hdrBuf is a mutable copy of the seed's IP+TCP header (we patch total +// length and pseudo-header partial at flush), and payIovs are *borrowed* +// slices from the caller's plaintext buffers — no payload is ever copied. +// The caller (listenOut) must keep those buffers alive until Flush. +type coalesceSlot struct { + passthrough bool + rawPkt []byte // borrowed when passthrough + + fk flowKey + hdrBuf [tcpCoalesceHdrCap]byte + hdrLen int + ipHdrLen int + isV6 bool + gsoSize int + numSeg int + totalPay int + nextSeq uint32 + // psh closes the chain: set when the last-accepted segment had PSH or + // was sub-gsoSize. No further appends after that. + psh bool + payIovs [][]byte +} + +// tcpCoalescer accumulates adjacent in-flow TCP data segments across +// multiple concurrent flows and emits each flow's run as a single TSO +// superpacket via overlay.GSOWriter. All output — coalesced or not — is +// deferred until Flush so arrival order is preserved on the wire. Owns +// no locks; one coalescer per TUN write queue. +type tcpCoalescer struct { + plainW io.Writer + gsoW overlay.GSOWriter // nil when the queue doesn't support TSO + + // slots is the ordered event queue. Flush walks it once and emits each + // entry as either a WriteGSO (coalesced) or a plainW.Write (passthrough). + slots []*coalesceSlot + // openSlots maps a flow key to its most recent non-sealed slot, so new + // segments can extend an in-progress superpacket in O(1). Slots are + // removed from this map when they close (PSH or short-last-segment), + // when a non-admissible packet for that flow arrives, or in Flush. + openSlots map[flowKey]*coalesceSlot + pool []*coalesceSlot // free list for reuse +} + +func newTCPCoalescer(w io.Writer) *tcpCoalescer { + c := &tcpCoalescer{ + plainW: w, + slots: make([]*coalesceSlot, 0, initialSlots), + openSlots: make(map[flowKey]*coalesceSlot, initialSlots), + pool: make([]*coalesceSlot, 0, initialSlots), + } + if gw, ok := w.(overlay.GSOWriter); ok && gw.GSOSupported() { + c.gsoW = gw + } + return c +} + +// parsedTCP holds the fields extracted from a single parse so later steps +// (admission, slot lookup, canAppend) don't re-walk the header. +type parsedTCP struct { + fk flowKey + ipHdrLen int + tcpHdrLen int + hdrLen int + payLen int + seq uint32 + flags byte +} + +// parseTCPBase extracts the flow key and IP/TCP offsets for any TCP packet, +// regardless of whether it's admissible for coalescing. Returns ok=false +// for non-TCP or malformed input. Accepts IPv4 (no options, no fragmentation) +// and IPv6 (no extension headers). +func parseTCPBase(pkt []byte) (parsedTCP, bool) { + var p parsedTCP + if len(pkt) < 20 { + return p, false + } + v := pkt[0] >> 4 + switch v { + case 4: + ihl := int(pkt[0]&0x0f) * 4 + if ihl != 20 { + return p, false + } + if pkt[9] != ipProtoTCP { + return p, false + } + // Reject actual fragmentation (MF or non-zero frag offset). + if binary.BigEndian.Uint16(pkt[6:8])&0x3fff != 0 { + return p, false + } + totalLen := int(binary.BigEndian.Uint16(pkt[2:4])) + if totalLen > len(pkt) || totalLen < ihl { + return p, false + } + p.ipHdrLen = 20 + p.fk.isV6 = false + copy(p.fk.src[:4], pkt[12:16]) + copy(p.fk.dst[:4], pkt[16:20]) + pkt = pkt[:totalLen] + case 6: + if len(pkt) < 40 { + return p, false + } + if pkt[6] != ipProtoTCP { + return p, false + } + payloadLen := int(binary.BigEndian.Uint16(pkt[4:6])) + if 40+payloadLen > len(pkt) { + return p, false + } + p.ipHdrLen = 40 + p.fk.isV6 = true + copy(p.fk.src[:], pkt[8:24]) + copy(p.fk.dst[:], pkt[24:40]) + pkt = pkt[:40+payloadLen] + default: + return p, false + } + + if len(pkt) < p.ipHdrLen+20 { + return p, false + } + tcpOff := int(pkt[p.ipHdrLen+12]>>4) * 4 + if tcpOff < 20 || tcpOff > 60 { + return p, false + } + if len(pkt) < p.ipHdrLen+tcpOff { + return p, false + } + p.tcpHdrLen = tcpOff + p.hdrLen = p.ipHdrLen + tcpOff + p.payLen = len(pkt) - p.hdrLen + p.seq = binary.BigEndian.Uint32(pkt[p.ipHdrLen+4 : p.ipHdrLen+8]) + p.flags = pkt[p.ipHdrLen+13] + p.fk.sport = binary.BigEndian.Uint16(pkt[p.ipHdrLen : p.ipHdrLen+2]) + p.fk.dport = binary.BigEndian.Uint16(pkt[p.ipHdrLen+2 : p.ipHdrLen+4]) + return p, true +} + +// coalesceable reports whether a parsed TCP segment is eligible for +// coalescing. Accepts only ACK or ACK|PSH with a non-empty payload. +func (p parsedTCP) coalesceable() bool { + const ack = 0x10 + const psh = 0x08 + if p.flags&^(ack|psh) != 0 || p.flags&ack == 0 { + return false + } + return p.payLen > 0 +} + +// Add borrows pkt. The caller must keep pkt valid until the next Flush, +// whether or not the packet was coalesced — passthrough (non-admissible) +// packets are queued and written at Flush time, not synchronously. +func (c *tcpCoalescer) Add(pkt []byte) error { + if c.gsoW == nil { + c.addPassthrough(pkt) + return nil + } + + info, ok := parseTCPBase(pkt) + if !ok { + // Non-TCP or malformed — can't possibly collide with an open flow. + c.addPassthrough(pkt) + return nil + } + if !info.coalesceable() { + // TCP but not admissible (SYN/FIN/RST/URG/CWR/ECE or zero-payload). + // Seal this flow's open slot so later in-flow packets don't extend + // it and accidentally reorder past this passthrough. + delete(c.openSlots, info.fk) + c.addPassthrough(pkt) + return nil + } + + if open := c.openSlots[info.fk]; open != nil { + if c.canAppend(open, pkt, info) { + c.appendPayload(open, pkt, info) + if open.psh { + delete(c.openSlots, info.fk) + } + return nil + } + // Can't extend — seal it and fall through to seed a fresh slot. + delete(c.openSlots, info.fk) + } + c.seed(pkt, info) + return nil +} + +// Flush emits every queued event in arrival order. Coalesced slots go out +// via WriteGSO; passthrough slots go out via plainW.Write. Returns the +// first error observed; keeps draining so one bad packet doesn't hold up +// the rest. After Flush returns, borrowed payload slices may be recycled. +func (c *tcpCoalescer) Flush() error { + var first error + for _, s := range c.slots { + var err error + if s.passthrough { + _, err = c.plainW.Write(s.rawPkt) + } else { + err = c.flushSlot(s) + } + if err != nil && first == nil { + first = err + } + c.release(s) + } + for i := range c.slots { + c.slots[i] = nil + } + c.slots = c.slots[:0] + for k := range c.openSlots { + delete(c.openSlots, k) + } + return first +} + +func (c *tcpCoalescer) addPassthrough(pkt []byte) { + s := c.take() + s.passthrough = true + s.rawPkt = pkt + c.slots = append(c.slots, s) +} + +func (c *tcpCoalescer) seed(pkt []byte, info parsedTCP) { + if info.hdrLen > tcpCoalesceHdrCap || info.hdrLen+info.payLen > tcpCoalesceBufSize { + // Pathological shape — can't fit our scratch, emit as-is. + c.addPassthrough(pkt) + return + } + s := c.take() + s.passthrough = false + s.rawPkt = nil + copy(s.hdrBuf[:], pkt[:info.hdrLen]) + s.hdrLen = info.hdrLen + s.ipHdrLen = info.ipHdrLen + s.isV6 = info.fk.isV6 + s.fk = info.fk + s.gsoSize = info.payLen + s.numSeg = 1 + s.totalPay = info.payLen + s.nextSeq = info.seq + uint32(info.payLen) + s.psh = info.flags&0x08 != 0 + s.payIovs = append(s.payIovs[:0], pkt[info.hdrLen:info.hdrLen+info.payLen]) + c.slots = append(c.slots, s) + if !s.psh { + c.openSlots[info.fk] = s + } +} + +// canAppend reports whether info's packet extends the slot's seed: same +// header shape and stable contents, adjacent seq, not oversized, chain not +// closed. +func (c *tcpCoalescer) canAppend(s *coalesceSlot, pkt []byte, info parsedTCP) bool { + if s.psh { + return false + } + if info.hdrLen != s.hdrLen { + return false + } + if info.seq != s.nextSeq { + return false + } + if s.numSeg >= tcpCoalesceMaxSegs { + return false + } + if info.payLen > s.gsoSize { + return false + } + if s.hdrLen+s.totalPay+info.payLen > tcpCoalesceBufSize { + return false + } + if !headersMatch(s.hdrBuf[:s.hdrLen], pkt[:info.hdrLen], s.isV6, s.ipHdrLen) { + return false + } + return true +} + +func (c *tcpCoalescer) appendPayload(s *coalesceSlot, pkt []byte, info parsedTCP) { + s.payIovs = append(s.payIovs, pkt[info.hdrLen:info.hdrLen+info.payLen]) + s.numSeg++ + s.totalPay += info.payLen + s.nextSeq = info.seq + uint32(info.payLen) + if info.payLen < s.gsoSize || info.flags&0x08 != 0 { + s.psh = true + } +} + +func (c *tcpCoalescer) take() *coalesceSlot { + if n := len(c.pool); n > 0 { + s := c.pool[n-1] + c.pool[n-1] = nil + c.pool = c.pool[:n-1] + return s + } + return &coalesceSlot{} +} + +func (c *tcpCoalescer) release(s *coalesceSlot) { + s.passthrough = false + s.rawPkt = nil + for i := range s.payIovs { + s.payIovs[i] = nil + } + s.payIovs = s.payIovs[:0] + s.numSeg = 0 + s.totalPay = 0 + s.psh = false + c.pool = append(c.pool, s) +} + +// flushSlot patches the header and calls WriteGSO. Does not remove the +// slot from c.slots. +func (c *tcpCoalescer) flushSlot(s *coalesceSlot) error { + total := s.hdrLen + s.totalPay + l4Len := total - s.ipHdrLen + hdr := s.hdrBuf[:s.hdrLen] + + if s.isV6 { + binary.BigEndian.PutUint16(hdr[4:6], uint16(l4Len)) + } else { + binary.BigEndian.PutUint16(hdr[2:4], uint16(total)) + hdr[10] = 0 + hdr[11] = 0 + binary.BigEndian.PutUint16(hdr[10:12], ipv4HdrChecksum(hdr[:s.ipHdrLen])) + } + + var psum uint32 + if s.isV6 { + psum = pseudoSumIPv6(hdr[8:24], hdr[24:40], ipProtoTCP, l4Len) + } else { + psum = pseudoSumIPv4(hdr[12:16], hdr[16:20], ipProtoTCP, l4Len) + } + tcsum := s.ipHdrLen + 16 + binary.BigEndian.PutUint16(hdr[tcsum:tcsum+2], foldOnceNoInvert(psum)) + + return c.gsoW.WriteGSO(hdr, s.payIovs, uint16(s.gsoSize), s.isV6, uint16(s.ipHdrLen)) +} + +// headersMatch compares two IP+TCP header prefixes for byte-for-byte +// equality on every field that must be identical across coalesced +// segments. Size/IPID/IPCsum/seq/flags/tcpCsum are masked out. +func headersMatch(a, b []byte, isV6 bool, ipHdrLen int) bool { + if len(a) != len(b) { + return false + } + if isV6 { + // IPv6: bytes [0:4] = version/TC/flow-label, [6:8] = next_hdr/hop, + // [8:40] = src+dst. Skip [4:6] payload length. + if !bytes.Equal(a[0:4], b[0:4]) { + return false + } + if !bytes.Equal(a[6:40], b[6:40]) { + return false + } + } else { + // IPv4: [0:2] version/IHL/TOS, [6:10] flags/fragoff/TTL/proto, + // [12:20] src+dst. Skip [2:4] total len, [4:6] id, [10:12] csum. + if !bytes.Equal(a[0:2], b[0:2]) { + return false + } + if !bytes.Equal(a[6:10], b[6:10]) { + return false + } + if !bytes.Equal(a[12:20], b[12:20]) { + return false + } + } + // TCP: compare [0:4] ports, [8:13] ack+dataoff, [14:16] window, + // [18:tcpHdrLen] options (incl. urgent). + tcp := ipHdrLen + if !bytes.Equal(a[tcp:tcp+4], b[tcp:tcp+4]) { + return false + } + if !bytes.Equal(a[tcp+8:tcp+13], b[tcp+8:tcp+13]) { + return false + } + if !bytes.Equal(a[tcp+14:tcp+16], b[tcp+14:tcp+16]) { + return false + } + if !bytes.Equal(a[tcp+18:], b[tcp+18:]) { + return false + } + return true +} + +// ipv4HdrChecksum computes the IPv4 header checksum over hdr (which must +// already have its checksum field zeroed) and returns the folded/inverted +// 16-bit value to store. +func ipv4HdrChecksum(hdr []byte) uint16 { + var sum uint32 + for i := 0; i+1 < len(hdr); i += 2 { + sum += uint32(binary.BigEndian.Uint16(hdr[i : i+2])) + } + if len(hdr)%2 == 1 { + sum += uint32(hdr[len(hdr)-1]) << 8 + } + for sum>>16 != 0 { + sum = (sum & 0xffff) + (sum >> 16) + } + return ^uint16(sum) +} + +// pseudoSumIPv4 / pseudoSumIPv6 build the TCP pseudo-header partial sum +// expected by the virtio NEEDS_CSUM kernel path: the 32-bit accumulator +// before folding. +func pseudoSumIPv4(src, dst []byte, proto byte, l4Len int) uint32 { + var sum uint32 + sum += uint32(binary.BigEndian.Uint16(src[0:2])) + sum += uint32(binary.BigEndian.Uint16(src[2:4])) + sum += uint32(binary.BigEndian.Uint16(dst[0:2])) + sum += uint32(binary.BigEndian.Uint16(dst[2:4])) + sum += uint32(proto) + sum += uint32(l4Len) + return sum +} + +func pseudoSumIPv6(src, dst []byte, proto byte, l4Len int) uint32 { + var sum uint32 + for i := 0; i < 16; i += 2 { + sum += uint32(binary.BigEndian.Uint16(src[i : i+2])) + sum += uint32(binary.BigEndian.Uint16(dst[i : i+2])) + } + sum += uint32(l4Len >> 16) + sum += uint32(l4Len & 0xffff) + sum += uint32(proto) + return sum +} + +// foldOnceNoInvert folds the 32-bit accumulator to 16 bits and returns it +// unchanged (no one's complement). This is what virtio NEEDS_CSUM wants in +// the L4 checksum field — the kernel will add the payload sum and invert. +func foldOnceNoInvert(sum uint32) uint16 { + for sum>>16 != 0 { + sum = (sum & 0xffff) + (sum >> 16) + } + return uint16(sum) +} diff --git a/tcp_coalesce_test.go b/tcp_coalesce_test.go new file mode 100644 index 00000000..9d7713fc --- /dev/null +++ b/tcp_coalesce_test.go @@ -0,0 +1,576 @@ +package nebula + +import ( + "encoding/binary" + "testing" +) + +// fakeTunWriter records plain Writes and WriteGSO calls without touching a +// real TUN fd. WriteGSO preserves the split between hdr and borrowed pays +// so tests can inspect each independently. +type fakeTunWriter struct { + gsoEnabled bool + writes [][]byte + gsoWrites []fakeGSOWrite +} + +type fakeGSOWrite struct { + hdr []byte + pays [][]byte + gsoSize uint16 + isV6 bool + csumStart uint16 +} + +// total returns hdrLen + sum of pay lens. +func (g fakeGSOWrite) total() int { + n := len(g.hdr) + for _, p := range g.pays { + n += len(p) + } + return n +} + +// payLen sums the pays. +func (g fakeGSOWrite) payLen() int { + var n int + for _, p := range g.pays { + n += len(p) + } + return n +} + +func (w *fakeTunWriter) Write(p []byte) (int, error) { + buf := make([]byte, len(p)) + copy(buf, p) + w.writes = append(w.writes, buf) + return len(p), nil +} + +func (w *fakeTunWriter) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error { + hcopy := make([]byte, len(hdr)) + copy(hcopy, hdr) + paysCopy := make([][]byte, len(pays)) + for i, p := range pays { + pc := make([]byte, len(p)) + copy(pc, p) + paysCopy[i] = pc + } + w.gsoWrites = append(w.gsoWrites, fakeGSOWrite{ + hdr: hcopy, + pays: paysCopy, + gsoSize: gsoSize, + isV6: isV6, + csumStart: csumStart, + }) + return nil +} + +func (w *fakeTunWriter) GSOSupported() bool { return w.gsoEnabled } + +// buildTCPv4 constructs a minimal IPv4+TCP packet with the given payload, +// seq, and flags. Assumes no IP options and a 20-byte TCP header. +func buildTCPv4(seq uint32, flags byte, payload []byte) []byte { + return buildTCPv4Ports(1000, 2000, seq, flags, payload) +} + +// buildTCPv4Ports is buildTCPv4 with caller-specified ports so tests can +// build distinct flows. +func buildTCPv4Ports(sport, dport uint16, seq uint32, flags byte, payload []byte) []byte { + const ipHdrLen = 20 + const tcpHdrLen = 20 + total := ipHdrLen + tcpHdrLen + len(payload) + pkt := make([]byte, total) + + pkt[0] = 0x45 + pkt[1] = 0x00 + binary.BigEndian.PutUint16(pkt[2:4], uint16(total)) + binary.BigEndian.PutUint16(pkt[4:6], 0) + binary.BigEndian.PutUint16(pkt[6:8], 0x4000) + pkt[8] = 64 + pkt[9] = ipProtoTCP + copy(pkt[12:16], []byte{10, 0, 0, 1}) + copy(pkt[16:20], []byte{10, 0, 0, 2}) + + binary.BigEndian.PutUint16(pkt[20:22], sport) + binary.BigEndian.PutUint16(pkt[22:24], dport) + binary.BigEndian.PutUint32(pkt[24:28], seq) + binary.BigEndian.PutUint32(pkt[28:32], 12345) + pkt[32] = 0x50 + pkt[33] = flags + binary.BigEndian.PutUint16(pkt[34:36], 0xffff) + + copy(pkt[40:], payload) + return pkt +} + +const ( + tcpAck = 0x10 + tcpPsh = 0x08 + tcpSyn = 0x02 + tcpFin = 0x01 + tcpAckPsh = tcpAck | tcpPsh +) + +func TestCoalescerPassthroughWhenGSOUnavailable(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: false} + c := newTCPCoalescer(w) + pkt := buildTCPv4(1000, tcpAck, []byte("hello")) + if err := c.Add(pkt); err != nil { + t.Fatal(err) + } + // No sync write — passthrough is deferred to Flush. + if len(w.writes) != 0 || len(w.gsoWrites) != 0 { + t.Fatalf("no Add-time writes: got writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.writes) != 1 || len(w.gsoWrites) != 0 { + t.Fatalf("want single plain write, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } +} + +func TestCoalescerNonTCPPassthrough(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := newTCPCoalescer(w) + pkt := make([]byte, 28) + pkt[0] = 0x45 + binary.BigEndian.PutUint16(pkt[2:4], 28) + pkt[9] = 1 + copy(pkt[12:16], []byte{10, 0, 0, 1}) + copy(pkt[16:20], []byte{10, 0, 0, 2}) + if err := c.Add(pkt); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.writes) != 1 || len(w.gsoWrites) != 0 { + t.Fatalf("ICMP should pass through unchanged") + } +} + +func TestCoalescerSeedThenFlushAlone(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := newTCPCoalescer(w) + pkt := buildTCPv4(1000, tcpAck, make([]byte, 1000)) + if err := c.Add(pkt); err != nil { + t.Fatal(err) + } + if len(w.writes) != 0 || len(w.gsoWrites) != 0 { + t.Fatalf("unexpected output before flush") + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + // Single-segment flush now goes through WriteGSO with GSO_NONE + // (virtio NEEDS_CSUM lets the kernel fill in the L4 csum). + if len(w.gsoWrites) != 1 || len(w.writes) != 0 { + t.Fatalf("single-seg flush: writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } + g := w.gsoWrites[0] + if g.total() != 40+1000 { + t.Errorf("super total=%d want %d", g.total(), 40+1000) + } + if g.payLen() != 1000 { + t.Errorf("payLen=%d want 1000", g.payLen()) + } +} + +func TestCoalescerCoalescesAdjacentACKs(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := newTCPCoalescer(w) + pay := make([]byte, 1200) + if err := c.Add(buildTCPv4(1000, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Add(buildTCPv4(2200, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Add(buildTCPv4(3400, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.gsoWrites) != 1 { + t.Fatalf("want 1 gso write, got %d (plain=%d)", len(w.gsoWrites), len(w.writes)) + } + g := w.gsoWrites[0] + if g.gsoSize != 1200 { + t.Errorf("gsoSize=%d want 1200", g.gsoSize) + } + if len(g.hdr) != 40 { + t.Errorf("hdrLen=%d want 40", len(g.hdr)) + } + if g.csumStart != 20 { + t.Errorf("csumStart=%d want 20", g.csumStart) + } + if len(g.pays) != 3 { + t.Errorf("pay count=%d want 3", len(g.pays)) + } + if g.total() != 40+3*1200 { + t.Errorf("superpacket len=%d want %d", g.total(), 40+3*1200) + } + if tot := binary.BigEndian.Uint16(g.hdr[2:4]); int(tot) != g.total() { + t.Errorf("ip total_length=%d want %d", tot, g.total()) + } +} + +func TestCoalescerRejectsSeqGap(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := newTCPCoalescer(w) + pay := make([]byte, 1200) + if err := c.Add(buildTCPv4(1000, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Add(buildTCPv4(3000, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + // Each packet flushes as its own single-segment WriteGSO now. + if len(w.gsoWrites) != 2 || len(w.writes) != 0 { + t.Fatalf("seq gap: want 2 gso writes got writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } +} + +func TestCoalescerRejectsFlagMismatch(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := newTCPCoalescer(w) + pay := make([]byte, 1200) + if err := c.Add(buildTCPv4(1000, tcpAck, pay)); err != nil { + t.Fatal(err) + } + // SYN|ACK is non-admissible. Must flush matching flow's slot (gso) + // and then plain-write the SYN packet itself. + syn := buildTCPv4(2200, tcpSyn|tcpAck, pay) + if err := c.Add(syn); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.writes) != 1 || len(w.gsoWrites) != 1 { + t.Fatalf("flag mismatch: want 1 plain + 1 gso, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } +} + +func TestCoalescerRejectsFIN(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := newTCPCoalescer(w) + fin := buildTCPv4(1000, tcpAck|tcpFin, []byte("x")) + if err := c.Add(fin); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + // FIN isn't admissible — passthrough as plain, no slot, no gso. + if len(w.writes) != 1 || len(w.gsoWrites) != 0 { + t.Fatalf("FIN should be passthrough, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } +} + +func TestCoalescerShortLastSegmentClosesChain(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := newTCPCoalescer(w) + full := make([]byte, 1200) + half := make([]byte, 500) + if err := c.Add(buildTCPv4(1000, tcpAck, full)); err != nil { + t.Fatal(err) + } + if err := c.Add(buildTCPv4(2200, tcpAck, half)); err != nil { + t.Fatal(err) + } + // Chain now closed; next packet seeds a new slot on the same flow + // after flushing the old one. + if err := c.Add(buildTCPv4(2700, tcpAck, full)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + // Expect two gso writes: the first two packets coalesced, then the + // third flushed alone (single-seg via GSO_NONE). + if len(w.gsoWrites) != 2 { + t.Fatalf("want 2 gso writes got %d", len(w.gsoWrites)) + } + if len(w.writes) != 0 { + t.Fatalf("want 0 plain writes got %d", len(w.writes)) + } + if w.gsoWrites[0].gsoSize != 1200 { + t.Errorf("gsoSize=%d want 1200", w.gsoWrites[0].gsoSize) + } + if got, want := w.gsoWrites[0].total(), 40+1200+500; got != want { + t.Errorf("super len=%d want %d", got, want) + } +} + +func TestCoalescerPSHFinalizesChain(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := newTCPCoalescer(w) + pay := make([]byte, 1200) + if err := c.Add(buildTCPv4(1000, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Add(buildTCPv4(2200, tcpAckPsh, pay)); err != nil { + t.Fatal(err) + } + if err := c.Add(buildTCPv4(3400, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + // First two coalesce; the third seeds a fresh slot that flushes alone. + if len(w.gsoWrites) != 2 { + t.Fatalf("want 2 gso writes got %d", len(w.gsoWrites)) + } + if len(w.writes) != 0 { + t.Fatalf("want 0 plain writes got %d", len(w.writes)) + } +} + +func TestCoalescerRejectsDifferentFlow(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := newTCPCoalescer(w) + pay := make([]byte, 1200) + p1 := buildTCPv4(1000, tcpAck, pay) + p2 := buildTCPv4(2200, tcpAck, pay) + binary.BigEndian.PutUint16(p2[20:22], 9999) + if err := c.Add(p1); err != nil { + t.Fatal(err) + } + if err := c.Add(p2); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + // Two independent flows, each flushes its own single-segment WriteGSO. + if len(w.gsoWrites) != 2 || len(w.writes) != 0 { + t.Fatalf("diff flow: want 2 gso writes got writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } +} + +func TestCoalescerRejectsIPOptions(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := newTCPCoalescer(w) + pay := make([]byte, 500) + pkt := buildTCPv4(1000, tcpAck, pay) + // Bump IHL to 6 to simulate 4 bytes of IP options. Don't actually add + // bytes — parser should bail before it matters. + pkt[0] = 0x46 + if err := c.Add(pkt); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + // Non-admissible parse → passthrough as plain. + if len(w.writes) != 1 || len(w.gsoWrites) != 0 { + t.Fatalf("IP options should passthrough, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } +} + +func TestCoalescerCapBySegments(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := newTCPCoalescer(w) + pay := make([]byte, 512) + seq := uint32(1000) + for i := 0; i < tcpCoalesceMaxSegs+5; i++ { + if err := c.Add(buildTCPv4(seq, tcpAck, pay)); err != nil { + t.Fatal(err) + } + seq += uint32(len(pay)) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + for _, g := range w.gsoWrites { + segs := len(g.pays) + if segs > tcpCoalesceMaxSegs { + t.Fatalf("super exceeded seg cap: %d > %d", segs, tcpCoalesceMaxSegs) + } + } +} + +// TestCoalescerMultipleFlowsInSameBatch proves two interleaved bulk TCP +// flows coalesce independently in a single Flush. +func TestCoalescerMultipleFlowsInSameBatch(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := newTCPCoalescer(w) + pay := make([]byte, 1200) + + // Flow A: sport 1000. Flow B: sport 3000. + if err := c.Add(buildTCPv4Ports(1000, 2000, 100, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Add(buildTCPv4Ports(3000, 2000, 500, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Add(buildTCPv4Ports(1000, 2000, 1300, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Add(buildTCPv4Ports(3000, 2000, 1700, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Add(buildTCPv4Ports(1000, 2000, 2500, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Add(buildTCPv4Ports(3000, 2000, 2900, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + + if len(w.gsoWrites) != 2 { + t.Fatalf("want 2 gso writes (one per flow), got %d", len(w.gsoWrites)) + } + if len(w.writes) != 0 { + t.Fatalf("want no plain writes, got %d", len(w.writes)) + } + // Each superpacket should carry 3 segments. + for i, g := range w.gsoWrites { + if len(g.pays) != 3 { + t.Errorf("gso[%d]: segs=%d want 3", i, len(g.pays)) + } + if g.gsoSize != 1200 { + t.Errorf("gso[%d]: gsoSize=%d want 1200", i, g.gsoSize) + } + } + // Verify each superpacket carries the source port it was seeded with. + seenSports := map[uint16]bool{} + for _, g := range w.gsoWrites { + sp := binary.BigEndian.Uint16(g.hdr[20:22]) + seenSports[sp] = true + } + if !seenSports[1000] || !seenSports[3000] { + t.Errorf("expected superpackets for sports 1000 and 3000, got %v", seenSports) + } +} + +// TestCoalescerPreservesArrivalOrder confirms that with passthrough and +// coalesced events both queued, Flush emits them in Add order rather than +// writing passthrough packets synchronously. +func TestCoalescerPreservesArrivalOrder(t *testing.T) { + w := &orderedFakeWriter{gsoEnabled: true} + c := newTCPCoalescer(w) + // Sequence: coalesceable TCP, ICMP (passthrough), coalesceable TCP on + // a different flow. Expected emit order: gso(X), plain(ICMP), gso(Y). + pay := make([]byte, 1200) + if err := c.Add(buildTCPv4Ports(1000, 2000, 100, tcpAck, pay)); err != nil { + t.Fatal(err) + } + icmp := make([]byte, 28) + icmp[0] = 0x45 + binary.BigEndian.PutUint16(icmp[2:4], 28) + icmp[9] = 1 + copy(icmp[12:16], []byte{10, 0, 0, 1}) + copy(icmp[16:20], []byte{10, 0, 0, 3}) + if err := c.Add(icmp); err != nil { + t.Fatal(err) + } + if err := c.Add(buildTCPv4Ports(3000, 2000, 500, tcpAck, pay)); err != nil { + t.Fatal(err) + } + // Nothing should have hit the writer synchronously. + if len(w.events) != 0 { + t.Fatalf("Add emitted events synchronously: %v", w.events) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if got, want := w.events, []string{"gso", "plain", "gso"}; !stringSliceEq(got, want) { + t.Fatalf("flush order=%v want %v", got, want) + } +} + +// orderedFakeWriter records only the sequence of call types so tests can +// assert arrival order without inspecting bytes. +type orderedFakeWriter struct { + gsoEnabled bool + events []string +} + +func (w *orderedFakeWriter) Write(p []byte) (int, error) { + w.events = append(w.events, "plain") + return len(p), nil +} + +func (w *orderedFakeWriter) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error { + w.events = append(w.events, "gso") + return nil +} + +func (w *orderedFakeWriter) GSOSupported() bool { return w.gsoEnabled } + +func stringSliceEq(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +// TestCoalescerInterleavedFlowsPreserveOrdering checks that a non-admissible +// packet (SYN) mid-flow only flushes its own flow, not others. +func TestCoalescerInterleavedFlowsPreserveOrdering(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := newTCPCoalescer(w) + pay := make([]byte, 1200) + + // Flow A two segments. + if err := c.Add(buildTCPv4Ports(1000, 2000, 100, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Add(buildTCPv4Ports(1000, 2000, 1300, tcpAck, pay)); err != nil { + t.Fatal(err) + } + // Flow B two segments. + if err := c.Add(buildTCPv4Ports(3000, 2000, 500, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Add(buildTCPv4Ports(3000, 2000, 1700, tcpAck, pay)); err != nil { + t.Fatal(err) + } + // Flow A SYN (non-admissible) — must flush only flow A's slot. + syn := buildTCPv4Ports(1000, 2000, 9999, tcpSyn|tcpAck, pay) + if err := c.Add(syn); err != nil { + t.Fatal(err) + } + // Flow B continues — should still be coalesced with its seed. + if err := c.Add(buildTCPv4Ports(3000, 2000, 2900, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + + // Expected: + // - 1 gso for flow A (first 2 segments) + // - 1 plain for flow A SYN + // - 1 gso for flow B (3 segments) + if len(w.gsoWrites) != 2 { + t.Fatalf("want 2 gso writes, got %d", len(w.gsoWrites)) + } + if len(w.writes) != 1 { + t.Fatalf("want 1 plain write (SYN), got %d", len(w.writes)) + } + // Find the 3-segment gso (flow B) and the 2-segment gso (flow A). + var segCounts []int + for _, g := range w.gsoWrites { + segCounts = append(segCounts, len(g.pays)) + } + if !(segCounts[0] == 2 && segCounts[1] == 3) && !(segCounts[0] == 3 && segCounts[1] == 2) { + t.Errorf("unexpected segment counts: %v (want 2 and 3)", segCounts) + } +} diff --git a/udp/conn.go b/udp/conn.go index 30d89dec..c80cad8b 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -8,6 +8,12 @@ import ( const MTU = 9001 +// MaxWriteBatch is the largest batch any Conn.WriteBatch implementation is +// required to accept. Callers SHOULD NOT pass more than this per call; Linux +// backends preallocate sendmmsg scratch sized to this value, so exceeding it +// only costs a chunked retry. +const MaxWriteBatch = 128 + type EncReader func( addr netip.AddrPort, payload []byte, @@ -16,8 +22,29 @@ type EncReader func( type Conn interface { Rebind() error LocalAddr() (netip.AddrPort, error) - ListenOut(r EncReader) error + // ListenOut invokes r for each received packet. On batch-capable + // backends (recvmmsg), flush is called after each batch is fully + // delivered — callers use it to flush per-batch accumulators such as + // TUN write coalescers. Single-packet backends call flush after each + // packet. flush must not be nil. + ListenOut(r EncReader, flush func()) error WriteTo(b []byte, addr netip.AddrPort) error + // WriteBatch sends a contiguous batch of packets, each with its own + // destination. bufs and addrs must have the same length. Linux uses + // sendmmsg(2) for a single syscall; other backends fall back to a + // WriteTo loop. Returns on the first error; callers may observe a + // partial send if some packets went out before the error. + WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error + // WriteSegmented sends bufs as a single UDP GSO sendmsg when the kernel + // supports it: all bufs go to the same addr, each must be exactly segSize + // bytes except the last which may be shorter. The kernel emits one + // datagram per buf on the wire. Backends / kernels without GSO support + // fall back to a per-packet WriteTo loop. Returns on the first error. + WriteSegmented(bufs [][]byte, addr netip.AddrPort, segSize int) error + // SupportsGSO reports whether WriteSegmented takes the single-syscall + // GSO path. Callers use this to decide at batch-assembly time whether + // the uniform-size / same-dst check is worth running. + SupportsGSO() bool ReloadConfig(c *config.C) SupportsMultipleReaders() bool Close() error @@ -31,7 +58,7 @@ func (NoopConn) Rebind() error { func (NoopConn) LocalAddr() (netip.AddrPort, error) { return netip.AddrPort{}, nil } -func (NoopConn) ListenOut(_ EncReader) error { +func (NoopConn) ListenOut(_ EncReader, _ func()) error { return nil } func (NoopConn) SupportsMultipleReaders() bool { @@ -40,6 +67,15 @@ func (NoopConn) SupportsMultipleReaders() bool { func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { return nil } +func (NoopConn) WriteBatch(_ [][]byte, _ []netip.AddrPort) error { + return nil +} +func (NoopConn) WriteSegmented(_ [][]byte, _ netip.AddrPort, _ int) error { + return nil +} +func (NoopConn) SupportsGSO() bool { + return false +} func (NoopConn) ReloadConfig(_ *config.C) { return } diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 863c98f3..e4bdd659 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -140,6 +140,26 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error { } } +func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { + for i, b := range bufs { + if err := u.WriteTo(b, addrs[i]); err != nil { + return err + } + } + return nil +} + +func (u *StdConn) WriteSegmented(bufs [][]byte, addr netip.AddrPort, _ int) error { + for _, b := range bufs { + if err := u.WriteTo(b, addr); err != nil { + return err + } + } + return nil +} + +func (u *StdConn) SupportsGSO() bool { return false } + func (u *StdConn) LocalAddr() (netip.AddrPort, error) { a := u.UDPConn.LocalAddr() @@ -165,7 +185,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() { return func() {} } -func (u *StdConn) ListenOut(r EncReader) error { +func (u *StdConn) ListenOut(r EncReader, flush func()) error { buffer := make([]byte, MTU) for { @@ -180,6 +200,7 @@ func (u *StdConn) ListenOut(r EncReader) error { } r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) + flush() } } diff --git a/udp/udp_generic.go b/udp/udp_generic.go index ad26f794..307ed0d3 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -44,6 +44,26 @@ func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error { return err } +func (u *GenericConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { + for i, b := range bufs { + if _, err := u.UDPConn.WriteToUDPAddrPort(b, addrs[i]); err != nil { + return err + } + } + return nil +} + +func (u *GenericConn) WriteSegmented(bufs [][]byte, addr netip.AddrPort, _ int) error { + for _, b := range bufs { + if _, err := u.UDPConn.WriteToUDPAddrPort(b, addr); err != nil { + return err + } + } + return nil +} + +func (u *GenericConn) SupportsGSO() bool { return false } + func (u *GenericConn) LocalAddr() (netip.AddrPort, error) { a := u.UDPConn.LocalAddr() @@ -73,7 +93,7 @@ type rawMessage struct { Len uint32 } -func (u *GenericConn) ListenOut(r EncReader) error { +func (u *GenericConn) ListenOut(r EncReader, flush func()) error { buffer := make([]byte, MTU) var lastRecvErr time.Time @@ -94,6 +114,7 @@ func (u *GenericConn) ListenOut(r EncReader) error { } r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) + flush() } } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 21a34147..40fd0463 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -24,6 +24,32 @@ type StdConn struct { isV4 bool l *logrus.Logger batch int + + // sendmmsg scratch. Each queue has its own StdConn, so no locking is + // needed. Sized to MaxWriteBatch at construction; WriteBatch chunks + // larger inputs. + writeMsgs []rawMessage + writeIovs []iovec + writeNames [][]byte + + // Preallocated closure + in/out slots for sendmmsg, so the hot path + // does not heap-allocate a fresh closure per call. + writeChunk int + writeSent int + writeErrno syscall.Errno + writeFunc func(fd uintptr) bool + + // UDP GSO (sendmsg with UDP_SEGMENT cmsg) support. gsoSupported is + // probed once at socket creation. When true, WriteSegmented takes a + // single-syscall GSO path; otherwise it falls back to a WriteTo loop. + gsoSupported bool + gsoMsg msghdr + gsoIovs []iovec + gsoName []byte // SizeofSockaddrInet6 + gsoCmsg []byte // CmsgSpace(2) + gsoSent int + gsoErrno syscall.Errno + gsoFunc func(fd uintptr) bool } func setReusePort(network, address string, c syscall.RawConn) error { @@ -70,9 +96,61 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in } out.isV4 = af == unix.AF_INET + out.prepareWriteMessages(MaxWriteBatch) + out.writeFunc = out.sendmmsgRawWrite + + out.prepareGSO() + return out, nil } +// maxGSOSegments caps the per-sendmsg GSO fan-out. Linux kernels have +// historically capped UDP_MAX_SEGMENTS at 64; newer kernels raise it to 128 +// but we stay conservative so the same code works everywhere. +const maxGSOSegments = 64 + +// maxGSOBytes bounds the total payload per sendmsg() when UDP_SEGMENT is +// set. The kernel stitches all iovecs into a single skb whose length the +// UDP length field can represent, and also enforces sk_gso_max_size (which +// on most devices is 65536). We use 65535 so ciphertext + headers always +// fits, avoiding EMSGSIZE on large TSO superpackets. +const maxGSOBytes = 65535 + +// prepareGSO probes UDP_SEGMENT support and, on success, sets up the +// reusable sendmsg scratch (iovecs, sockaddr, cmsg) plus the preallocated +// raw-write closure used to avoid heap allocations on the hot path. +func (u *StdConn) prepareGSO() { + var probeErr error + if err := u.rawConn.Control(func(fd uintptr) { + probeErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT, 0) + }); err != nil { + return + } + if probeErr != nil { + return + } + u.gsoSupported = true + u.gsoIovs = make([]iovec, maxGSOSegments) + u.gsoName = make([]byte, unix.SizeofSockaddrInet6) + u.gsoCmsg = make([]byte, unix.CmsgSpace(2)) + + // Wire up the static pieces of gsoMsg. Iovlen / Controllen / Namelen / + // cmsg contents get refreshed per call; Iov, Name, Control pointers are + // fixed because the scratch slices never move. + u.gsoMsg.Iov = &u.gsoIovs[0] + u.gsoMsg.Name = &u.gsoName[0] + u.gsoMsg.Control = &u.gsoCmsg[0] + + // Prepopulate the cmsg header. Len/Level/Type are constant for our use; + // only the 2-byte gso_size payload changes per call. + cmsghdr := (*unix.Cmsghdr)(unsafe.Pointer(&u.gsoCmsg[0])) + cmsghdr.Level = unix.SOL_UDP + cmsghdr.Type = unix.UDP_SEGMENT + setCmsgLen(cmsghdr, unix.CmsgLen(2)) + + u.gsoFunc = u.sendmsgRawWriteGSO +} + func (u *StdConn) SupportsMultipleReaders() bool { return true } @@ -171,7 +249,7 @@ func recvmmsg(fd uintptr, msgs []rawMessage) (int, bool, error) { return int(n), true, nil } -func (u *StdConn) listenOutSingle(r EncReader) error { +func (u *StdConn) listenOutSingle(r EncReader, flush func()) error { var err error var n int var from netip.AddrPort @@ -184,10 +262,11 @@ func (u *StdConn) listenOutSingle(r EncReader) error { } from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port()) r(from, buffer[:n]) + flush() } } -func (u *StdConn) listenOutBatch(r EncReader) error { +func (u *StdConn) listenOutBatch(r EncReader, flush func()) error { var ip netip.Addr var n int var operr error @@ -219,14 +298,17 @@ func (u *StdConn) listenOutBatch(r EncReader) error { } r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) } + // End-of-batch: let callers (e.g. TUN write coalescer) flush any + // state they accumulated across this batch. + flush() } } -func (u *StdConn) ListenOut(r EncReader) error { +func (u *StdConn) ListenOut(r EncReader, flush func()) error { if u.batch == 1 { - return u.listenOutSingle(r) + return u.listenOutSingle(r, flush) } else { - return u.listenOutBatch(r) + return u.listenOutBatch(r, flush) } } @@ -235,6 +317,237 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { return err } +// WriteBatch sends bufs via sendmmsg(2) using the preallocated scratch on +// StdConn. Chunks larger than the scratch are processed in multiple syscalls. +// If sendmmsg returns a fatal error mid-chunk we fall back to single WriteTo +// calls for the remainder so the caller still gets best-effort delivery. +func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { + if len(bufs) != len(addrs) { + return fmt.Errorf("WriteBatch: len(bufs)=%d != len(addrs)=%d", len(bufs), len(addrs)) + } + //u.l.WithField("bufs", len(bufs)).Info("WriteBatch") + i := 0 + for i < len(bufs) { + chunk := len(bufs) - i + if chunk > len(u.writeMsgs) { + chunk = len(u.writeMsgs) + } + + for k := 0; k < chunk; k++ { + b := bufs[i+k] + if len(b) == 0 { + // sendmmsg with an empty iovec is legal but pointless; fall + // through after filling the slot so Base is still valid. + u.writeIovs[k].Base = nil + setIovLen(&u.writeIovs[k], 0) + } else { + u.writeIovs[k].Base = &b[0] + setIovLen(&u.writeIovs[k], len(b)) + } + nlen, err := writeSockaddr(u.writeNames[k], addrs[i+k], u.isV4) + if err != nil { + return err + } + u.writeMsgs[k].Hdr.Namelen = uint32(nlen) + } + + sent, serr := u.sendmmsg(chunk) + if serr != nil { + if sent <= 0 { + // nothing went out; fall back to WriteTo for this chunk. + for k := 0; k < chunk; k++ { + if err := u.WriteTo(bufs[i+k], addrs[i+k]); err != nil { + return err + } + } + i += chunk + continue + } + // partial: treat as success for the sent packets and retry the + // remainder on the next outer-loop iteration. + } + if sent == 0 { + return fmt.Errorf("sendmmsg made no progress") + } + i += sent + } + return nil +} + +// sendmmsgRawWrite is the preallocated callback passed to rawConn.Write. It +// reads its input (u.writeChunk) and writes its outputs (u.writeSent, +// u.writeErrno) through StdConn fields so the closure itself does not +// capture per-call locals and therefore does not heap-allocate. +func (u *StdConn) sendmmsgRawWrite(fd uintptr) bool { + r1, _, errno := unix.Syscall6( + unix.SYS_SENDMMSG, + fd, + uintptr(unsafe.Pointer(&u.writeMsgs[0])), + uintptr(u.writeChunk), + 0, + 0, + 0, + ) + if errno == syscall.EAGAIN || errno == syscall.EWOULDBLOCK { + return false + } + u.writeSent = int(r1) + u.writeErrno = errno + return true +} + +func (u *StdConn) SupportsGSO() bool { + return u.gsoSupported +} + +// WriteSegmented sends bufs to addr as a UDP GSO superpacket. The kernel +// emits one datagram per iovec on the wire; all iovecs except the last must +// be exactly segSize bytes. Non-GSO kernels hit the WriteTo fallback. +// Called with len(bufs) >= 1. len(bufs) > maxGSOSegments is chunked. +func (u *StdConn) WriteSegmented(bufs [][]byte, addr netip.AddrPort, segSize int) error { + if len(bufs) == 0 { + return nil + } + if !u.gsoSupported { + for _, b := range bufs { + if err := u.WriteTo(b, addr); err != nil { + return err + } + } + return nil + } + + nlen, err := writeSockaddr(u.gsoName, addr, u.isV4) + if err != nil { + return err + } + u.gsoMsg.Namelen = uint32(nlen) + setMsgControllen(&u.gsoMsg, unix.CmsgSpace(2)) + + // Cap the per-syscall fan-out by both segment count and total bytes. + // Kernel rejects sendmsg with EMSGSIZE when segCount*segSize would + // exceed sk_gso_max_size (typically 65536). For segSize > maxGSOBytes + // we can't use GSO at all and must fall back per-packet. + segsByBytes := maxGSOBytes / segSize + if segsByBytes == 0 { + for _, b := range bufs { + if werr := u.WriteTo(b, addr); werr != nil { + return werr + } + } + return nil + } + maxChunk := maxGSOSegments + if segsByBytes < maxChunk { + maxChunk = segsByBytes + } + + i := 0 + for i < len(bufs) { + chunk := len(bufs) - i + if chunk > maxChunk { + chunk = maxChunk + } + for k := 0; k < chunk; k++ { + b := bufs[i+k] + if len(b) == 0 { + u.gsoIovs[k].Base = nil + setIovLen(&u.gsoIovs[k], 0) + } else { + u.gsoIovs[k].Base = &b[0] + setIovLen(&u.gsoIovs[k], len(b)) + } + } + setMsgIovlen(&u.gsoMsg, chunk) + binary.NativeEndian.PutUint16(u.gsoCmsg[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(segSize)) + + if serr := u.sendmsgGSO(); serr != nil { + // Fall back to a per-packet loop for the remainder of the + // batch. Dropping the GSO call entirely is safer than + // returning mid-superpacket and losing bytes. + for k := 0; k < chunk; k++ { + if werr := u.WriteTo(bufs[i+k], addr); werr != nil { + return werr + } + } + } + i += chunk + } + return nil +} + +// sendmsgRawWriteGSO is the preallocated rawConn.Write callback for the GSO +// path. Reads the prebuilt u.gsoMsg and writes u.gsoSent / u.gsoErrno. +func (u *StdConn) sendmsgRawWriteGSO(fd uintptr) bool { + r1, _, errno := unix.Syscall( + unix.SYS_SENDMSG, + fd, + uintptr(unsafe.Pointer(&u.gsoMsg)), + 0, + ) + if errno == syscall.EAGAIN || errno == syscall.EWOULDBLOCK { + return false + } + u.gsoSent = int(r1) + u.gsoErrno = errno + return true +} + +func (u *StdConn) sendmsgGSO() error { + u.gsoSent = 0 + u.gsoErrno = 0 + if err := u.rawConn.Write(u.gsoFunc); err != nil { + return err + } + if u.gsoErrno != 0 { + return &net.OpError{Op: "sendmsg", Err: u.gsoErrno} + } + return nil +} + +func (u *StdConn) sendmmsg(n int) (int, error) { + u.writeChunk = n + u.writeSent = 0 + u.writeErrno = 0 + if err := u.rawConn.Write(u.writeFunc); err != nil { + return u.writeSent, err + } + if u.writeErrno != 0 { + return u.writeSent, &net.OpError{Op: "sendmmsg", Err: u.writeErrno} + } + return u.writeSent, nil +} + +// writeSockaddr encodes addr into buf (which must be at least +// SizeofSockaddrInet6 bytes). Returns the number of bytes used. If isV4 is +// true and addr is not a v4 (or v4-in-v6) address, returns an error. +func writeSockaddr(buf []byte, addr netip.AddrPort, isV4 bool) (int, error) { + ap := addr.Addr().Unmap() + if isV4 { + if !ap.Is4() { + return 0, ErrInvalidIPv6RemoteForSocket + } + // struct sockaddr_in: { sa_family_t(2), in_port_t(2, BE), in_addr(4), zero(8) } + // sa_family is host endian. + binary.NativeEndian.PutUint16(buf[0:2], unix.AF_INET) + binary.BigEndian.PutUint16(buf[2:4], addr.Port()) + ip4 := ap.As4() + copy(buf[4:8], ip4[:]) + for j := 8; j < 16; j++ { + buf[j] = 0 + } + return unix.SizeofSockaddrInet4, nil + } + // struct sockaddr_in6: { sa_family_t(2), in_port_t(2, BE), flowinfo(4), in6_addr(16), scope_id(4) } + binary.NativeEndian.PutUint16(buf[0:2], unix.AF_INET6) + binary.BigEndian.PutUint16(buf[2:4], addr.Port()) + binary.NativeEndian.PutUint32(buf[4:8], 0) + ip6 := addr.Addr().As16() + copy(buf[8:24], ip6[:]) + binary.NativeEndian.PutUint32(buf[24:28], 0) + return unix.SizeofSockaddrInet6, nil +} + func (u *StdConn) ReloadConfig(c *config.C) { b := c.GetInt("listen.read_buffer", 0) if b > 0 { diff --git a/udp/udp_linux_32.go b/udp/udp_linux_32.go index de8f1cdf..1fd469f3 100644 --- a/udp/udp_linux_32.go +++ b/udp/udp_linux_32.go @@ -52,3 +52,35 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { return msgs, buffers, names } + +// prepareWriteMessages allocates one Mmsghdr/iovec/sockaddr scratch per slot, +// wired up so each writeMsgs[i] already points at writeIovs[i] and +// writeNames[i]. Callers fill in the iovec Base/Len, the sockaddr bytes, and +// Namelen before each sendmmsg. +func (u *StdConn) prepareWriteMessages(n int) { + u.writeMsgs = make([]rawMessage, n) + u.writeIovs = make([]iovec, n) + u.writeNames = make([][]byte, n) + for i := range u.writeMsgs { + u.writeNames[i] = make([]byte, unix.SizeofSockaddrInet6) + u.writeMsgs[i].Hdr.Iov = &u.writeIovs[i] + u.writeMsgs[i].Hdr.Iovlen = 1 + u.writeMsgs[i].Hdr.Name = &u.writeNames[i][0] + } +} + +func setIovLen(v *iovec, n int) { + v.Len = uint32(n) +} + +func setMsgIovlen(m *msghdr, n int) { + m.Iovlen = uint32(n) +} + +func setMsgControllen(m *msghdr, n int) { + m.Controllen = uint32(n) +} + +func setCmsgLen(h *unix.Cmsghdr, n int) { + h.Len = uint32(n) +} diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 48c5a978..6f5008c0 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -55,3 +55,35 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { return msgs, buffers, names } + +// prepareWriteMessages allocates one Mmsghdr/iovec/sockaddr scratch per slot, +// wired up so each writeMsgs[i] already points at writeIovs[i] and +// writeNames[i]. Callers fill in the iovec Base/Len, the sockaddr bytes, and +// Namelen before each sendmmsg. +func (u *StdConn) prepareWriteMessages(n int) { + u.writeMsgs = make([]rawMessage, n) + u.writeIovs = make([]iovec, n) + u.writeNames = make([][]byte, n) + for i := range u.writeMsgs { + u.writeNames[i] = make([]byte, unix.SizeofSockaddrInet6) + u.writeMsgs[i].Hdr.Iov = &u.writeIovs[i] + u.writeMsgs[i].Hdr.Iovlen = 1 + u.writeMsgs[i].Hdr.Name = &u.writeNames[i][0] + } +} + +func setIovLen(v *iovec, n int) { + v.Len = uint64(n) +} + +func setMsgIovlen(m *msghdr, n int) { + m.Iovlen = uint64(n) +} + +func setMsgControllen(m *msghdr, n int) { + m.Controllen = uint64(n) +} + +func setCmsgLen(h *unix.Cmsghdr, n int) { + h.Len = uint64(n) +} diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index 607b978e..1ee85165 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -140,7 +140,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error { return nil } -func (u *RIOConn) ListenOut(r EncReader) error { +func (u *RIOConn) ListenOut(r EncReader, flush func()) error { buffer := make([]byte, MTU) var lastRecvErr time.Time @@ -162,6 +162,7 @@ func (u *RIOConn) ListenOut(r EncReader) error { } r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n]) + flush() } } @@ -316,6 +317,26 @@ func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error { return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) } +func (u *RIOConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { + for i, b := range bufs { + if err := u.WriteTo(b, addrs[i]); err != nil { + return err + } + } + return nil +} + +func (u *RIOConn) WriteSegmented(bufs [][]byte, addr netip.AddrPort, _ int) error { + for _, b := range bufs { + if err := u.WriteTo(b, addr); err != nil { + return err + } + } + return nil +} + +func (u *RIOConn) SupportsGSO() bool { return false } + func (u *RIOConn) LocalAddr() (netip.AddrPort, error) { sa, err := windows.Getsockname(u.sock) if err != nil { diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 5db72555..e7ef5c05 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -107,13 +107,34 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { return nil } -func (u *TesterConn) ListenOut(r EncReader) error { +func (u *TesterConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { + for i, b := range bufs { + if err := u.WriteTo(b, addrs[i]); err != nil { + return err + } + } + return nil +} + +func (u *TesterConn) WriteSegmented(bufs [][]byte, addr netip.AddrPort, _ int) error { + for _, b := range bufs { + if err := u.WriteTo(b, addr); err != nil { + return err + } + } + return nil +} + +func (u *TesterConn) SupportsGSO() bool { return false } + +func (u *TesterConn) ListenOut(r EncReader, flush func()) error { for { p, ok := <-u.RxPackets if !ok { return os.ErrClosed } r(p.From, p.Data) + flush() } }