diff --git a/inside.go b/inside.go index 369b86b6..c723b812 100644 --- a/inside.go +++ b/inside.go @@ -9,8 +9,75 @@ import ( "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/routing" + "github.com/slackhq/nebula/udp" ) +// consumeInsidePacketBatched is a variant of consumeInsidePacket that queues +// outgoing packets into pendingPackets instead of sending them immediately. +// The caller is responsible for flushing pendingPackets with WriteBatch. +func (f *Interface) consumeInsidePacketBatched(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache, pendingPackets *[]udp.BatchPacket) { + err := newPacket(packet, false, fwPacket) + if err != nil { + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err) + } + return + } + + // Ignore local broadcast packets + if f.dropLocalBroadcast { + if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) { + return + } + } + + if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) { + if immediatelyForwardToSelf { + _, err := f.readers[q].Write(packet) + if err != nil { + f.l.WithError(err).Error("Failed to forward to tun") + } + } + return + } + + // Ignore multicast packets + if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() { + return + } + + hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { + hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) + }) + + if hostinfo == nil { + f.rejectInside(packet, out, q) + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("vpnAddr", fwPacket.RemoteAddr). + WithField("fwPacket", fwPacket). + Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks") + } + return + } + + if !ready { + return + } + + dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) + if dropReason == nil { + f.sendNoMetricsBatched(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q, pendingPackets) + } else { + f.rejectInside(packet, out, q) + if f.l.Level >= logrus.DebugLevel { + hostinfo.logger(f.l). + WithField("fwPacket", fwPacket). + WithField("reason", dropReason). + Debugln("dropping outbound packet") + } + } +} + func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { err := newPacket(packet, false, fwPacket) if err != nil { @@ -409,3 +476,75 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType } } } + +// sendNoMetricsBatched is like sendNoMetrics but queues the packet for batched sending +// instead of sending immediately. The caller must flush pendingPackets with WriteBatch. +func (f *Interface) sendNoMetricsBatched(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int, pendingPackets *[]udp.BatchPacket) { + if ci.eKey == nil { + return + } + useRelay := !remote.IsValid() && !hostinfo.remote.IsValid() + fullOut := out + + if useRelay { + if len(out) < header.Len { + out = out[:header.Len] + } + out = out[header.Len:] + } + + if noiseutil.EncryptLockNeeded { + ci.writeLock.Lock() + } + c := ci.messageCounter.Add(1) + + out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c) + f.connectionManager.Out(hostinfo) + + if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount { + f.lightHouse.QueryServer(hostinfo.vpnAddrs[0]) + hostinfo.lastRebindCount = f.rebindCount + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter") + } + } + + var err error + out, err = ci.eKey.EncryptDanger(out, out, p, c, nb) + if noiseutil.EncryptLockNeeded { + ci.writeLock.Unlock() + } + if err != nil { + hostinfo.logger(f.l).WithError(err). + WithField("udpAddr", remote).WithField("counter", c). + WithField("attemptedCounter", c). + Error("Failed to encrypt outgoing packet") + return + } + + // Queue the packet for batched sending + var addr netip.AddrPort + if remote.IsValid() { + addr = remote + } else if hostinfo.remote.IsValid() { + addr = hostinfo.remote + } else { + // Relay path - send immediately, not batched + for _, relayIP := range hostinfo.relayState.CopyRelayIps() { + relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP) + if err != nil { + hostinfo.relayState.DeleteRelay(relayIP) + hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetricsBatched failed to find HostInfo") + continue + } + f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true) + break + } + return + } + + // Copy the payload since the buffer will be reused + payload := make([]byte, len(out)) + copy(payload, out) + *pendingPackets = append(*pendingPackets, udp.BatchPacket{Payload: payload, Addr: addr}) +} diff --git a/interface.go b/interface.go index 61b1f228..069a9123 100644 --- a/interface.go +++ b/interface.go @@ -48,6 +48,8 @@ type InterfaceConfig struct { ConntrackCacheTimeout time.Duration l *logrus.Logger + + tunBatchSize int // batch size for TUN read/write batching, 0 to disable } type Interface struct { @@ -86,8 +88,9 @@ type Interface struct { conntrackCacheTimeout time.Duration - writers []udp.Conn - readers []io.ReadWriteCloser + writers []udp.Conn + readers []io.ReadWriteCloser + tunBatchSize int // batch size for TUN read/write batching metricHandshakes metrics.Histogram messageMetrics *MessageMetrics @@ -187,6 +190,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { relayManager: c.relayManager, connectionManager: c.connectionManager, conntrackCacheTimeout: c.ConntrackCacheTimeout, + tunBatchSize: c.tunBatchSize, metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)), messageMetrics: c.MessageMetrics, @@ -244,6 +248,15 @@ func (f *Interface) activate() { f.readers[i] = reader } + // Enable batch reading on all readers if batch size > 1 + if f.tunBatchSize > 1 { + for i := 0; i < f.routines; i++ { + if err := overlay.EnableBatchReading(f.readers[i]); err != nil { + f.l.WithError(err).WithField("routine", i).Warn("Failed to enable batch reading, falling back to single reads") + } + } + } + if err := f.inside.Activate(); err != nil { f.inside.Close() f.l.Fatal(err) @@ -287,13 +300,21 @@ func (f *Interface) listenOut(i int) { func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { runtime.LockOSThread() + conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) + + // Check if batch reading is available and enabled + batchReader := overlay.AsBatchReader(reader) + if batchReader != nil && f.tunBatchSize > 1 { + f.listenInBatched(reader, batchReader, i, conntrackCache) + return + } + + // Fallback to single-packet reading packet := make([]byte, mtu) out := make([]byte, mtu) fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) - conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) - for { n, err := reader.Read(packet) if err != nil { @@ -310,6 +331,54 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { } } +func (f *Interface) listenInBatched(reader io.ReadWriteCloser, batchReader overlay.BatchReader, i int, conntrackCache *firewall.ConntrackCacheTicker) { + batchSize := f.tunBatchSize + + // Pre-allocate buffers for batch reading + packets := make([][]byte, batchSize) + for j := range packets { + packets[j] = make([]byte, mtu) + } + sizes := make([]int, batchSize) + + // Pre-allocate buffers for packet processing + out := make([]byte, mtu) + fwPacket := &firewall.Packet{} + nb := make([]byte, 12, 12) + + // Pre-allocate buffer for batched UDP writes + pendingPackets := make([]udp.BatchPacket, 0, batchSize) + + for { + // Read a batch of packets from TUN + n, err := batchReader.ReadBatch(packets, sizes) + if err != nil { + if errors.Is(err, os.ErrClosed) && f.closed.Load() { + return + } + + f.l.WithError(err).Error("Error while reading outbound packets") + os.Exit(2) + } + + if n == 0 { + continue + } + + // Process all packets in the batch + cache := conntrackCache.Get(f.l) + for j := 0; j < n; j++ { + f.consumeInsidePacketBatched(packets[j][:sizes[j]], fwPacket, nb, out, i, cache, &pendingPackets) + } + + // Flush all pending UDP writes + if len(pendingPackets) > 0 { + f.writers[i].WriteBatch(pendingPackets) + pendingPackets = pendingPackets[:0] + } + } +} + func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { c.RegisterReloadCallback(f.reloadFirewall) c.RegisterReloadCallback(f.reloadSendRecvError) diff --git a/main.go b/main.go index 17aaa548..ba90d0dc 100644 --- a/main.go +++ b/main.go @@ -250,6 +250,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg punchy: punchy, ConntrackCacheTimeout: conntrackCacheTimeout, l: l, + tunBatchSize: c.GetInt("tun.batch", 64), } var ifce *Interface diff --git a/overlay/device.go b/overlay/device.go index b6077aba..d71567cb 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -16,3 +16,38 @@ type Device interface { SupportsMultiqueue() bool NewMultiQueueReader() (io.ReadWriteCloser, error) } + +// BatchReader is an optional interface that devices can implement +// to support reading multiple packets in a single batch operation. +// This can significantly reduce syscall overhead under high load. +type BatchReader interface { + // ReadBatch reads up to len(packets) packets into the provided buffers. + // Each packet is read into packets[i] and its length is stored in sizes[i]. + // Returns the number of packets read, or an error. + // A return of (0, nil) indicates no packets were available (non-blocking). + ReadBatch(packets [][]byte, sizes []int) (int, error) +} + +// AsBatchReader returns a BatchReader if the reader supports batch operations, +// otherwise returns nil. +func AsBatchReader(r io.ReadWriteCloser) BatchReader { + if br, ok := r.(BatchReader); ok { + return br + } + return nil +} + +// BatchEnabler is an optional interface for devices that need explicit +// enabling of batch read support (e.g., setting non-blocking mode). +type BatchEnabler interface { + EnableBatchReading() error +} + +// EnableBatchReading enables batch reading on the device if supported. +// Returns nil if the device doesn't support or need explicit enabling. +func EnableBatchReading(d interface{}) error { + if be, ok := d.(BatchEnabler); ok { + return be.EnableBatchReading() + } + return nil +} diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 7e4aa418..719f5b67 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -34,6 +34,7 @@ type tun struct { TXQueueLen int deviceIndex int ioctlFd uintptr + nonBlocking bool // true if fd is in non-blocking mode Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] @@ -239,7 +240,7 @@ func (t *tun) SupportsMultiqueue() bool { } func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) + fd, err := unix.Open("/dev/net/tun", os.O_RDWR|unix.O_NONBLOCK, 0) if err != nil { return nil, err } @@ -251,9 +252,99 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, err } - file := os.NewFile(uintptr(fd), "/dev/net/tun") + return &tunBatchReader{fd: fd, device: t.Device}, nil +} - return file, nil +// tunBatchReader implements BatchReader for efficient batch packet reading +type tunBatchReader struct { + fd int + device string +} + +func (r *tunBatchReader) Read(b []byte) (int, error) { + // Use poll to wait for data, then read + for { + n, err := unix.Read(r.fd, b) + if err == nil { + return n, nil + } + if err == unix.EAGAIN || err == unix.EWOULDBLOCK { + // Wait for data + pfds := []unix.PollFd{{Fd: int32(r.fd), Events: unix.POLLIN}} + _, err = unix.Poll(pfds, -1) + if err != nil { + if err == unix.EINTR { + continue + } + return 0, err + } + continue + } + return n, err + } +} + +func (r *tunBatchReader) Write(b []byte) (int, error) { + return unix.Write(r.fd, b) +} + +func (r *tunBatchReader) Close() error { + return unix.Close(r.fd) +} + +// ReadBatch reads up to len(packets) packets from the TUN device. +// It drains all available packets without blocking, using poll() only +// when no packets have been read yet. +func (r *tunBatchReader) ReadBatch(packets [][]byte, sizes []int) (int, error) { + count := 0 + maxPackets := len(packets) + if len(sizes) < maxPackets { + maxPackets = len(sizes) + } + + for count < maxPackets { + n, err := unix.Read(r.fd, packets[count]) + if err == nil && n > 0 { + sizes[count] = n + count++ + continue + } + + if err == unix.EAGAIN || err == unix.EWOULDBLOCK { + // No more packets available + if count > 0 { + // We have some packets, return them + return count, nil + } + // No packets yet, wait for at least one + pfds := []unix.PollFd{{Fd: int32(r.fd), Events: unix.POLLIN}} + _, err = unix.Poll(pfds, -1) + if err != nil { + if err == unix.EINTR { + continue + } + return 0, err + } + continue + } + + if err != nil { + if count > 0 { + // Return what we have + return count, nil + } + return 0, err + } + + if n == 0 { + if count > 0 { + return count, nil + } + return 0, io.EOF + } + } + + return count, nil } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { @@ -284,6 +375,111 @@ func (t *tun) Write(b []byte) (int, error) { } } +// EnableBatchReading sets the TUN fd to non-blocking mode to enable batch reading. +// This should be called before using ReadBatch. +func (t *tun) EnableBatchReading() error { + if t.nonBlocking { + return nil + } + err := unix.SetNonblock(t.fd, true) + if err != nil { + return err + } + t.nonBlocking = true + return nil +} + +// Read overrides the default Read to handle non-blocking mode +func (t *tun) Read(b []byte) (int, error) { + if !t.nonBlocking { + // Use the embedded ReadWriteCloser for blocking reads + return t.ReadWriteCloser.Read(b) + } + + // Non-blocking read with poll + for { + n, err := unix.Read(t.fd, b) + if err == nil { + return n, nil + } + if err == unix.EAGAIN || err == unix.EWOULDBLOCK { + // Wait for data + pfds := []unix.PollFd{{Fd: int32(t.fd), Events: unix.POLLIN}} + _, err = unix.Poll(pfds, -1) + if err != nil { + if err == unix.EINTR { + continue + } + return 0, err + } + continue + } + return n, err + } +} + +// ReadBatch reads up to len(packets) packets from the TUN device. +// EnableBatchReading must be called first. +func (t *tun) ReadBatch(packets [][]byte, sizes []int) (int, error) { + if !t.nonBlocking { + // Fallback to single read if non-blocking not enabled + n, err := t.ReadWriteCloser.Read(packets[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil + } + + count := 0 + maxPackets := len(packets) + if len(sizes) < maxPackets { + maxPackets = len(sizes) + } + + for count < maxPackets { + n, err := unix.Read(t.fd, packets[count]) + if err == nil && n > 0 { + sizes[count] = n + count++ + continue + } + + if err == unix.EAGAIN || err == unix.EWOULDBLOCK { + // No more packets available + if count > 0 { + return count, nil + } + // No packets yet, wait for at least one + pfds := []unix.PollFd{{Fd: int32(t.fd), Events: unix.POLLIN}} + _, err = unix.Poll(pfds, -1) + if err != nil { + if err == unix.EINTR { + continue + } + return 0, err + } + continue + } + + if err != nil { + if count > 0 { + return count, nil + } + return 0, err + } + + if n == 0 { + if count > 0 { + return count, nil + } + return 0, io.EOF + } + } + + return count, nil +} + func (t *tun) deviceBytes() (o [16]byte) { for i, c := range t.Device { o[i] = byte(c)