From 5cc3ff594a9d782a86f8484f623f4805f72ca6e9 Mon Sep 17 00:00:00 2001 From: Jay Wren Date: Tue, 4 Nov 2025 15:04:24 -0500 Subject: [PATCH] use wg tun library; batching & locking improvements --- cmd/nebula-cert/test_freebsd.go | 4 + cmd/nebula-cert/test_openbsd.go | 4 + connection_manager_test.go | 9 +- firewall.go | 19 ++- go.mod | 6 +- go.sum | 12 +- inside.go | 254 ++++++++++++++++++++++++++++ interface.go | 75 +++++++-- main.go | 2 +- outside.go | 121 +++++++++++++- overlay/device.go | 18 +- overlay/tun.go | 1 + overlay/tun_android.go | 25 ++- overlay/tun_darwin.go | 28 +++- overlay/tun_disabled.go | 28 +++- overlay/tun_freebsd.go | 29 +++- overlay/tun_ios.go | 25 ++- overlay/tun_linux.go | 285 +++++++++++++++++++++++++++----- overlay/tun_netbsd.go | 26 ++- overlay/tun_openbsd.go | 26 ++- overlay/tun_tester.go | 38 ++++- overlay/tun_windows.go | 29 +++- overlay/user.go | 28 +++- stats.go | 1 + test/{tun.go => device/noop.go} | 20 ++- udp/conn.go | 28 +++- udp/udp_darwin.go | 39 +++++ udp/udp_generic.go | 35 ++++ udp/udp_linux.go | 178 +++++++++++++++++++- udp/udp_linux_32.go | 4 +- udp/udp_linux_64.go | 4 +- udp/udp_rio_windows.go | 44 +++++ udp/udp_tester.go | 29 ++++ 33 files changed, 1353 insertions(+), 121 deletions(-) create mode 100644 cmd/nebula-cert/test_freebsd.go create mode 100644 cmd/nebula-cert/test_openbsd.go rename test/{tun.go => device/noop.go} (55%) diff --git a/cmd/nebula-cert/test_freebsd.go b/cmd/nebula-cert/test_freebsd.go new file mode 100644 index 00000000..7276dfa1 --- /dev/null +++ b/cmd/nebula-cert/test_freebsd.go @@ -0,0 +1,4 @@ +package main + +const NoSuchFileError = "no such file or directory" +const NoSuchDirError = "no such file or directory" diff --git a/cmd/nebula-cert/test_openbsd.go b/cmd/nebula-cert/test_openbsd.go new file mode 100644 index 00000000..7276dfa1 --- /dev/null +++ b/cmd/nebula-cert/test_openbsd.go @@ -0,0 +1,4 @@ +package main + +const NoSuchFileError = "no such file or directory" +const NoSuchDirError = "no such file or directory" diff --git a/connection_manager_test.go b/connection_manager_test.go index 647dd72b..f8b15b6c 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -11,6 +11,7 @@ import ( "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" + "github.com/slackhq/nebula/test/device" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -52,7 +53,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, - inside: &test.NoopTun{}, + inside: &device.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, @@ -135,7 +136,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, - inside: &test.NoopTun{}, + inside: &device.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, @@ -220,7 +221,7 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) { lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, - inside: &test.NoopTun{}, + inside: &device.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, @@ -347,7 +348,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, - inside: &test.NoopTun{}, + inside: &device.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, diff --git a/firewall.go b/firewall.go index 45dc0691..8877bd83 100644 --- a/firewall.go +++ b/firewall.go @@ -81,8 +81,14 @@ type FirewallConntrack struct { Conns map[firewall.Packet]*conn TimerWheel *TimerWheel[firewall.Packet] + + // purgeCounter tracks lookups to trigger periodic purge instead of every lookup + purgeCounter uint32 } +// purgeInterval defines how many lookups between purge attempts +const conntrackPurgeInterval = 1024 + // FirewallTable is the entry point for a rule, the evaluation order is: // Proto AND port AND (CA SHA or CA name) AND local CIDR AND (group OR groups OR name OR remote CIDR) type FirewallTable struct { @@ -492,14 +498,17 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, conntrack := f.Conntrack conntrack.Lock() - // Purge every time we test - ep, has := conntrack.TimerWheel.Purge() - if has { - f.evict(ep) + // Periodic purge instead of every lookup (major CPU savings) + conntrack.purgeCounter++ + if conntrack.purgeCounter >= conntrackPurgeInterval { + conntrack.purgeCounter = 0 + ep, has := conntrack.TimerWheel.Purge() + if has { + f.evict(ep) + } } c, ok := conntrack.Conns[fp] - if !ok { conntrack.Unlock() return false diff --git a/go.mod b/go.mod index 1c564d03..d151fb05 100644 --- a/go.mod +++ b/go.mod @@ -30,11 +30,11 @@ require ( golang.org/x/sys v0.40.0 golang.org/x/term v0.39.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 - golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b + golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb golang.zx2c4.com/wireguard/windows v0.5.3 google.golang.org/protobuf v1.36.11 gopkg.in/yaml.v3 v3.0.1 - gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe + gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c ) require ( @@ -50,6 +50,6 @@ require ( github.com/vishvananda/netns v0.0.5 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect golang.org/x/mod v0.31.0 // indirect - golang.org/x/time v0.5.0 // indirect + golang.org/x/time v0.7.0 // indirect golang.org/x/tools v0.40.0 // indirect ) diff --git a/go.sum b/go.sum index c4613e01..1a038a9d 100644 --- a/go.sum +++ b/go.sum @@ -216,8 +216,8 @@ golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= -golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= +golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= @@ -231,8 +231,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= -golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo= -golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4= +golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A= +golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= @@ -258,5 +258,5 @@ gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe h1:fre4i6mv4iBuz5lCMOzHD1rH1ljqHWSICFmZRbbgp3g= -gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe/go.mod h1:sxc3Uvk/vHcd3tj7/DHVBoR5wvWT/MmRq2pj7HRJnwU= +gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI= +gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g= diff --git a/inside.go b/inside.go index 0d53f952..f7caebe2 100644 --- a/inside.go +++ b/inside.go @@ -2,6 +2,8 @@ package nebula import ( "net/netip" + "sync" + "time" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" @@ -11,6 +13,258 @@ import ( "github.com/slackhq/nebula/routing" ) +// preEncryptionPacket holds packet data before batch encryption +type preEncryptionPacket struct { + hostinfo *HostInfo + ci *ConnectionState + packet []byte + out []byte +} + +// Pool for preEncryptionBatch slices to reduce allocations +var preEncryptionBatchPool = sync.Pool{ + New: func() any { + // Pre-allocate with reasonable capacity + batch := make([]preEncryptionPacket, 0, 128) + return &batch + }, +} + +// consumeInsidePackets processes multiple packets in a batch for improved performance +// packets: slice of packet buffers to process +// sizes: slice of packet sizes +// count: number of packets to process +// outs: slice of output buffers (one per packet) with virtio headroom +// q: queue index +// localCache: firewall conntrack cache +// batchPackets: pre-allocated slice for accumulating encrypted packets +// batchAddrs: pre-allocated slice for accumulating destination addresses +func (f *Interface) consumeInsidePackets(packets [][]byte, sizes []int, count int, outs [][]byte, nb []byte, q int, localCache firewall.ConntrackCache, batchPackets *[][]byte, batchAddrs *[]netip.AddrPort) { + // Reusable per-packet state + fwPacket := &firewall.Packet{} + + // Reset batch accumulation slices (reuse capacity) + *batchPackets = (*batchPackets)[:0] + *batchAddrs = (*batchAddrs)[:0] + + // Collect packets for batched encryption + preEncryptionBatch := make([]preEncryptionPacket, 0, count) + + // Process each packet in the batch + for i := 0; i < count; i++ { + packet := packets[i][:sizes[i]] + out := outs[i] + + // Inline the consumeInsidePacket logic for better performance + err := newPacket(packet, false, fwPacket) + if err != nil { + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err) + } + continue + } + + // Ignore local broadcast packets + if f.dropLocalBroadcast { + if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) { + continue + } + } + + if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) { + // Immediately forward packets from self to self. + if immediatelyForwardToSelf { + _, err := f.readers[q].Write(packet) + if err != nil { + f.l.WithError(err).Error("Failed to forward to tun") + } + } + continue + } + + // Ignore multicast packets + if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() { + continue + } + + hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { + hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) + }) + + if hostinfo == nil { + f.rejectInside(packet, out, q) + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("vpnAddr", fwPacket.RemoteAddr). + WithField("fwPacket", fwPacket). + Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks") + } + continue + } + + if !ready { + continue + } + + dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) + if dropReason != nil { + f.rejectInside(packet, out, q) + if f.l.Level >= logrus.DebugLevel { + hostinfo.logger(f.l). + WithField("fwPacket", fwPacket). + WithField("reason", dropReason). + Debugln("dropping outbound packet") + } + continue + } + + // Prepare packet for batch encryption + ci := hostinfo.ConnectionState + if ci.eKey == nil { + continue + } + + // Check if this needs relay - if so, send immediately and skip batching + useRelay := !hostinfo.remote.IsValid() + if useRelay { + // Handle relay sends individually (less common path) + f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, packet, nb, out, q) + continue + } + + // Collect for batched encryption + preEncryptionBatch = append(preEncryptionBatch, preEncryptionPacket{ + hostinfo: hostinfo, + ci: ci, + packet: packet, + out: out, + }) + } + + // BATCH ENCRYPTION: Process all collected packets + if len(preEncryptionBatch) > 0 { + f.encryptBatch(preEncryptionBatch, nb, batchPackets, batchAddrs) + } + + // Send all accumulated packets in one batch + if len(*batchPackets) > 0 { + batchSize := len(*batchPackets) + n, err := f.writers[q].WriteMulti(*batchPackets, *batchAddrs) + if err != nil { + f.l.WithError(err).WithField("sent", n).WithField("total", batchSize).Error("Failed to send batch") + } + } +} + +// encryptBatch processes multiple packets, grouping by ConnectionState to reduce lock acquisitions +func (f *Interface) encryptBatch(batch []preEncryptionPacket, nb []byte, batchPackets *[][]byte, batchAddrs *[]netip.AddrPort) { + lockStart := time.Now() + lockAcquisitions := int64(0) + + if noiseutil.EncryptLockNeeded { + // Group packets by ConnectionState to minimize lock acquisitions + // Process packets in order but batch lock acquisitions per CI + var currentCI *ConnectionState + var lockHeld bool + + for i := range batch { + ci := batch[i].ci + hostinfo := batch[i].hostinfo + + // Validate packet data to prevent nil pointer dereference + if ci == nil || hostinfo == nil { + continue + } + + // Switch locks if we're moving to a different ConnectionState + if ci != currentCI { + if lockHeld { + currentCI.writeLock.Unlock() + } + ci.writeLock.Lock() + lockAcquisitions++ + currentCI = ci + lockHeld = true + } + + c := ci.messageCounter.Add(1) + out := header.Encode(batch[i].out, header.Version, header.Message, 0, hostinfo.remoteIndexId, c) + f.connectionManager.Out(hostinfo) + + // Query lighthouse if needed + if hostinfo.lastRebindCount != f.rebindCount { + f.lightHouse.QueryServer(hostinfo.vpnAddrs[0]) + hostinfo.lastRebindCount = f.rebindCount + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter") + } + } + + var err error + out, err = ci.eKey.EncryptDanger(out, out, batch[i].packet, c, nb) + if err != nil { + hostinfo.logger(f.l).WithError(err). + WithField("counter", c). + Error("Failed to encrypt outgoing packet") + continue + } + + // Add to output batches + *batchPackets = append(*batchPackets, out) + *batchAddrs = append(*batchAddrs, hostinfo.remote) + } + + // Release final lock + if lockHeld { + currentCI.writeLock.Unlock() + } + } else { + // No locks needed - process directly + for i := range batch { + ci := batch[i].ci + hostinfo := batch[i].hostinfo + + // Validate packet data to prevent nil pointer dereference + if ci == nil || hostinfo == nil { + continue + } + + c := ci.messageCounter.Add(1) + out := header.Encode(batch[i].out, header.Version, header.Message, 0, hostinfo.remoteIndexId, c) + f.connectionManager.Out(hostinfo) + + // Query lighthouse if needed + if hostinfo.lastRebindCount != f.rebindCount { + f.lightHouse.QueryServer(hostinfo.vpnAddrs[0]) + hostinfo.lastRebindCount = f.rebindCount + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter") + } + } + + var err error + out, err = ci.eKey.EncryptDanger(out, out, batch[i].packet, c, nb) + if err != nil { + hostinfo.logger(f.l).WithError(err). + WithField("counter", c). + Error("Failed to encrypt outgoing packet") + continue + } + + // Add to output batches + *batchPackets = append(*batchPackets, out) + *batchAddrs = append(*batchAddrs, hostinfo.remote) + } + } + + // Record metrics + encryptionTime := time.Since(lockStart) + if noiseutil.EncryptLockNeeded { + f.batchMetrics.lockAcquisitions.Inc(lockAcquisitions) + } + f.batchMetrics.encryptionTime.Update(encryptionTime.Nanoseconds()) + f.batchMetrics.batchSize.Update(int64(len(batch))) +} + func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { err := newPacket(packet, false, fwPacket) if err != nil { diff --git a/interface.go b/interface.go index 61b1f228..fb67fe84 100644 --- a/interface.go +++ b/interface.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io" "net/netip" "os" "runtime" @@ -22,6 +21,7 @@ import ( ) const mtu = 9001 +const virtioNetHdrLen = overlay.VirtioNetHdrLen type InterfaceConfig struct { HostMap *HostMap @@ -50,6 +50,13 @@ type InterfaceConfig struct { l *logrus.Logger } +type batchMetrics struct { + udpReadSize metrics.Histogram + encryptionTime metrics.Histogram // Time spent in encryption (including lock waits) + batchSize metrics.Histogram // Dynamic batch sizes being used + lockAcquisitions metrics.Counter // Number of lock acquisitions (should be minimal) +} + type Interface struct { hostMap *HostMap outside udp.Conn @@ -87,11 +94,12 @@ type Interface struct { conntrackCacheTimeout time.Duration writers []udp.Conn - readers []io.ReadWriteCloser + readers []overlay.BatchReadWriter metricHandshakes metrics.Histogram messageMetrics *MessageMetrics cachedPacketMetrics *cachedPacketMetrics + batchMetrics *batchMetrics l *logrus.Logger } @@ -178,7 +186,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { routines: c.routines, version: c.version, writers: make([]udp.Conn, c.routines), - readers: make([]io.ReadWriteCloser, c.routines), + readers: make([]overlay.BatchReadWriter, c.routines), myVpnNetworks: cs.myVpnNetworks, myVpnNetworksTable: cs.myVpnNetworksTable, myVpnAddrs: cs.myVpnAddrs, @@ -194,6 +202,12 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { sent: metrics.GetOrRegisterCounter("hostinfo.cached_packets.sent", nil), dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil), }, + batchMetrics: &batchMetrics{ + udpReadSize: metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024)), + encryptionTime: metrics.GetOrRegisterHistogram("batch.encryption_time_ns", nil, metrics.NewUniformSample(1024)), + batchSize: metrics.GetOrRegisterHistogram("batch.size", nil, metrics.NewUniformSample(1024)), + lockAcquisitions: metrics.GetOrRegisterCounter("batch.lock_acquisitions", nil), + }, l: c.l, } @@ -233,7 +247,7 @@ func (f *Interface) activate() { metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines)) // Prepare n tun queues - var reader io.ReadWriteCloser = f.inside + var reader overlay.BatchReadWriter = f.inside for i := 0; i < f.routines; i++ { if i > 0 { reader, err = f.inside.NewMultiQueueReader() @@ -274,39 +288,68 @@ func (f *Interface) listenOut(i int) { ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) lhh := f.lightHouse.NewRequestHandler() - plaintext := make([]byte, udp.MTU) + + // Pre-allocate output buffers for batch processing + batchSize := li.BatchSize() + outs := make([][]byte, batchSize) + for idx := range outs { + // Allocate full buffer with virtio header space + outs[idx] = make([]byte, virtioNetHdrLen, virtioNetHdrLen+udp.MTU) + } + h := &header.H{} fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) + nb := make([]byte, 12) + + li.ListenOutBatch(func(addrs []netip.AddrPort, payloads [][]byte, count int) { + f.readOutsidePacketsBatch(addrs, payloads, count, outs[:count], nb, i, h, fwPacket, lhh, ctCache.Get(f.l)) - li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { - f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) }) } -func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { +func (f *Interface) listenIn(reader overlay.BatchReadWriter, i int) { runtime.LockOSThread() - packet := make([]byte, mtu) - out := make([]byte, mtu) - fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) + batchSize := reader.BatchSize() + + // Allocate buffers for batch reading + bufs := make([][]byte, batchSize) + for idx := range bufs { + bufs[idx] = make([]byte, mtu) + } + sizes := make([]int, batchSize) + + // Allocate output buffers for batch processing (one per packet) + // Each has virtio header headroom to avoid copies on write + outs := make([][]byte, batchSize) + for idx := range outs { + outBuf := make([]byte, virtioNetHdrLen+mtu) + outs[idx] = outBuf[virtioNetHdrLen:] // Slice starting after headroom + } + + // Pre-allocate batch accumulation buffers for sending + batchPackets := make([][]byte, 0, batchSize) + batchAddrs := make([]netip.AddrPort, 0, batchSize) + + // Pre-allocate nonce buffer (reused for all encryptions) + nb := make([]byte, 12) conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) for { - n, err := reader.Read(packet) + n, err := reader.BatchRead(bufs, sizes) if err != nil { if errors.Is(err, os.ErrClosed) && f.closed.Load() { return } - f.l.WithError(err).Error("Error while reading outbound packet") + f.l.WithError(err).Error("Error while batch reading outbound packets") // This only seems to happen when something fatal happens to the fd, so exit. os.Exit(2) } - f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) + // Process all packets in the batch at once + f.consumeInsidePackets(bufs, sizes, n, outs, nb, i, conntrackCache.Get(f.l), &batchPackets, &batchAddrs) } } diff --git a/main.go b/main.go index 17aaa548..d2f89d3f 100644 --- a/main.go +++ b/main.go @@ -171,7 +171,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg for i := 0; i < routines; i++ { l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port))) - udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64)) + udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 128)) if err != nil { return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) } diff --git a/outside.go b/outside.go index 172c3e83..2088159b 100644 --- a/outside.go +++ b/outside.go @@ -102,7 +102,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, relay: relay, IsRelayed: true, } - f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) + f.readOutsidePackets(via, out[:virtioNetHdrLen], signedPayload, h, fwPacket, lhf, nb, q, localCache) return case ForwardingType: // Find the target HostInfo relay object @@ -145,7 +145,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, } //TODO: assert via is not relayed - lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f) + lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d[virtioNetHdrLen:], f) // Fallthrough to the bottom to record incoming traffic @@ -167,7 +167,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, // This testRequest might be from TryPromoteBest, so we should roam // to the new IP address before responding f.handleHostRoaming(hostinfo, via) - f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out) + f.send(header.Test, header.TestReply, ci, hostinfo, d[virtioNetHdrLen:], nb, out) } // Fallthrough to the bottom to record incoming traffic @@ -210,7 +210,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, return } - f.relayManager.HandleControlMsg(hostinfo, d, f) + f.relayManager.HandleControlMsg(hostinfo, d[virtioNetHdrLen:], f) default: f.messageMetrics.Rx(h.Type, h.Subtype, 1) @@ -481,9 +481,11 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return false } - err = newPacket(out, true, fwPacket) + packetData := out[virtioNetHdrLen:] + + err = newPacket(packetData, true, fwPacket) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("packet", out). + hostinfo.logger(f.l).WithError(err).WithField("packet", packetData). Warnf("Error while validating inbound packet") return false } @@ -498,7 +500,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out if dropReason != nil { // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // This gives us a buffer to build the reject packet in - f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q) + f.rejectOutside(packetData, hostinfo.ConnectionState, hostinfo, nb, packet, q) if f.l.Level >= logrus.DebugLevel { hostinfo.logger(f.l).WithField("fwPacket", fwPacket). WithField("reason", dropReason). @@ -562,3 +564,108 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) { // We also delete it from pending hostmap to allow for fast reconnect. f.handshakeManager.DeleteHostInfo(hostinfo) } + +// readOutsidePacketsBatch processes multiple packets received from UDP in a batch +// and writes all successfully decrypted packets to TUN in a single operation +func (f *Interface) readOutsidePacketsBatch(addrs []netip.AddrPort, payloads [][]byte, count int, outs [][]byte, nb []byte, q int, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, localCache firewall.ConntrackCache) { + // Pre-allocate slice for accumulating successful decryptions + tunPackets := make([][]byte, 0, count) + + for i := 0; i < count; i++ { + payload := payloads[i] + addr := addrs[i] + out := outs[i] + + // Parse header + err := h.Parse(payload) + if err != nil { + if len(payload) > 1 { + f.l.WithField("packet", payload).Infof("Error while parsing inbound packet from %s: %s", addr, err) + } + continue + } + + if addr.IsValid() { + if f.myVpnNetworksTable.Contains(addr.Addr()) { + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("udpAddr", addr).Debug("Refusing to process double encrypted packet") + } + continue + } + } + + var hostinfo *HostInfo + if h.Type == header.Message && h.Subtype == header.MessageRelay { + hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex) + } else { + hostinfo = f.hostMap.QueryIndex(h.RemoteIndex) + } + + var ci *ConnectionState + if hostinfo != nil { + ci = hostinfo.ConnectionState + } + + switch h.Type { + case header.Message: + if !f.handleEncrypted(ci, ViaSender{UdpAddr: addr}, h) { + continue + } + + switch h.Subtype { + case header.MessageNone: + // Decrypt packet + out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, payload[:header.Len], payload[header.Len:], h.MessageCounter, nb) + if err != nil { + hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") + continue + } + + packetData := out[virtioNetHdrLen:] + + err = newPacket(packetData, true, fwPacket) + if err != nil { + hostinfo.logger(f.l).WithError(err).WithField("packet", packetData).Warnf("Error while validating inbound packet") + continue + } + + if !hostinfo.ConnectionState.window.Update(f.l, h.MessageCounter) { + hostinfo.logger(f.l).WithField("fwPacket", fwPacket).Debugln("dropping out of window packet") + continue + } + + dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) + if dropReason != nil { + f.rejectOutside(packetData, hostinfo.ConnectionState, hostinfo, nb, payload, q) + if f.l.Level >= logrus.DebugLevel { + hostinfo.logger(f.l).WithField("fwPacket", fwPacket).WithField("reason", dropReason).Debugln("dropping inbound packet") + } + continue + } + + f.connectionManager.In(hostinfo) + // Add to batch for TUN write + tunPackets = append(tunPackets, out) + + case header.MessageRelay: + // Skip relay packets in batch mode for now (less common path) + f.readOutsidePackets(ViaSender{UdpAddr: addr}, out[:virtioNetHdrLen], payload, h, fwPacket, lhf, nb, q, localCache) + + default: + hostinfo.logger(f.l).Debugf("unexpected message subtype %d", h.Subtype) + } + + default: + // Handle non-Message types using single-packet path + f.readOutsidePackets(ViaSender{UdpAddr: addr}, out[:virtioNetHdrLen], payload, h, fwPacket, lhf, nb, q, localCache) + } + } + + if len(tunPackets) > 0 { + n, err := f.readers[q].WriteBatch(tunPackets, virtioNetHdrLen) + if err != nil { + f.l.WithError(err).WithField("sent", n).WithField("total", len(tunPackets)).Error("Failed to batch write to tun") + } + + } +} diff --git a/overlay/device.go b/overlay/device.go index b6077aba..0b7a3f2b 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -7,12 +7,26 @@ import ( "github.com/slackhq/nebula/routing" ) -type Device interface { +// BatchReadWriter extends io.ReadWriteCloser with batch I/O operations +type BatchReadWriter interface { io.ReadWriteCloser + + // BatchRead reads multiple packets at once + BatchRead(bufs [][]byte, sizes []int) (int, error) + + // WriteBatch writes multiple packets at once + WriteBatch(bufs [][]byte, offset int) (int, error) + + // BatchSize returns the optimal batch size for this device + BatchSize() int +} + +type Device interface { + BatchReadWriter Activate() error Networks() []netip.Prefix Name() string RoutesFor(netip.Addr) routing.Gateways SupportsMultiqueue() bool - NewMultiQueueReader() (io.ReadWriteCloser, error) + NewMultiQueueReader() (BatchReadWriter, error) } diff --git a/overlay/tun.go b/overlay/tun.go index e0bf69f6..da8f6c8d 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -11,6 +11,7 @@ import ( ) const DefaultMTU = 1300 +const VirtioNetHdrLen = 10 // Size of virtio_net_hdr structure type NameError struct { Name string diff --git a/overlay/tun_android.go b/overlay/tun_android.go index eddef882..2f02c0bb 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -99,6 +99,29 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for android") } + +func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) { + n, err := t.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) { + for i, buf := range bufs { + _, err := t.Write(buf[offset:]) + if err != nil { + return i, err + } + } + return len(bufs), nil +} + +func (t *tun) BatchSize() int { + return 1 +} diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 128c2001..1d04bf5d 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -553,6 +553,32 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") } + +// BatchRead reads a single packet (batch size 1 for non-Linux platforms) +func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) { + n, err := t.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +// WriteBatch writes packets individually (no batching for non-Linux platforms) +func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) { + for i, buf := range bufs { + _, err := t.Write(buf[offset:]) + if err != nil { + return i, err + } + } + return len(bufs), nil +} + +// BatchSize returns 1 for non-Linux platforms (no batching) +func (t *tun) BatchSize() int { + return 1 +} diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index aa3dddaf..ab524238 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -109,10 +109,36 @@ func (t *disabledTun) SupportsMultiqueue() bool { return true } -func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *disabledTun) NewMultiQueueReader() (BatchReadWriter, error) { return t, nil } +// BatchRead reads a single packet (batch size 1 for disabled tun) +func (t *disabledTun) BatchRead(bufs [][]byte, sizes []int) (int, error) { + n, err := t.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +// WriteBatch writes packets individually (no batching for disabled tun) +func (t *disabledTun) WriteBatch(bufs [][]byte, offset int) (int, error) { + for i, buf := range bufs { + _, err := t.Write(buf[offset:]) + if err != nil { + return i, err + } + } + return len(bufs), nil +} + +// BatchSize returns 1 for disabled tun (no batching) +func (t *disabledTun) BatchSize() int { + return 1 +} + func (t *disabledTun) Close() error { if t.read != nil { close(t.read) diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 2f65b3a4..2e937881 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -7,7 +7,6 @@ import ( "bytes" "errors" "fmt" - "io" "io/fs" "net/netip" "sync/atomic" @@ -454,10 +453,36 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd") } +// BatchRead reads a single packet (batch size 1 for FreeBSD) +func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) { + n, err := t.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +// WriteBatch writes packets individually (no batching for FreeBSD) +func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) { + for i, buf := range bufs { + _, err := t.Write(buf[offset:]) + if err != nil { + return i, err + } + } + return len(bufs), nil +} + +// BatchSize returns 1 for FreeBSD (no batching) +func (t *tun) BatchSize() int { + return 1 +} + func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 0ce01df8..6f35bd27 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -155,6 +155,29 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for ios") } + +func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) { + n, err := t.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) { + for i, buf := range bufs { + _, err := t.Write(buf[offset:]) + if err != nil { + return i, err + } + } + return len(bufs), nil +} + +func (t *tun) BatchSize() int { + return 1 +} diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 7e4aa418..0c8ea5d5 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -9,7 +9,6 @@ import ( "net" "net/netip" "os" - "strings" "sync" "sync/atomic" "time" @@ -22,10 +21,12 @@ import ( "github.com/slackhq/nebula/util" "github.com/vishvananda/netlink" "golang.org/x/sys/unix" + wgtun "golang.zx2c4.com/wireguard/tun" ) type tun struct { io.ReadWriteCloser + wgDevice wgtun.Device fd int Device string vpnNetworks []netip.Prefix @@ -71,63 +72,154 @@ type ifreqQLEN struct { pad [8]byte } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { - file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") +// wgDeviceWrapper wraps a wireguard Device to implement io.ReadWriteCloser +// This allows multiqueue readers to use the same wireguard Device batching as the main device +type wgDeviceWrapper struct { + dev wgtun.Device + buf []byte // Reusable buffer for single packet reads +} +func (w *wgDeviceWrapper) Read(b []byte) (int, error) { + // Use wireguard Device's batch API for single packet + bufs := [][]byte{b} + sizes := make([]int, 1) + n, err := w.dev.Read(bufs, sizes, 0) + if err != nil { + return 0, err + } + if n == 0 { + return 0, io.EOF + } + return sizes[0], nil +} + +func (w *wgDeviceWrapper) Write(b []byte) (int, error) { + // Buffer b should have virtio header space (10 bytes) at the beginning + // The decrypted packet data starts at offset 10 + // Pass the full buffer to WireGuard with offset=virtioNetHdrLen + bufs := [][]byte{b} + n, err := w.dev.Write(bufs, VirtioNetHdrLen) + if err != nil { + return 0, err + } + if n == 0 { + return 0, io.ErrShortWrite + } + return len(b), nil +} + +func (w *wgDeviceWrapper) WriteBatch(bufs [][]byte, offset int) (int, error) { + // Pass all buffers to WireGuard's batch write + return w.dev.Write(bufs, offset) +} + +func (w *wgDeviceWrapper) Close() error { + return w.dev.Close() +} + +// BatchRead implements batching for multiqueue readers +func (w *wgDeviceWrapper) BatchRead(bufs [][]byte, sizes []int) (int, error) { + // The zero here is offset. + return w.dev.Read(bufs, sizes, 0) +} + +// BatchSize returns the optimal batch size +func (w *wgDeviceWrapper) BatchSize() int { + return w.dev.BatchSize() +} + +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { + wgDev, name, err := wgtun.CreateUnmonitoredTUNFromFD(deviceFd) + if err != nil { + return nil, fmt.Errorf("failed to create TUN from FD: %w", err) + } + + file := wgDev.File() t, err := newTunGeneric(c, l, file, vpnNetworks) if err != nil { + _ = wgDev.Close() return nil, err } - t.Device = "tun0" + t.wgDevice = wgDev + t.Device = name return t, nil } func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { + // Check if /dev/net/tun exists, create if needed (for docker containers) + if _, err := os.Stat("/dev/net/tun"); os.IsNotExist(err) { + if err := os.MkdirAll("/dev/net", 0755); err != nil { + return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err) + } + if err := unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))); err != nil { + return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err) + } + } + + devName := c.GetString("tun.dev", "") + mtu := c.GetInt("tun.mtu", DefaultMTU) + + // Create TUN device manually to support multiqueue fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) - if err != nil { - // If /dev/net/tun doesn't exist, try to create it (will happen in docker) - if os.IsNotExist(err) { - err = os.MkdirAll("/dev/net", 0755) - if err != nil { - return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err) - } - err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))) - if err != nil { - return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err) - } - - fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0) - if err != nil { - return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err) - } - } else { - return nil, err - } - } - - var req ifReq - req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI) - if multiqueue { - req.Flags |= unix.IFF_MULTI_QUEUE - } - nameStr := c.GetString("tun.dev", "") - copy(req.Name[:], nameStr) - if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { - return nil, &NameError{ - Name: nameStr, - Underlying: err, - } - } - name := strings.Trim(string(req.Name[:]), "\x00") - - file := os.NewFile(uintptr(fd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, vpnNetworks) if err != nil { return nil, err } + var req ifReq + req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR) + if multiqueue { + req.Flags |= unix.IFF_MULTI_QUEUE + } + copy(req.Name[:], devName) + if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + unix.Close(fd) + return nil, err + } + + // Set nonblocking + if err = unix.SetNonblock(fd, true); err != nil { + unix.Close(fd) + return nil, err + } + + // Enable TCP and UDP offload (TSO/GRO) for performance + // This allows the kernel to handle segmentation/coalescing + const ( + tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 + tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6 + ) + offloads := tunTCPOffloads | tunUDPOffloads + if err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, offloads); err != nil { + // Log warning but don't fail - offload is optional + l.WithError(err).Warn("Failed to enable TUN offload (TSO/GRO), performance may be reduced") + } + + file := os.NewFile(uintptr(fd), "/dev/net/tun") + + // Create wireguard device from file descriptor + wgDev, err := wgtun.CreateTUNFromFile(file, mtu) + if err != nil { + file.Close() + return nil, fmt.Errorf("failed to create TUN from file: %w", err) + } + + name, err := wgDev.Name() + if err != nil { + _ = wgDev.Close() + return nil, fmt.Errorf("failed to get TUN device name: %w", err) + } + + // file is now owned by wgDev, get a new reference + file = wgDev.File() + t, err := newTunGeneric(c, l, file, vpnNetworks) + if err != nil { + _ = wgDev.Close() + return nil, err + } + + t.wgDevice = wgDev t.Device = name return t, nil @@ -238,22 +330,44 @@ func (t *tun) SupportsMultiqueue() bool { return true } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { return nil, err } var req ifReq - req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) + // MUST match the flags used in newTun - includes IFF_VNET_HDR + req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR | unix.IFF_MULTI_QUEUE) copy(req.Name[:], t.Device) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + unix.Close(fd) return nil, err } + // Set nonblocking mode - CRITICAL for proper netpoller integration + if err = unix.SetNonblock(fd, true); err != nil { + unix.Close(fd) + return nil, err + } + + // Get MTU from main device + mtu := t.MaxMTU + if mtu == 0 { + mtu = DefaultMTU + } + file := os.NewFile(uintptr(fd), "/dev/net/tun") - return file, nil + // Create wireguard Device from the file descriptor (just like the main device) + wgDev, err := wgtun.CreateTUNFromFile(file, mtu) + if err != nil { + file.Close() + return nil, fmt.Errorf("failed to create multiqueue TUN device: %w", err) + } + + // Return a wrapper that uses the wireguard Device for all I/O + return &wgDeviceWrapper{dev: wgDev}, nil } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { @@ -261,7 +375,68 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { return r } +func (t *tun) Read(b []byte) (int, error) { + if t.wgDevice != nil { + // Use wireguard device which handles virtio headers internally + bufs := [][]byte{b} + sizes := make([]int, 1) + n, err := t.wgDevice.Read(bufs, sizes, 0) + if err != nil { + return 0, err + } + if n == 0 { + return 0, io.EOF + } + return sizes[0], nil + } + + // Fallback: direct read from file (shouldn't happen in normal operation) + return t.ReadWriteCloser.Read(b) +} + +// BatchRead reads multiple packets at once for improved performance +// bufs: slice of buffers to read into +// sizes: slice that will be filled with packet sizes +// Returns number of packets read +func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) { + if t.wgDevice != nil { + return t.wgDevice.Read(bufs, sizes, 0) + } + + // Fallback: single packet read + n, err := t.ReadWriteCloser.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +// BatchSize returns the optimal number of packets to read/write in a batch +func (t *tun) BatchSize() int { + if t.wgDevice != nil { + return t.wgDevice.BatchSize() + } + return 1 +} + func (t *tun) Write(b []byte) (int, error) { + if t.wgDevice != nil { + // Buffer b should have virtio header space (10 bytes) at the beginning + // The decrypted packet data starts at offset 10 + // Pass the full buffer to WireGuard with offset=virtioNetHdrLen + bufs := [][]byte{b} + n, err := t.wgDevice.Write(bufs, VirtioNetHdrLen) + if err != nil { + return 0, err + } + if n == 0 { + return 0, io.ErrShortWrite + } + return len(b), nil + } + + // Fallback: direct write (shouldn't happen in normal operation) var nn int maximum := len(b) @@ -284,6 +459,22 @@ func (t *tun) Write(b []byte) (int, error) { } } +// WriteBatch writes multiple packets to the TUN device in a single syscall +func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) { + if t.wgDevice != nil { + return t.wgDevice.Write(bufs, offset) + } + + // Fallback: write individually (shouldn't happen in normal operation) + for i, buf := range bufs { + _, err := t.Write(buf) + if err != nil { + return i, err + } + } + return len(bufs), nil +} + func (t *tun) deviceBytes() (o [16]byte) { for i, c := range t.Device { o[i] = byte(c) @@ -711,6 +902,10 @@ func (t *tun) Close() error { close(t.routeChan) } + if t.wgDevice != nil { + _ = t.wgDevice.Close() + } + if t.ReadWriteCloser != nil { _ = t.ReadWriteCloser.Close() } diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 2986c895..fd510697 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -6,7 +6,6 @@ package overlay import ( "errors" "fmt" - "io" "net/netip" "os" "regexp" @@ -394,10 +393,33 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd") } +func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) { + n, err := t.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) { + for i, buf := range bufs { + _, err := t.Write(buf[offset:]) + if err != nil { + return i, err + } + } + return len(bufs), nil +} + +func (t *tun) BatchSize() int { + return 1 +} + func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 9209b795..8f772a6a 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -6,7 +6,6 @@ package overlay import ( "errors" "fmt" - "io" "net/netip" "os" "regexp" @@ -314,10 +313,33 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd") } +func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) { + n, err := t.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) { + for i, buf := range bufs { + _, err := t.Write(buf[offset:]) + if err != nil { + return i, err + } + } + return len(bufs), nil +} + +func (t *tun) BatchSize() int { + return 1 +} + func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 3477de3d..4ed50917 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -109,8 +109,12 @@ func (t *TestTun) Write(b []byte) (n int, err error) { return 0, io.ErrClosedPipe } - packet := make([]byte, len(b), len(b)) - copy(packet, b) + // Skip virtio header (consistent with production Linux tun) + // The buffer b has VirtioNetHdrLen bytes of header followed by the actual packet + data := b[VirtioNetHdrLen:] + + packet := make([]byte, len(data)) + copy(packet, data) t.TxPackets <- packet return len(b), nil } @@ -136,6 +140,34 @@ func (t *TestTun) SupportsMultiqueue() bool { return false } -func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *TestTun) NewMultiQueueReader() (BatchReadWriter, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented") } + +func (t *TestTun) BatchRead(bufs [][]byte, sizes []int) (int, error) { + n, err := t.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +func (t *TestTun) WriteBatch(bufs [][]byte, offset int) (int, error) { + if t.closed.Load() { + return 0, io.ErrClosedPipe + } + + for _, buf := range bufs { + // Strip the header at offset and send directly to channel + data := buf[offset:] + packet := make([]byte, len(data)) + copy(packet, data) + t.TxPackets <- packet + } + return len(bufs), nil +} + +func (t *TestTun) BatchSize() int { + return 1 +} diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 223eabee..321afb35 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -6,7 +6,6 @@ package overlay import ( "crypto" "fmt" - "io" "net/netip" "os" "path/filepath" @@ -241,10 +240,36 @@ func (t *winTun) SupportsMultiqueue() bool { return false } -func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *winTun) NewMultiQueueReader() (BatchReadWriter, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") } +// BatchRead reads a single packet (batch size 1 for Windows) +func (t *winTun) BatchRead(bufs [][]byte, sizes []int) (int, error) { + n, err := t.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +// WriteBatch writes packets individually (no batching for Windows) +func (t *winTun) WriteBatch(bufs [][]byte, offset int) (int, error) { + for i, buf := range bufs { + _, err := t.Write(buf[offset:]) + if err != nil { + return i, err + } + } + return len(bufs), nil +} + +// BatchSize returns 1 for Windows (no batching) +func (t *winTun) BatchSize() int { + return 1 +} + func (t *winTun) Close() error { // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes, // so to be certain, just remove everything before destroying. diff --git a/overlay/user.go b/overlay/user.go index 1f92d4e9..2a5c86c2 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -50,10 +50,36 @@ func (d *UserDevice) SupportsMultiqueue() bool { return true } -func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (d *UserDevice) NewMultiQueueReader() (BatchReadWriter, error) { return d, nil } +// BatchRead reads a single packet (batch size 1 for UserDevice) +func (d *UserDevice) BatchRead(bufs [][]byte, sizes []int) (int, error) { + n, err := d.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +// WriteBatch writes packets individually (no batching for UserDevice) +func (d *UserDevice) WriteBatch(bufs [][]byte, offset int) (int, error) { + for i, buf := range bufs { + _, err := d.Write(buf[offset:]) + if err != nil { + return i, err + } + } + return len(bufs), nil +} + +// BatchSize returns 1 for UserDevice (no batching) +func (d *UserDevice) BatchSize() int { + return 1 +} + func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) { return d.inboundReader, d.outboundWriter } diff --git a/stats.go b/stats.go index c88c45cc..b86919cc 100644 --- a/stats.go +++ b/stats.go @@ -6,6 +6,7 @@ import ( "log" "net" "net/http" + _ "net/http/pprof" "runtime" "strconv" "time" diff --git a/test/tun.go b/test/device/noop.go similarity index 55% rename from test/tun.go rename to test/device/noop.go index fb32782f..c145c755 100644 --- a/test/tun.go +++ b/test/device/noop.go @@ -1,10 +1,11 @@ -package test +package device import ( "errors" "io" "net/netip" + "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/routing" ) @@ -38,10 +39,25 @@ func (NoopTun) SupportsMultiqueue() bool { return false } -func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (NoopTun) NewMultiQueueReader() (overlay.BatchReadWriter, error) { return nil, errors.New("unsupported") } func (NoopTun) Close() error { return nil } + +// BatchRead implements BatchReadWriter interface +func (NoopTun) BatchRead(bufs [][]byte, sizes []int) (int, error) { + return 0, io.EOF +} + +// WriteBatch implements BatchReadWriter interface +func (NoopTun) WriteBatch(bufs [][]byte, offset int) (int, error) { + return len(bufs), nil +} + +// BatchSize implements BatchReadWriter interface +func (NoopTun) BatchSize() int { + return 1 +} diff --git a/udp/conn.go b/udp/conn.go index 1ae585c2..2c0eacbc 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -13,13 +13,21 @@ type EncReader func( payload []byte, ) +type EncBatchReader func( + addrs []netip.AddrPort, + payloads [][]byte, + count int, +) + type Conn interface { Rebind() error LocalAddr() (netip.AddrPort, error) - ListenOut(r EncReader) + ListenOutBatch(r EncBatchReader) WriteTo(b []byte, addr netip.AddrPort) error + WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) ReloadConfig(c *config.C) SupportsMultipleReaders() bool + BatchSize() int Close() error } @@ -31,17 +39,25 @@ func (NoopConn) Rebind() error { func (NoopConn) LocalAddr() (netip.AddrPort, error) { return netip.AddrPort{}, nil } -func (NoopConn) ListenOut(_ EncReader) { - return -} + +func (NoopConn) ListenOut(_ EncReader) {} + func (NoopConn) SupportsMultipleReaders() bool { return false } + +func (NoopConn) ListenOutBatch(_ EncBatchReader) {} + func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { return nil } -func (NoopConn) ReloadConfig(_ *config.C) { - return +func (NoopConn) WriteMulti(_ [][]byte, _ []netip.AddrPort) (int, error) { + return 0, nil +} +func (NoopConn) ReloadConfig(_ *config.C) {} + +func (NoopConn) BatchSize() int { + return 1 } func (NoopConn) Close() error { return nil diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 91201194..50012b53 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -140,6 +140,17 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error { } } +// WriteMulti sends multiple packets - fallback implementation without sendmmsg +func (u *StdConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) { + for i := range packets { + err := u.WriteTo(packets[i], addrs[i]) + if err != nil { + return i, err + } + } + return len(packets), nil +} + func (u *StdConn) LocalAddr() (netip.AddrPort, error) { a := u.UDPConn.LocalAddr() @@ -188,6 +199,34 @@ func (u *StdConn) SupportsMultipleReaders() bool { return false } +// ListenOutBatch - fallback to single-packet reads for Darwin +func (u *StdConn) ListenOutBatch(r EncBatchReader) { + buffer := make([]byte, MTU) + addrs := make([]netip.AddrPort, 1) + payloads := make([][]byte, 1) + + for { + // Just read one packet at a time and call batch callback with count=1 + n, rua, err := u.ReadFromUDPAddrPort(buffer) + if err != nil { + if errors.Is(err, net.ErrClosed) { + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return + } + + u.l.WithError(err).Error("unexpected udp socket receive error") + } + + addrs[0] = netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()) + payloads[0] = buffer[:n] + r(addrs, payloads, 1) + } +} + +func (u *StdConn) BatchSize() int { + return 1 +} + func (u *StdConn) Rebind() error { var err error if u.isV4 { diff --git a/udp/udp_generic.go b/udp/udp_generic.go index e9dad6c5..f8187718 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -101,3 +101,38 @@ func (u *GenericConn) ListenOut(r EncReader) { func (u *GenericConn) SupportsMultipleReaders() bool { return false } + +// ListenOutBatch - fallback to single-packet reads for generic platforms +func (u *GenericConn) ListenOutBatch(r EncBatchReader) { + buffer := make([]byte, MTU) + addrs := make([]netip.AddrPort, 1) + payloads := make([][]byte, 1) + + for { + // Just read one packet at a time and call batch callback with count=1 + n, rua, err := u.ReadFromUDPAddrPort(buffer) + if err != nil { + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return + } + + addrs[0] = netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()) + payloads[0] = buffer[:n] + r(addrs, payloads, 1) + } +} + +// WriteMulti sends multiple packets - fallback implementation +func (u *GenericConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) { + for i := range packets { + err := u.WriteTo(packets[i], addrs[i]) + if err != nil { + return i, err + } + } + return len(packets), nil +} + +func (u *GenericConn) BatchSize() int { + return 1 +} diff --git a/udp/udp_linux.go b/udp/udp_linux.go index e7759329..91784b1f 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -22,6 +22,11 @@ type StdConn struct { isV4 bool l *logrus.Logger batch int + + // Pre-allocated buffers for batch writes (sized for IPv6, works for both) + writeMsgs []rawMessage + writeIovecs []iovec + writeNames [][]byte } func maybeIPV4(ip net.IP) (net.IP, bool) { @@ -69,7 +74,26 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in return nil, fmt.Errorf("unable to bind to socket: %s", err) } - return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err + c := &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch} + + // Pre-allocate write message structures for batching (sized for IPv6, works for both) + c.writeMsgs = make([]rawMessage, batch) + c.writeIovecs = make([]iovec, batch) + c.writeNames = make([][]byte, batch) + + for i := range c.writeMsgs { + // Allocate for IPv6 size (larger than IPv4, works for both) + c.writeNames[i] = make([]byte, unix.SizeofSockaddrInet6) + + // Point to the iovec in the slice + c.writeMsgs[i].Hdr.Iov = &c.writeIovecs[i] + c.writeMsgs[i].Hdr.Iovlen = 1 + + c.writeMsgs[i].Hdr.Name = &c.writeNames[i][0] + // Namelen will be set appropriately in writeMulti4/writeMulti6 + } + + return c, err } func (u *StdConn) SupportsMultipleReaders() bool { @@ -122,7 +146,7 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) { } } -func (u *StdConn) ListenOut(r EncReader) { +func (u *StdConn) ListenOutBatch(r EncBatchReader) { var ip netip.Addr msgs, buffers, names := u.PrepareRawMessages(u.batch) @@ -131,6 +155,12 @@ func (u *StdConn) ListenOut(r EncReader) { read = u.ReadSingle } + udpBatchHist := metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024)) + + // Pre-allocate slices for batch callback + addrs := make([]netip.AddrPort, u.batch) + payloads := make([][]byte, u.batch) + for { n, err := read(msgs) if err != nil { @@ -138,15 +168,21 @@ func (u *StdConn) ListenOut(r EncReader) { return } + udpBatchHist.Update(int64(n)) + + // Prepare batch data for i := 0; i < n; i++ { - // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic if u.isV4 { ip, _ = netip.AddrFromSlice(names[i][4:8]) } else { ip, _ = netip.AddrFromSlice(names[i][8:24]) } - r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) + addrs[i] = netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])) + payloads[i] = buffers[i][:msgs[i].Len] } + + // Call batch callback with all packets + r(addrs, payloads, n) } } @@ -198,6 +234,19 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { return u.writeTo6(b, ip) } +func (u *StdConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) { + if len(packets) != len(addrs) { + return 0, fmt.Errorf("packets and addrs length mismatch") + } + if len(packets) == 0 { + return 0, nil + } + if u.isV4 { + return u.writeMulti4(packets, addrs) + } + return u.writeMulti6(packets, addrs) +} + func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { var rsa unix.RawSockaddrInet6 rsa.Family = unix.AF_INET6 @@ -252,6 +301,123 @@ func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { } } +func (u *StdConn) writeMulti4(packets [][]byte, addrs []netip.AddrPort) (int, error) { + sent := 0 + for sent < len(packets) { + // Determine batch size based on remaining packets and buffer capacity + batchSize := len(packets) - sent + if batchSize > len(u.writeMsgs) { + batchSize = len(u.writeMsgs) + } + + // Use pre-allocated buffers + msgs := u.writeMsgs[:batchSize] + iovecs := u.writeIovecs[:batchSize] + names := u.writeNames[:batchSize] + + // Setup message structures for this batch + for i := 0; i < batchSize; i++ { + pktIdx := sent + i + if !addrs[pktIdx].Addr().Is4() { + return sent + i, ErrInvalidIPv6RemoteForSocket + } + + // Setup the packet buffer + iovecs[i].Base = &packets[pktIdx][0] + iovecs[i].Len = uint(len(packets[pktIdx])) + + // Setup the destination address + rsa := (*unix.RawSockaddrInet4)(unsafe.Pointer(&names[i][0])) + rsa.Family = unix.AF_INET + rsa.Addr = addrs[pktIdx].Addr().As4() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[pktIdx].Port()) + + // Set the appropriate address length for IPv4 + msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet4 + } + + // Send this batch + nsent, _, err := unix.Syscall6( + unix.SYS_SENDMMSG, + uintptr(u.sysFd), + uintptr(unsafe.Pointer(&msgs[0])), + uintptr(batchSize), + 0, + 0, + 0, + ) + + if err != 0 { + return sent + int(nsent), &net.OpError{Op: "sendmmsg", Err: err} + } + + sent += int(nsent) + if int(nsent) < batchSize { + // Couldn't send all packets in batch, return what we sent + return sent, nil + } + } + + return sent, nil +} + +func (u *StdConn) writeMulti6(packets [][]byte, addrs []netip.AddrPort) (int, error) { + sent := 0 + for sent < len(packets) { + // Determine batch size based on remaining packets and buffer capacity + batchSize := len(packets) - sent + if batchSize > len(u.writeMsgs) { + batchSize = len(u.writeMsgs) + } + + // Use pre-allocated buffers + msgs := u.writeMsgs[:batchSize] + iovecs := u.writeIovecs[:batchSize] + names := u.writeNames[:batchSize] + + // Setup message structures for this batch + for i := 0; i < batchSize; i++ { + pktIdx := sent + i + + // Setup the packet buffer + iovecs[i].Base = &packets[pktIdx][0] + iovecs[i].Len = uint(len(packets[pktIdx])) + + // Setup the destination address + rsa := (*unix.RawSockaddrInet6)(unsafe.Pointer(&names[i][0])) + rsa.Family = unix.AF_INET6 + rsa.Addr = addrs[pktIdx].Addr().As16() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[pktIdx].Port()) + + // Set the appropriate address length for IPv6 + msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet6 + } + + // Send this batch + nsent, _, err := unix.Syscall6( + unix.SYS_SENDMMSG, + uintptr(u.sysFd), + uintptr(unsafe.Pointer(&msgs[0])), + uintptr(batchSize), + 0, + 0, + 0, + ) + + if err != 0 { + return sent + int(nsent), &net.OpError{Op: "sendmmsg", Err: err} + } + + sent += int(nsent) + if int(nsent) < batchSize { + // Couldn't send all packets in batch, return what we sent + return sent, nil + } + } + + return sent, nil +} + func (u *StdConn) ReloadConfig(c *config.C) { b := c.GetInt("listen.read_buffer", 0) if b > 0 { @@ -309,6 +475,10 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { return nil } +func (u *StdConn) BatchSize() int { + return u.batch +} + func (u *StdConn) Close() error { return syscall.Close(u.sysFd) } diff --git a/udp/udp_linux_32.go b/udp/udp_linux_32.go index de8f1cdf..707a2b1f 100644 --- a/udp/udp_linux_32.go +++ b/udp/udp_linux_32.go @@ -12,7 +12,7 @@ import ( type iovec struct { Base *byte - Len uint32 + Len uint } type msghdr struct { @@ -40,7 +40,7 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { names[i] = make([]byte, unix.SizeofSockaddrInet6) vs := []iovec{ - {Base: &buffers[i][0], Len: uint32(len(buffers[i]))}, + {Base: &buffers[i][0], Len: uint(len(buffers[i]))}, } msgs[i].Hdr.Iov = &vs[0] diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 48c5a978..89c6695d 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -12,7 +12,7 @@ import ( type iovec struct { Base *byte - Len uint64 + Len uint } type msghdr struct { @@ -43,7 +43,7 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { names[i] = make([]byte, unix.SizeofSockaddrInet6) vs := []iovec{ - {Base: &buffers[i][0], Len: uint64(len(buffers[i]))}, + {Base: &buffers[i][0], Len: uint(len(buffers[i]))}, } msgs[i].Hdr.Iov = &vs[0] diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index 3d60f34c..749ea354 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -338,6 +338,50 @@ func (u *RIOConn) Rebind() error { func (u *RIOConn) ReloadConfig(*config.C) {} +// BatchSize returns 1 since RIO reads packets one at a time +func (u *RIOConn) BatchSize() int { + return 1 +} + +// ListenOutBatch - fallback to single-packet reads for RIO +func (u *RIOConn) ListenOutBatch(r EncBatchReader) { + buffer := make([]byte, MTU) + addrs := make([]netip.AddrPort, 1) + payloads := make([][]byte, 1) + + var lastRecvErr time.Time + + for { + n, rua, err := u.receive(buffer) + if err != nil { + if errors.Is(err, net.ErrClosed) { + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return + } + if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute { + lastRecvErr = time.Now() + u.l.WithError(err).Warn("unexpected udp socket receive error") + } + continue + } + + addrs[0] = netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)) + payloads[0] = buffer[:n] + r(addrs, payloads, 1) + } +} + +// WriteMulti sends multiple packets - fallback implementation +func (u *RIOConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) { + for i := range packets { + err := u.WriteTo(packets[i], addrs[i]) + if err != nil { + return i, err + } + } + return len(packets), nil +} + func (u *RIOConn) Close() error { if !u.isOpen.CompareAndSwap(true, false) { return nil diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 5f0f7765..5130db99 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -116,6 +116,31 @@ func (u *TesterConn) ListenOut(r EncReader) { } } +func (u *TesterConn) ListenOutBatch(r EncBatchReader) { + addrs := make([]netip.AddrPort, 1) + payloads := make([][]byte, 1) + + for { + p, ok := <-u.RxPackets + if !ok { + return + } + addrs[0] = p.From + payloads[0] = p.Data + r(addrs, payloads, 1) + } +} + +func (u *TesterConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) { + for i := range packets { + err := u.WriteTo(packets[i], addrs[i]) + if err != nil { + return i, err + } + } + return len(packets), nil +} + func (u *TesterConn) ReloadConfig(*config.C) {} func NewUDPStatsEmitter(_ []Conn) func() { @@ -131,6 +156,10 @@ func (u *TesterConn) SupportsMultipleReaders() bool { return false } +func (u *TesterConn) BatchSize() int { + return 1 +} + func (u *TesterConn) Rebind() error { return nil }