use wg tun library; batching & locking improvements

This commit is contained in:
Jay Wren
2025-11-04 15:04:24 -05:00
parent 42bee7cf17
commit 5cc3ff594a
33 changed files with 1353 additions and 121 deletions

View File

@@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"io"
"net/netip"
"os"
"runtime"
@@ -22,6 +21,7 @@ import (
)
const mtu = 9001
const virtioNetHdrLen = overlay.VirtioNetHdrLen
type InterfaceConfig struct {
HostMap *HostMap
@@ -50,6 +50,13 @@ type InterfaceConfig struct {
l *logrus.Logger
}
type batchMetrics struct {
udpReadSize metrics.Histogram
encryptionTime metrics.Histogram // Time spent in encryption (including lock waits)
batchSize metrics.Histogram // Dynamic batch sizes being used
lockAcquisitions metrics.Counter // Number of lock acquisitions (should be minimal)
}
type Interface struct {
hostMap *HostMap
outside udp.Conn
@@ -87,11 +94,12 @@ type Interface struct {
conntrackCacheTimeout time.Duration
writers []udp.Conn
readers []io.ReadWriteCloser
readers []overlay.BatchReadWriter
metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics
cachedPacketMetrics *cachedPacketMetrics
batchMetrics *batchMetrics
l *logrus.Logger
}
@@ -178,7 +186,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
routines: c.routines,
version: c.version,
writers: make([]udp.Conn, c.routines),
readers: make([]io.ReadWriteCloser, c.routines),
readers: make([]overlay.BatchReadWriter, c.routines),
myVpnNetworks: cs.myVpnNetworks,
myVpnNetworksTable: cs.myVpnNetworksTable,
myVpnAddrs: cs.myVpnAddrs,
@@ -194,6 +202,12 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
sent: metrics.GetOrRegisterCounter("hostinfo.cached_packets.sent", nil),
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
},
batchMetrics: &batchMetrics{
udpReadSize: metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024)),
encryptionTime: metrics.GetOrRegisterHistogram("batch.encryption_time_ns", nil, metrics.NewUniformSample(1024)),
batchSize: metrics.GetOrRegisterHistogram("batch.size", nil, metrics.NewUniformSample(1024)),
lockAcquisitions: metrics.GetOrRegisterCounter("batch.lock_acquisitions", nil),
},
l: c.l,
}
@@ -233,7 +247,7 @@ func (f *Interface) activate() {
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
// Prepare n tun queues
var reader io.ReadWriteCloser = f.inside
var reader overlay.BatchReadWriter = f.inside
for i := 0; i < f.routines; i++ {
if i > 0 {
reader, err = f.inside.NewMultiQueueReader()
@@ -274,39 +288,68 @@ func (f *Interface) listenOut(i int) {
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
lhh := f.lightHouse.NewRequestHandler()
plaintext := make([]byte, udp.MTU)
// Pre-allocate output buffers for batch processing
batchSize := li.BatchSize()
outs := make([][]byte, batchSize)
for idx := range outs {
// Allocate full buffer with virtio header space
outs[idx] = make([]byte, virtioNetHdrLen, virtioNetHdrLen+udp.MTU)
}
h := &header.H{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
nb := make([]byte, 12)
li.ListenOutBatch(func(addrs []netip.AddrPort, payloads [][]byte, count int) {
f.readOutsidePacketsBatch(addrs, payloads, count, outs[:count], nb, i, h, fwPacket, lhh, ctCache.Get(f.l))
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
})
}
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
func (f *Interface) listenIn(reader overlay.BatchReadWriter, i int) {
runtime.LockOSThread()
packet := make([]byte, mtu)
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
batchSize := reader.BatchSize()
// Allocate buffers for batch reading
bufs := make([][]byte, batchSize)
for idx := range bufs {
bufs[idx] = make([]byte, mtu)
}
sizes := make([]int, batchSize)
// Allocate output buffers for batch processing (one per packet)
// Each has virtio header headroom to avoid copies on write
outs := make([][]byte, batchSize)
for idx := range outs {
outBuf := make([]byte, virtioNetHdrLen+mtu)
outs[idx] = outBuf[virtioNetHdrLen:] // Slice starting after headroom
}
// Pre-allocate batch accumulation buffers for sending
batchPackets := make([][]byte, 0, batchSize)
batchAddrs := make([]netip.AddrPort, 0, batchSize)
// Pre-allocate nonce buffer (reused for all encryptions)
nb := make([]byte, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
for {
n, err := reader.Read(packet)
n, err := reader.BatchRead(bufs, sizes)
if err != nil {
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return
}
f.l.WithError(err).Error("Error while reading outbound packet")
f.l.WithError(err).Error("Error while batch reading outbound packets")
// This only seems to happen when something fatal happens to the fd, so exit.
os.Exit(2)
}
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
// Process all packets in the batch at once
f.consumeInsidePackets(bufs, sizes, n, outs, nb, i, conntrackCache.Get(f.l), &batchPackets, &batchAddrs)
}
}