use wg tun library; batching & locking improvements

This commit is contained in:
Jay Wren
2025-11-04 15:04:24 -05:00
parent 42bee7cf17
commit 5cc3ff594a
33 changed files with 1353 additions and 121 deletions

254
inside.go
View File

@@ -2,6 +2,8 @@ package nebula
import (
"net/netip"
"sync"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall"
@@ -11,6 +13,258 @@ import (
"github.com/slackhq/nebula/routing"
)
// preEncryptionPacket holds packet data before batch encryption
type preEncryptionPacket struct {
hostinfo *HostInfo
ci *ConnectionState
packet []byte
out []byte
}
// Pool for preEncryptionBatch slices to reduce allocations
var preEncryptionBatchPool = sync.Pool{
New: func() any {
// Pre-allocate with reasonable capacity
batch := make([]preEncryptionPacket, 0, 128)
return &batch
},
}
// consumeInsidePackets processes multiple packets in a batch for improved performance
// packets: slice of packet buffers to process
// sizes: slice of packet sizes
// count: number of packets to process
// outs: slice of output buffers (one per packet) with virtio headroom
// q: queue index
// localCache: firewall conntrack cache
// batchPackets: pre-allocated slice for accumulating encrypted packets
// batchAddrs: pre-allocated slice for accumulating destination addresses
func (f *Interface) consumeInsidePackets(packets [][]byte, sizes []int, count int, outs [][]byte, nb []byte, q int, localCache firewall.ConntrackCache, batchPackets *[][]byte, batchAddrs *[]netip.AddrPort) {
// Reusable per-packet state
fwPacket := &firewall.Packet{}
// Reset batch accumulation slices (reuse capacity)
*batchPackets = (*batchPackets)[:0]
*batchAddrs = (*batchAddrs)[:0]
// Collect packets for batched encryption
preEncryptionBatch := make([]preEncryptionPacket, 0, count)
// Process each packet in the batch
for i := 0; i < count; i++ {
packet := packets[i][:sizes[i]]
out := outs[i]
// Inline the consumeInsidePacket logic for better performance
err := newPacket(packet, false, fwPacket)
if err != nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
}
continue
}
// Ignore local broadcast packets
if f.dropLocalBroadcast {
if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
continue
}
}
if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) {
// Immediately forward packets from self to self.
if immediatelyForwardToSelf {
_, err := f.readers[q].Write(packet)
if err != nil {
f.l.WithError(err).Error("Failed to forward to tun")
}
}
continue
}
// Ignore multicast packets
if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
continue
}
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
})
if hostinfo == nil {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
WithField("fwPacket", fwPacket).
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
}
continue
}
if !ready {
continue
}
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason != nil {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).
WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping outbound packet")
}
continue
}
// Prepare packet for batch encryption
ci := hostinfo.ConnectionState
if ci.eKey == nil {
continue
}
// Check if this needs relay - if so, send immediately and skip batching
useRelay := !hostinfo.remote.IsValid()
if useRelay {
// Handle relay sends individually (less common path)
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, packet, nb, out, q)
continue
}
// Collect for batched encryption
preEncryptionBatch = append(preEncryptionBatch, preEncryptionPacket{
hostinfo: hostinfo,
ci: ci,
packet: packet,
out: out,
})
}
// BATCH ENCRYPTION: Process all collected packets
if len(preEncryptionBatch) > 0 {
f.encryptBatch(preEncryptionBatch, nb, batchPackets, batchAddrs)
}
// Send all accumulated packets in one batch
if len(*batchPackets) > 0 {
batchSize := len(*batchPackets)
n, err := f.writers[q].WriteMulti(*batchPackets, *batchAddrs)
if err != nil {
f.l.WithError(err).WithField("sent", n).WithField("total", batchSize).Error("Failed to send batch")
}
}
}
// encryptBatch processes multiple packets, grouping by ConnectionState to reduce lock acquisitions
func (f *Interface) encryptBatch(batch []preEncryptionPacket, nb []byte, batchPackets *[][]byte, batchAddrs *[]netip.AddrPort) {
lockStart := time.Now()
lockAcquisitions := int64(0)
if noiseutil.EncryptLockNeeded {
// Group packets by ConnectionState to minimize lock acquisitions
// Process packets in order but batch lock acquisitions per CI
var currentCI *ConnectionState
var lockHeld bool
for i := range batch {
ci := batch[i].ci
hostinfo := batch[i].hostinfo
// Validate packet data to prevent nil pointer dereference
if ci == nil || hostinfo == nil {
continue
}
// Switch locks if we're moving to a different ConnectionState
if ci != currentCI {
if lockHeld {
currentCI.writeLock.Unlock()
}
ci.writeLock.Lock()
lockAcquisitions++
currentCI = ci
lockHeld = true
}
c := ci.messageCounter.Add(1)
out := header.Encode(batch[i].out, header.Version, header.Message, 0, hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo)
// Query lighthouse if needed
if hostinfo.lastRebindCount != f.rebindCount {
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
hostinfo.lastRebindCount = f.rebindCount
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
}
}
var err error
out, err = ci.eKey.EncryptDanger(out, out, batch[i].packet, c, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("counter", c).
Error("Failed to encrypt outgoing packet")
continue
}
// Add to output batches
*batchPackets = append(*batchPackets, out)
*batchAddrs = append(*batchAddrs, hostinfo.remote)
}
// Release final lock
if lockHeld {
currentCI.writeLock.Unlock()
}
} else {
// No locks needed - process directly
for i := range batch {
ci := batch[i].ci
hostinfo := batch[i].hostinfo
// Validate packet data to prevent nil pointer dereference
if ci == nil || hostinfo == nil {
continue
}
c := ci.messageCounter.Add(1)
out := header.Encode(batch[i].out, header.Version, header.Message, 0, hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo)
// Query lighthouse if needed
if hostinfo.lastRebindCount != f.rebindCount {
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
hostinfo.lastRebindCount = f.rebindCount
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
}
}
var err error
out, err = ci.eKey.EncryptDanger(out, out, batch[i].packet, c, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("counter", c).
Error("Failed to encrypt outgoing packet")
continue
}
// Add to output batches
*batchPackets = append(*batchPackets, out)
*batchAddrs = append(*batchAddrs, hostinfo.remote)
}
}
// Record metrics
encryptionTime := time.Since(lockStart)
if noiseutil.EncryptLockNeeded {
f.batchMetrics.lockAcquisitions.Inc(lockAcquisitions)
}
f.batchMetrics.encryptionTime.Update(encryptionTime.Nanoseconds())
f.batchMetrics.batchSize.Update(int64(len(batch)))
}
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
err := newPacket(packet, false, fwPacket)
if err != nil {