mirror of
https://github.com/slackhq/nebula.git
synced 2026-02-15 09:14:23 +01:00
use wg tun library; batching & locking improvements
This commit is contained in:
254
inside.go
254
inside.go
@@ -2,6 +2,8 @@ package nebula
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
@@ -11,6 +13,258 @@ import (
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
// preEncryptionPacket holds packet data before batch encryption
|
||||
type preEncryptionPacket struct {
|
||||
hostinfo *HostInfo
|
||||
ci *ConnectionState
|
||||
packet []byte
|
||||
out []byte
|
||||
}
|
||||
|
||||
// Pool for preEncryptionBatch slices to reduce allocations
|
||||
var preEncryptionBatchPool = sync.Pool{
|
||||
New: func() any {
|
||||
// Pre-allocate with reasonable capacity
|
||||
batch := make([]preEncryptionPacket, 0, 128)
|
||||
return &batch
|
||||
},
|
||||
}
|
||||
|
||||
// consumeInsidePackets processes multiple packets in a batch for improved performance
|
||||
// packets: slice of packet buffers to process
|
||||
// sizes: slice of packet sizes
|
||||
// count: number of packets to process
|
||||
// outs: slice of output buffers (one per packet) with virtio headroom
|
||||
// q: queue index
|
||||
// localCache: firewall conntrack cache
|
||||
// batchPackets: pre-allocated slice for accumulating encrypted packets
|
||||
// batchAddrs: pre-allocated slice for accumulating destination addresses
|
||||
func (f *Interface) consumeInsidePackets(packets [][]byte, sizes []int, count int, outs [][]byte, nb []byte, q int, localCache firewall.ConntrackCache, batchPackets *[][]byte, batchAddrs *[]netip.AddrPort) {
|
||||
// Reusable per-packet state
|
||||
fwPacket := &firewall.Packet{}
|
||||
|
||||
// Reset batch accumulation slices (reuse capacity)
|
||||
*batchPackets = (*batchPackets)[:0]
|
||||
*batchAddrs = (*batchAddrs)[:0]
|
||||
|
||||
// Collect packets for batched encryption
|
||||
preEncryptionBatch := make([]preEncryptionPacket, 0, count)
|
||||
|
||||
// Process each packet in the batch
|
||||
for i := 0; i < count; i++ {
|
||||
packet := packets[i][:sizes[i]]
|
||||
out := outs[i]
|
||||
|
||||
// Inline the consumeInsidePacket logic for better performance
|
||||
err := newPacket(packet, false, fwPacket)
|
||||
if err != nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Ignore local broadcast packets
|
||||
if f.dropLocalBroadcast {
|
||||
if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) {
|
||||
// Immediately forward packets from self to self.
|
||||
if immediatelyForwardToSelf {
|
||||
_, err := f.readers[q].Write(packet)
|
||||
if err != nil {
|
||||
f.l.WithError(err).Error("Failed to forward to tun")
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Ignore multicast packets
|
||||
if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
|
||||
continue
|
||||
}
|
||||
|
||||
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
|
||||
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
||||
})
|
||||
|
||||
if hostinfo == nil {
|
||||
f.rejectInside(packet, out, q)
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
|
||||
WithField("fwPacket", fwPacket).
|
||||
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if !ready {
|
||||
continue
|
||||
}
|
||||
|
||||
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
if dropReason != nil {
|
||||
f.rejectInside(packet, out, q)
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(f.l).
|
||||
WithField("fwPacket", fwPacket).
|
||||
WithField("reason", dropReason).
|
||||
Debugln("dropping outbound packet")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Prepare packet for batch encryption
|
||||
ci := hostinfo.ConnectionState
|
||||
if ci.eKey == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this needs relay - if so, send immediately and skip batching
|
||||
useRelay := !hostinfo.remote.IsValid()
|
||||
if useRelay {
|
||||
// Handle relay sends individually (less common path)
|
||||
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, packet, nb, out, q)
|
||||
continue
|
||||
}
|
||||
|
||||
// Collect for batched encryption
|
||||
preEncryptionBatch = append(preEncryptionBatch, preEncryptionPacket{
|
||||
hostinfo: hostinfo,
|
||||
ci: ci,
|
||||
packet: packet,
|
||||
out: out,
|
||||
})
|
||||
}
|
||||
|
||||
// BATCH ENCRYPTION: Process all collected packets
|
||||
if len(preEncryptionBatch) > 0 {
|
||||
f.encryptBatch(preEncryptionBatch, nb, batchPackets, batchAddrs)
|
||||
}
|
||||
|
||||
// Send all accumulated packets in one batch
|
||||
if len(*batchPackets) > 0 {
|
||||
batchSize := len(*batchPackets)
|
||||
n, err := f.writers[q].WriteMulti(*batchPackets, *batchAddrs)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("sent", n).WithField("total", batchSize).Error("Failed to send batch")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// encryptBatch processes multiple packets, grouping by ConnectionState to reduce lock acquisitions
|
||||
func (f *Interface) encryptBatch(batch []preEncryptionPacket, nb []byte, batchPackets *[][]byte, batchAddrs *[]netip.AddrPort) {
|
||||
lockStart := time.Now()
|
||||
lockAcquisitions := int64(0)
|
||||
|
||||
if noiseutil.EncryptLockNeeded {
|
||||
// Group packets by ConnectionState to minimize lock acquisitions
|
||||
// Process packets in order but batch lock acquisitions per CI
|
||||
var currentCI *ConnectionState
|
||||
var lockHeld bool
|
||||
|
||||
for i := range batch {
|
||||
ci := batch[i].ci
|
||||
hostinfo := batch[i].hostinfo
|
||||
|
||||
// Validate packet data to prevent nil pointer dereference
|
||||
if ci == nil || hostinfo == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Switch locks if we're moving to a different ConnectionState
|
||||
if ci != currentCI {
|
||||
if lockHeld {
|
||||
currentCI.writeLock.Unlock()
|
||||
}
|
||||
ci.writeLock.Lock()
|
||||
lockAcquisitions++
|
||||
currentCI = ci
|
||||
lockHeld = true
|
||||
}
|
||||
|
||||
c := ci.messageCounter.Add(1)
|
||||
out := header.Encode(batch[i].out, header.Version, header.Message, 0, hostinfo.remoteIndexId, c)
|
||||
f.connectionManager.Out(hostinfo)
|
||||
|
||||
// Query lighthouse if needed
|
||||
if hostinfo.lastRebindCount != f.rebindCount {
|
||||
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
|
||||
hostinfo.lastRebindCount = f.rebindCount
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
out, err = ci.eKey.EncryptDanger(out, out, batch[i].packet, c, nb)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).
|
||||
WithField("counter", c).
|
||||
Error("Failed to encrypt outgoing packet")
|
||||
continue
|
||||
}
|
||||
|
||||
// Add to output batches
|
||||
*batchPackets = append(*batchPackets, out)
|
||||
*batchAddrs = append(*batchAddrs, hostinfo.remote)
|
||||
}
|
||||
|
||||
// Release final lock
|
||||
if lockHeld {
|
||||
currentCI.writeLock.Unlock()
|
||||
}
|
||||
} else {
|
||||
// No locks needed - process directly
|
||||
for i := range batch {
|
||||
ci := batch[i].ci
|
||||
hostinfo := batch[i].hostinfo
|
||||
|
||||
// Validate packet data to prevent nil pointer dereference
|
||||
if ci == nil || hostinfo == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
c := ci.messageCounter.Add(1)
|
||||
out := header.Encode(batch[i].out, header.Version, header.Message, 0, hostinfo.remoteIndexId, c)
|
||||
f.connectionManager.Out(hostinfo)
|
||||
|
||||
// Query lighthouse if needed
|
||||
if hostinfo.lastRebindCount != f.rebindCount {
|
||||
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
|
||||
hostinfo.lastRebindCount = f.rebindCount
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
out, err = ci.eKey.EncryptDanger(out, out, batch[i].packet, c, nb)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).
|
||||
WithField("counter", c).
|
||||
Error("Failed to encrypt outgoing packet")
|
||||
continue
|
||||
}
|
||||
|
||||
// Add to output batches
|
||||
*batchPackets = append(*batchPackets, out)
|
||||
*batchAddrs = append(*batchAddrs, hostinfo.remote)
|
||||
}
|
||||
}
|
||||
|
||||
// Record metrics
|
||||
encryptionTime := time.Since(lockStart)
|
||||
if noiseutil.EncryptLockNeeded {
|
||||
f.batchMetrics.lockAcquisitions.Inc(lockAcquisitions)
|
||||
}
|
||||
f.batchMetrics.encryptionTime.Update(encryptionTime.Nanoseconds())
|
||||
f.batchMetrics.batchSize.Update(int64(len(batch)))
|
||||
}
|
||||
|
||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
||||
err := newPacket(packet, false, fwPacket)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user