batch tun reads

This commit is contained in:
Jay Wren
2026-02-03 17:12:44 -05:00
parent 15333f9fed
commit 30db76ed79
5 changed files with 447 additions and 7 deletions

139
inside.go
View File

@@ -9,8 +9,75 @@ import (
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/noiseutil"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/udp"
)
// consumeInsidePacketBatched is a variant of consumeInsidePacket that queues
// outgoing packets into pendingPackets instead of sending them immediately.
// The caller is responsible for flushing pendingPackets with WriteBatch.
func (f *Interface) consumeInsidePacketBatched(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache, pendingPackets *[]udp.BatchPacket) {
err := newPacket(packet, false, fwPacket)
if err != nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
}
return
}
// Ignore local broadcast packets
if f.dropLocalBroadcast {
if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
return
}
}
if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) {
if immediatelyForwardToSelf {
_, err := f.readers[q].Write(packet)
if err != nil {
f.l.WithError(err).Error("Failed to forward to tun")
}
}
return
}
// Ignore multicast packets
if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
return
}
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
})
if hostinfo == nil {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
WithField("fwPacket", fwPacket).
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
}
return
}
if !ready {
return
}
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason == nil {
f.sendNoMetricsBatched(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q, pendingPackets)
} else {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).
WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping outbound packet")
}
}
}
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
err := newPacket(packet, false, fwPacket)
if err != nil {
@@ -409,3 +476,75 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
}
}
}
// sendNoMetricsBatched is like sendNoMetrics but queues the packet for batched sending
// instead of sending immediately. The caller must flush pendingPackets with WriteBatch.
func (f *Interface) sendNoMetricsBatched(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int, pendingPackets *[]udp.BatchPacket) {
if ci.eKey == nil {
return
}
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
fullOut := out
if useRelay {
if len(out) < header.Len {
out = out[:header.Len]
}
out = out[header.Len:]
}
if noiseutil.EncryptLockNeeded {
ci.writeLock.Lock()
}
c := ci.messageCounter.Add(1)
out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo)
if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
hostinfo.lastRebindCount = f.rebindCount
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
}
}
var err error
out, err = ci.eKey.EncryptDanger(out, out, p, c, nb)
if noiseutil.EncryptLockNeeded {
ci.writeLock.Unlock()
}
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).WithField("counter", c).
WithField("attemptedCounter", c).
Error("Failed to encrypt outgoing packet")
return
}
// Queue the packet for batched sending
var addr netip.AddrPort
if remote.IsValid() {
addr = remote
} else if hostinfo.remote.IsValid() {
addr = hostinfo.remote
} else {
// Relay path - send immediately, not batched
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
if err != nil {
hostinfo.relayState.DeleteRelay(relayIP)
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetricsBatched failed to find HostInfo")
continue
}
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
break
}
return
}
// Copy the payload since the buffer will be reused
payload := make([]byte, len(out))
copy(payload, out)
*pendingPackets = append(*pendingPackets, udp.BatchPacket{Payload: payload, Addr: addr})
}

View File

