From b2bc6a09ca9cb30f43b0166b08bb6ec4bcb94570 Mon Sep 17 00:00:00 2001 From: Jay Wren Date: Tue, 11 Nov 2025 15:06:45 -0500 Subject: [PATCH] write in batches --- inside.go | 139 ++++++++++++++++++++++++++++++++++++++++++++ interface.go | 19 +++--- udp/conn.go | 4 ++ udp/udp_darwin.go | 11 ++++ udp/udp_linux.go | 85 +++++++++++++++++++++++++++ udp/udp_linux_64.go | 38 ++++++++++++ 6 files changed, 286 insertions(+), 10 deletions(-) diff --git a/inside.go b/inside.go index d24ed31..d675f78 100644 --- a/inside.go +++ b/inside.go @@ -11,6 +11,145 @@ import ( "github.com/slackhq/nebula/routing" ) +// consumeInsidePackets processes multiple packets in a batch for improved performance +// packets: slice of packet buffers to process +// sizes: slice of packet sizes +// count: number of packets to process +// outs: slice of output buffers (one per packet) with virtio headroom +// q: queue index +// localCache: firewall conntrack cache +func (f *Interface) consumeInsidePackets(packets [][]byte, sizes []int, count int, outs [][]byte, q int, localCache firewall.ConntrackCache) { + // Reusable per-packet state + fwPacket := &firewall.Packet{} + nb := make([]byte, 12, 12) + + // Accumulate encrypted packets for batch sending + batchPackets := make([][]byte, 0, count) + batchAddrs := make([]netip.AddrPort, 0, count) + + // Process each packet in the batch + for i := 0; i < count; i++ { + packet := packets[i][:sizes[i]] + out := outs[i] + + // Inline the consumeInsidePacket logic for better performance + err := newPacket(packet, false, fwPacket) + if err != nil { + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err) + } + continue + } + + // Ignore local broadcast packets + if f.dropLocalBroadcast { + if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) { + continue + } + } + + if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) { + // Immediately forward packets from self to self. + if immediatelyForwardToSelf { + _, err := f.readers[q].Write(packet) + if err != nil { + f.l.WithError(err).Error("Failed to forward to tun") + } + } + continue + } + + // Ignore multicast packets + if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() { + continue + } + + hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { + hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) + }) + + if hostinfo == nil { + f.rejectInside(packet, out, q) + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("vpnAddr", fwPacket.RemoteAddr). + WithField("fwPacket", fwPacket). + Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks") + } + continue + } + + if !ready { + continue + } + + dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) + if dropReason != nil { + f.rejectInside(packet, out, q) + if f.l.Level >= logrus.DebugLevel { + hostinfo.logger(f.l). + WithField("fwPacket", fwPacket). + WithField("reason", dropReason). + Debugln("dropping outbound packet") + } + continue + } + + // Encrypt and prepare packet for batch sending + ci := hostinfo.ConnectionState + if ci.eKey == nil { + continue + } + + // Check if this needs relay - if so, send immediately and skip batching + useRelay := !hostinfo.remote.IsValid() + if useRelay { + // Handle relay sends individually (less common path) + f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, packet, nb, out, q) + continue + } + + // Encrypt the packet for batch sending + if noiseutil.EncryptLockNeeded { + ci.writeLock.Lock() + } + c := ci.messageCounter.Add(1) + out = header.Encode(out, header.Version, header.Message, 0, hostinfo.remoteIndexId, c) + f.connectionManager.Out(hostinfo) + + // Query lighthouse if needed + if hostinfo.lastRebindCount != f.rebindCount { + f.lightHouse.QueryServer(hostinfo.vpnAddrs[0]) + hostinfo.lastRebindCount = f.rebindCount + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter") + } + } + + out, err = ci.eKey.EncryptDanger(out, out, packet, c, nb) + if noiseutil.EncryptLockNeeded { + ci.writeLock.Unlock() + } + if err != nil { + hostinfo.logger(f.l).WithError(err). + WithField("counter", c). + Error("Failed to encrypt outgoing packet") + continue + } + + // Add to batch + batchPackets = append(batchPackets, out) + batchAddrs = append(batchAddrs, hostinfo.remote) + } + + // Send all accumulated packets in one batch + if len(batchPackets) > 0 { + n, err := f.writers[q].WriteMulti(batchPackets, batchAddrs) + if err != nil { + f.l.WithError(err).WithField("sent", n).WithField("total", len(batchPackets)).Error("Failed to send batch") + } + } +} + func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { err := newPacket(packet, false, fwPacket) if err != nil { diff --git a/interface.go b/interface.go index df4c3d3..8180cbf 100644 --- a/interface.go +++ b/interface.go @@ -333,12 +333,13 @@ func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchRe } sizes := make([]int, batchSize) - // Per-packet state (reused across batches) - // Allocate out buffer with virtio header headroom to avoid copies on write - outBuf := make([]byte, virtioNetHdrLen+mtu) - out := outBuf[virtioNetHdrLen:] - fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) + // Allocate output buffers for batch processing (one per packet) + // Each has virtio header headroom to avoid copies on write + outs := make([][]byte, batchSize) + for idx := range outs { + outBuf := make([]byte, virtioNetHdrLen+mtu) + outs[idx] = outBuf[virtioNetHdrLen:] // Slice starting after headroom + } conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) @@ -354,10 +355,8 @@ func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchRe os.Exit(2) } - // Process each packet in the batch - for j := 0; j < n; j++ { - f.consumeInsidePacket(bufs[j][:sizes[j]], fwPacket, nb, out, i, conntrackCache.Get(f.l)) - } + // Process all packets in the batch at once + f.consumeInsidePackets(bufs, sizes, n, outs, i, conntrackCache.Get(f.l)) } } diff --git a/udp/conn.go b/udp/conn.go index 895b0df..8c821d3 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -18,6 +18,7 @@ type Conn interface { LocalAddr() (netip.AddrPort, error) ListenOut(r EncReader) WriteTo(b []byte, addr netip.AddrPort) error + WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) ReloadConfig(c *config.C) Close() error } @@ -36,6 +37,9 @@ func (NoopConn) ListenOut(_ EncReader) { func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { return nil } +func (NoopConn) WriteMulti(_ [][]byte, _ []netip.AddrPort) (int, error) { + return 0, nil +} func (NoopConn) ReloadConfig(_ *config.C) { return } diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index c0c6233..b409d77 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -140,6 +140,17 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error { } } +// WriteMulti sends multiple packets - fallback implementation without sendmmsg +func (u *StdConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) { + for i := range packets { + err := u.WriteTo(packets[i], addrs[i]) + if err != nil { + return i, err + } + } + return len(packets), nil +} + func (u *StdConn) LocalAddr() (netip.AddrPort, error) { a := u.UDPConn.LocalAddr() diff --git a/udp/udp_linux.go b/udp/udp_linux.go index ec0bf64..aec5215 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -194,6 +194,19 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { return u.writeTo6(b, ip) } +func (u *StdConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) { + if len(packets) != len(addrs) { + return 0, fmt.Errorf("packets and addrs length mismatch") + } + if len(packets) == 0 { + return 0, nil + } + if u.isV4 { + return u.writeMulti4(packets, addrs) + } + return u.writeMulti6(packets, addrs) +} + func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { var rsa unix.RawSockaddrInet6 rsa.Family = unix.AF_INET6 @@ -248,6 +261,78 @@ func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { } } +func (u *StdConn) writeMulti4(packets [][]byte, addrs []netip.AddrPort) (int, error) { + msgs, iovecs, names := u.PrepareWriteMessages4(len(packets)) + + for i := range packets { + if !addrs[i].Addr().Is4() { + return i, ErrInvalidIPv6RemoteForSocket + } + + // Setup the packet buffer + iovecs[i].Base = &packets[i][0] + iovecs[i].Len = uint64(len(packets[i])) + + // Setup the destination address + rsa := (*unix.RawSockaddrInet4)(unsafe.Pointer(&names[i][0])) + rsa.Family = unix.AF_INET + rsa.Addr = addrs[i].Addr().As4() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[i].Port()) + } + + for { + n, _, err := unix.Syscall6( + unix.SYS_SENDMMSG, + uintptr(u.sysFd), + uintptr(unsafe.Pointer(&msgs[0])), + uintptr(len(msgs)), + 0, + 0, + 0, + ) + + if err != 0 { + return int(n), &net.OpError{Op: "sendmmsg", Err: err} + } + + return int(n), nil + } +} + +func (u *StdConn) writeMulti6(packets [][]byte, addrs []netip.AddrPort) (int, error) { + msgs, iovecs, names := u.PrepareWriteMessages6(len(packets)) + + for i := range packets { + // Setup the packet buffer + iovecs[i].Base = &packets[i][0] + iovecs[i].Len = uint64(len(packets[i])) + + // Setup the destination address + rsa := (*unix.RawSockaddrInet6)(unsafe.Pointer(&names[i][0])) + rsa.Family = unix.AF_INET6 + rsa.Addr = addrs[i].Addr().As16() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[i].Port()) + } + + for { + n, _, err := unix.Syscall6( + unix.SYS_SENDMMSG, + uintptr(u.sysFd), + uintptr(unsafe.Pointer(&msgs[0])), + uintptr(len(msgs)), + 0, + 0, + 0, + ) + + if err != 0 { + return int(n), &net.OpError{Op: "sendmmsg", Err: err} + } + + return int(n), nil + } +} + func (u *StdConn) ReloadConfig(c *config.C) { b := c.GetInt("listen.read_buffer", 0) if b > 0 { diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 48c5a97..36ce8a4 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -55,3 +55,41 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { return msgs, buffers, names } + +func (u *StdConn) PrepareWriteMessages4(n int) ([]rawMessage, []iovec, [][]byte) { + msgs := make([]rawMessage, n) + iovecs := make([]iovec, n) + names := make([][]byte, n) + + for i := range msgs { + names[i] = make([]byte, unix.SizeofSockaddrInet4) + + // Point to the iovec in the slice + msgs[i].Hdr.Iov = &iovecs[i] + msgs[i].Hdr.Iovlen = 1 + + msgs[i].Hdr.Name = &names[i][0] + msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet4 + } + + return msgs, iovecs, names +} + +func (u *StdConn) PrepareWriteMessages6(n int) ([]rawMessage, []iovec, [][]byte) { + msgs := make([]rawMessage, n) + iovecs := make([]iovec, n) + names := make([][]byte, n) + + for i := range msgs { + names[i] = make([]byte, unix.SizeofSockaddrInet6) + + // Point to the iovec in the slice + msgs[i].Hdr.Iov = &iovecs[i] + msgs[i].Hdr.Iovlen = 1 + + msgs[i].Hdr.Name = &names[i][0] + msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet6 + } + + return msgs, iovecs, names +}