@@ -48,6 +48,8 @@ type InterfaceConfig struct {
ConntrackCacheTimeout time.Duration
l *logrus.Logger
tunBatchSize int // batch size for TUN read/write batching, 0 to disable
}
type Interface struct {
@@ -86,8 +88,9 @@ type Interface struct {
conntrackCacheTimeout time.Duration
writers []udp.Conn
readers []io.ReadWriteCloser
writers []udp.Conn
readers []io.ReadWriteCloser
tunBatchSize int // batch size for TUN read/write batching
metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics
@@ -187,6 +190,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
relayManager: c.relayManager,
connectionManager: c.connectionManager,
conntrackCacheTimeout: c.ConntrackCacheTimeout,
tunBatchSize: c.tunBatchSize,
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
messageMetrics: c.MessageMetrics,
@@ -244,6 +248,15 @@ func (f *Interface) activate() {
f.readers[i] = reader
}
// Enable batch reading on all readers if batch size > 1
if f.tunBatchSize > 1 {
for i := 0; i < f.routines; i++ {
if err := overlay.EnableBatchReading(f.readers[i]); err != nil {
f.l.WithError(err).WithField("routine", i).Warn("Failed to enable batch reading, falling back to single reads")
}
}
}
if err := f.inside.Activate(); err != nil {
f.inside.Close()
f.l.Fatal(err)
@@ -287,13 +300,21 @@ func (f *Interface) listenOut(i int) {
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
runtime.LockOSThread()
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
// Check if batch reading is available and enabled
batchReader := overlay.AsBatchReader(reader)
if batchReader != nil && f.tunBatchSize > 1 {
f.listenInBatched(reader, batchReader, i, conntrackCache)
return
}
// Fallback to single-packet reading
packet := make([]byte, mtu)
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
for {
n, err := reader.Read(packet)
if err != nil {
@@ -310,6 +331,54 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
}
}
func (f *Interface) listenInBatched(reader io.ReadWriteCloser, batchReader overlay.BatchReader, i int, conntrackCache *firewall.ConntrackCacheTicker) {
batchSize := f.tunBatchSize
// Pre-allocate buffers for batch reading
packets := make([][]byte, batchSize)
for j := range packets {
packets[j] = make([]byte, mtu)
}
sizes := make([]int, batchSize)
// Pre-allocate buffers for packet processing
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
// Pre-allocate buffer for batched UDP writes
pendingPackets := make([]udp.BatchPacket, 0, batchSize)
for {
// Read a batch of packets from TUN
n, err := batchReader.ReadBatch(packets, sizes)
if err != nil {
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return
}
f.l.WithError(err).Error("Error while reading outbound packets")
os.Exit(2)
}
if n == 0 {
continue
}
// Process all packets in the batch
cache := conntrackCache.Get(f.l)
for j := 0; j < n; j++ {
f.consumeInsidePacketBatched(packets[j][:sizes[j]], fwPacket, nb, out, i, cache, &pendingPackets)
}
// Flush all pending UDP writes
if len(pendingPackets) > 0 {
f.writers[i].WriteBatch(pendingPackets)
pendingPackets = pendingPackets[:0]
}
}
}
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
c.RegisterReloadCallback(f.reloadFirewall)
c.RegisterReloadCallback(f.reloadSendRecvError)

View File

@@ -250,6 +250,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
punchy: punchy,
ConntrackCacheTimeout: conntrackCacheTimeout,
l: l,
tunBatchSize: c.GetInt("tun.batch", 64),
}
var ifce *Interface

View File

@@ -16,3 +16,38 @@ type Device interface {
SupportsMultiqueue() bool
NewMultiQueueReader() (io.ReadWriteCloser, error)
}
// BatchReader is an optional interface that devices can implement
// to support reading multiple packets in a single batch operation.
// This can significantly reduce syscall overhead under high load.
type BatchReader interface {
// ReadBatch reads up to len(packets) packets into the provided buffers.
// Each packet is read into packets[i] and its length is stored in sizes[i].
// Returns the number of packets read, or an error.
// A return of (0, nil) indicates no packets were available (non-blocking).
ReadBatch(packets [][]byte, sizes []int) (int, error)
}
// AsBatchReader returns a BatchReader if the reader supports batch operations,
// otherwise returns nil.
func AsBatchReader(r io.ReadWriteCloser) BatchReader {
if br, ok := r.(BatchReader); ok {
return br
}
return nil
}
// BatchEnabler is an optional interface for devices that need explicit
// enabling of batch read support (e.g., setting non-blocking mode).
type BatchEnabler interface {
EnableBatchReading() error
}
// EnableBatchReading enables batch reading on the device if supported.
// Returns nil if the device doesn't support or need explicit enabling.
func EnableBatchReading(d interface{}) error {
if be, ok := d.(BatchEnabler); ok {
return be.EnableBatchReading()
}
return nil
}

View File

@@ -34,6 +34,7 @@ type tun struct {
TXQueueLen int
deviceIndex int
ioctlFd uintptr
nonBlocking bool // true if fd is in non-blocking mode
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
@@ -239,7 +240,7 @@ func (t *tun) SupportsMultiqueue() bool {
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
fd, err := unix.Open("/dev/net/tun", os.O_RDWR|unix.O_NONBLOCK, 0)
if err != nil {
return nil, err
}
@@ -251,9 +252,99 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, err
}
file := os.NewFile(uintptr(fd), "/dev/net/tun")
return &tunBatchReader{fd: fd, device: t.Device}, nil
}
return file, nil
// tunBatchReader implements BatchReader for efficient batch packet reading
type tunBatchReader struct {
fd int
device string
}
func (r *tunBatchReader) Read(b []byte) (int, error) {
// Use poll to wait for data, then read
for {
n, err := unix.Read(r.fd, b)
if err == nil {
return n, nil
}
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
// Wait for data
pfds := []unix.PollFd{{Fd: int32(r.fd), Events: unix.POLLIN}}
_, err = unix.Poll(pfds, -1)
if err != nil {
if err == unix.EINTR {
continue
}
return 0, err
}
continue
}
return n, err
}
}
func (r *tunBatchReader) Write(b []byte) (int, error) {
return unix.Write(r.fd, b)
}
func (r *tunBatchReader) Close() error {
return unix.Close(r.fd)
}
// ReadBatch reads up to len(packets) packets from the TUN device.
// It drains all available packets without blocking, using poll() only
// when no packets have been read yet.
func (r *tunBatchReader) ReadBatch(packets [][]byte, sizes []int) (int, error) {
count := 0
maxPackets := len(packets)
if len(sizes) < maxPackets {
maxPackets = len(sizes)
}
for count < maxPackets {
n, err := unix.Read(r.fd, packets[count])
if err == nil && n > 0 {
sizes[count] = n
count++
continue
}
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
// No more packets available
if count > 0 {
// We have some packets, return them
return count, nil
}
// No packets yet, wait for at least one
pfds := []unix.PollFd{{Fd: int32(r.fd), Events: unix.POLLIN}}
_, err = unix.Poll(pfds, -1)
if err != nil {
if err == unix.EINTR {
continue
}
return 0, err
}
continue
}
if err != nil {
if count > 0 {
// Return what we have
return count, nil
}
return 0, err
}
if n == 0 {
if count > 0 {
return count, nil
}
return 0, io.EOF
}
}
return count, nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
@@ -284,6 +375,111 @@ func (t *tun) Write(b []byte) (int, error) {
}
}
// EnableBatchReading sets the TUN fd to non-blocking mode to enable batch reading.
// This should be called before using ReadBatch.
func (t *tun) EnableBatchReading() error {
if t.nonBlocking {
return nil
}
err := unix.SetNonblock(t.fd, true)
if err != nil {
return err
}
t.nonBlocking = true
return nil
}
// Read overrides the default Read to handle non-blocking mode
func (t *tun) Read(b []byte) (int, error) {
if !t.nonBlocking {
// Use the embedded ReadWriteCloser for blocking reads
return t.ReadWriteCloser.Read(b)
}
// Non-blocking read with poll
for {
n, err := unix.Read(t.fd, b)
if err == nil {
return n, nil
}
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
// Wait for data
pfds := []unix.PollFd{{Fd: int32(t.fd), Events: unix.POLLIN}}
_, err = unix.Poll(pfds, -1)
if err != nil {
if err == unix.EINTR {
continue
}
return 0, err
}
continue
}
return n, err
}
}
// ReadBatch reads up to len(packets) packets from the TUN device.
// EnableBatchReading must be called first.
func (t *tun) ReadBatch(packets [][]byte, sizes []int) (int, error) {
if !t.nonBlocking {
// Fallback to single read if non-blocking not enabled
n, err := t.ReadWriteCloser.Read(packets[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
count := 0
maxPackets := len(packets)
if len(sizes) < maxPackets {
maxPackets = len(sizes)
}
for count < maxPackets {
n, err := unix.Read(t.fd, packets[count])
if err == nil && n > 0 {
sizes[count] = n
count++
continue
}
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
// No more packets available
if count > 0 {
return count, nil
}
// No packets yet, wait for at least one
pfds := []unix.PollFd{{Fd: int32(t.fd), Events: unix.POLLIN}}
_, err = unix.Poll(pfds, -1)
if err != nil {
if err == unix.EINTR {
continue
}
return 0, err
}
continue
}
if err != nil {
if count > 0 {
return count, nil
}
return 0, err
}
if n == 0 {
if count > 0 {
return count, nil
}
return 0, io.EOF
}
}
return count, nil
}
func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device {
o[i] = byte(c)