From c9a695c2bf4dcb7f8f301a3a8ec8195d411496fd Mon Sep 17 00:00:00 2001 From: Ryan Date: Thu, 6 Nov 2025 10:56:53 -0500 Subject: [PATCH] try with sendmmsg merged back --- inside.go | 106 ++++++++++++++++++++------------------ interface.go | 83 +++++++++++++++++++++++++++++- udp/conn.go | 9 ++++ udp/udp_darwin.go | 11 ++++ udp/udp_generic.go | 11 ++++ udp/udp_linux.go | 112 +++++++++++++++++++++++++++++++++++++++++ udp/udp_linux_32.go | 10 ++++ udp/udp_linux_64.go | 10 ++++ udp/udp_rio_windows.go | 11 ++++ udp/udp_tester.go | 11 ++++ 10 files changed, 324 insertions(+), 50 deletions(-) diff --git a/inside.go b/inside.go index d24ed31..bba769f 100644 --- a/inside.go +++ b/inside.go @@ -11,19 +11,19 @@ import ( "github.com/slackhq/nebula/routing" ) -func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, queue func(netip.AddrPort, int), q int, localCache firewall.ConntrackCache) bool { err := newPacket(packet, false, fwPacket) if err != nil { if f.l.Level >= logrus.DebugLevel { f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err) } - return + return false } // Ignore local broadcast packets if f.dropLocalBroadcast { if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) { - return + return false } } @@ -40,12 +40,12 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } // Otherwise, drop. On linux, we should never see these packets - Linux // routes packets from the nebula addr to the nebula addr through the loopback device. - return + return false } // Ignore multicast packets if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() { - return + return false } hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { @@ -59,26 +59,26 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet WithField("fwPacket", fwPacket). Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks") } - return + return false } if !ready { - return + return false } dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { - f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) - - } else { - f.rejectInside(packet, out, q) - if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l). - WithField("fwPacket", fwPacket). - WithField("reason", dropReason). - Debugln("dropping outbound packet") - } + return f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, queue, q) } + + f.rejectInside(packet, out, q) + if f.l.Level >= logrus.DebugLevel { + hostinfo.logger(f.l). + WithField("fwPacket", fwPacket). + WithField("reason", dropReason). + Debugln("dropping outbound packet") + } + return false } func (f *Interface) rejectInside(packet []byte, out []byte, q int) { @@ -117,7 +117,7 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo * return } - f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) + _ = f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, nil, q) } // Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established @@ -228,7 +228,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp return } - f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0) + _ = f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, nil, 0) } // SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr @@ -258,12 +258,12 @@ func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.Messag func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) { f.messageMetrics.Tx(t, st, 1) - f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, 0) + _ = f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, nil, 0) } func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) { f.messageMetrics.Tx(t, st, 1) - f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0) + _ = f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, nil, 0) } // SendVia sends a payload through a Relay tunnel. No authentication or encryption is done @@ -331,9 +331,12 @@ func (f *Interface) SendVia(via *HostInfo, f.connectionManager.RelayUsed(relay.LocalIndex) } -func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) { +// sendNoMetrics encrypts and writes/queues an outbound packet. It returns true +// when the payload has been handed to a caller-provided queue (meaning the +// caller is responsible for flushing it later). +func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, queue func(netip.AddrPort, int), q int) bool { if ci.eKey == nil { - return + return false } useRelay := !remote.IsValid() && !hostinfo.remote.IsValid() fullOut := out @@ -380,32 +383,39 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType WithField("udpAddr", remote).WithField("counter", c). WithField("attemptedCounter", c). Error("Failed to encrypt outgoing packet") - return + return false } - if remote.IsValid() { - err = f.writers[q].WriteTo(out, remote) - if err != nil { - hostinfo.logger(f.l).WithError(err). - WithField("udpAddr", remote).Error("Failed to write outgoing packet") - } - } else if hostinfo.remote.IsValid() { - err = f.writers[q].WriteTo(out, hostinfo.remote) - if err != nil { - hostinfo.logger(f.l).WithError(err). - WithField("udpAddr", remote).Error("Failed to write outgoing packet") - } - } else { - // Try to send via a relay - for _, relayIP := range hostinfo.relayState.CopyRelayIps() { - relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP) - if err != nil { - hostinfo.relayState.DeleteRelay(relayIP) - hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo") - continue - } - f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true) - break - } + dest := remote + if !dest.IsValid() { + dest = hostinfo.remote } + + if dest.IsValid() { + if queue != nil { + queue(dest, len(out)) + return true + } + + err = f.writers[q].WriteTo(out, dest) + if err != nil { + hostinfo.logger(f.l).WithError(err). + WithField("udpAddr", dest).Error("Failed to write outgoing packet") + } + return false + } + + // Try to send via a relay + for _, relayIP := range hostinfo.relayState.CopyRelayIps() { + relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP) + if err != nil { + hostinfo.relayState.DeleteRelay(relayIP) + hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo") + continue + } + f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true) + break + } + + return false } diff --git a/interface.go b/interface.go index 9c75ed2..43d958f 100644 --- a/interface.go +++ b/interface.go @@ -29,6 +29,7 @@ const ( outboundBatchSizeDefault = 32 batchFlushIntervalDefault = 50 * time.Microsecond maxOutstandingBatchesDefault = 1028 + sendBatchSizeDefault = 32 ) type InterfaceConfig struct { @@ -120,12 +121,21 @@ type Interface struct { packetBatchPool sync.Pool outboundBatchPool sync.Pool + sendPool sync.Pool + sendBatchSize int + inboundBatchSize int outboundBatchSize int batchFlushInterval time.Duration maxOutstandingPerChan int } +type outboundSend struct { + buf *[]byte + length int + addr netip.AddrPort +} + type packetBatch struct { packets []*packet.Packet } @@ -194,6 +204,48 @@ func (f *Interface) releaseOutboundBatch(b *outboundBatch) { f.outboundBatchPool.Put(b) } +func (f *Interface) getSendBuffer() *[]byte { + if v := f.sendPool.Get(); v != nil { + buf := v.(*[]byte) + *buf = (*buf)[:0] + return buf + } + b := make([]byte, mtu) + return &b +} + +func (f *Interface) releaseSendBuffer(buf *[]byte) { + if buf == nil { + return + } + *buf = (*buf)[:0] + f.sendPool.Put(buf) +} + +func (f *Interface) flushSendQueue(q int, pending *[]outboundSend) { + if len(*pending) == 0 { + return + } + + batch := make([]udp.BatchPacket, len(*pending)) + for i, entry := range *pending { + batch[i] = udp.BatchPacket{ + Payload: (*entry.buf)[:entry.length], + Addr: entry.addr, + } + } + + sent, err := f.writers[q].WriteBatch(batch) + if err != nil { + f.l.WithError(err).WithField("sent", sent).Error("Failed to batch send packets") + } + + for _, entry := range *pending { + f.releaseSendBuffer(entry.buf) + } + *pending = (*pending)[:0] +} + type EncWriter interface { SendVia(via *HostInfo, relay *Relay, @@ -316,6 +368,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { outboundBatchSize: bc.OutboundBatchSize, batchFlushInterval: bc.FlushInterval, maxOutstandingPerChan: bc.MaxOutstandingPerChan, + sendBatchSize: bc.OutboundBatchSize, } for i := 0; i < c.routines; i++ { @@ -340,6 +393,11 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { return newOutboundBatch(ifce.outboundBatchSize) }} + ifce.sendPool = sync.Pool{New: func() any { + buf := make([]byte, mtu) + return &buf + }} + ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryWait.Store(int64(c.reQueryWait)) @@ -539,18 +597,39 @@ func (f *Interface) workerOut(i int, ctx context.Context) { conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) fwPacket1 := &firewall.Packet{} nb1 := make([]byte, 12, 12) - result1 := make([]byte, mtu) + pending := make([]outboundSend, 0, f.sendBatchSize) for { select { case batch := <-f.outbound[i]: for _, data := range batch.payloads { - f.consumeInsidePacket(*data, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l)) + sendBuf := f.getSendBuffer() + buf := (*sendBuf)[:0] + queue := func(addr netip.AddrPort, length int) { + pending = append(pending, outboundSend{ + buf: sendBuf, + length: length, + addr: addr, + }) + if len(pending) >= f.sendBatchSize { + f.flushSendQueue(i, &pending) + } + } + sent := f.consumeInsidePacket(*data, fwPacket1, nb1, buf, queue, i, conntrackCache.Get(f.l)) + if !sent { + f.releaseSendBuffer(sendBuf) + } *data = (*data)[:mtu] f.outPool.Put(data) } f.releaseOutboundBatch(batch) + if len(pending) > 0 { + f.flushSendQueue(i, &pending) + } case <-ctx.Done(): + if len(pending) > 0 { + f.flushSendQueue(i, &pending) + } f.wg.Done() return } diff --git a/udp/conn.go b/udp/conn.go index 340d30c..2a5d7e4 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -18,10 +18,16 @@ type Conn interface { LocalAddr() (netip.AddrPort, error) ListenOut(r EncReader) error WriteTo(b []byte, addr netip.AddrPort) error + WriteBatch(pkts []BatchPacket) (int, error) ReloadConfig(c *config.C) Close() error } +type BatchPacket struct { + Payload []byte + Addr netip.AddrPort +} + type NoopConn struct{} func (NoopConn) Rebind() error { @@ -36,6 +42,9 @@ func (NoopConn) ListenOut(_ EncReader) error { func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { return nil } +func (NoopConn) WriteBatch(_ []BatchPacket) (int, error) { + return 0, nil +} func (NoopConn) ReloadConfig(_ *config.C) { return } diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index c03b2d1..5d37181 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -140,6 +140,17 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error { } } +func (u *StdConn) WriteBatch(pkts []BatchPacket) (int, error) { + sent := 0 + for _, pkt := range pkts { + if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil { + return sent, err + } + sent++ + } + return sent, nil +} + func (u *StdConn) LocalAddr() (netip.AddrPort, error) { a := u.UDPConn.LocalAddr() diff --git a/udp/udp_generic.go b/udp/udp_generic.go index 6538520..f877b0b 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -42,6 +42,17 @@ func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error { return err } +func (u *GenericConn) WriteBatch(pkts []BatchPacket) (int, error) { + sent := 0 + for _, pkt := range pkts { + if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil { + return sent, err + } + sent++ + } + return sent, nil +} + func (u *GenericConn) LocalAddr() (netip.AddrPort, error) { a := u.UDPConn.LocalAddr() diff --git a/udp/udp_linux.go b/udp/udp_linux.go index dc39cc6..cfb8470 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -343,6 +343,118 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { return u.writeTo6(b, ip) } +func (u *StdConn) WriteBatch(pkts []BatchPacket) (int, error) { + if len(pkts) == 0 { + return 0, nil + } + + msgs := make([]rawMessage, 0, len(pkts)) + iovs := make([]iovec, 0, len(pkts)) + names := make([][unix.SizeofSockaddrInet6]byte, 0, len(pkts)) + + sent := 0 + + for _, pkt := range pkts { + if len(pkt.Payload) == 0 { + sent++ + continue + } + + if u.enableGSO && pkt.Addr.IsValid() { + if err := u.queueGSOPacket(pkt.Payload, pkt.Addr); err == nil { + sent++ + continue + } else if !errors.Is(err, errGSOFallback) { + return sent, err + } + } + + if !pkt.Addr.IsValid() { + if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil { + return sent, err + } + sent++ + continue + } + + msgs = append(msgs, rawMessage{}) + iovs = append(iovs, iovec{}) + names = append(names, [unix.SizeofSockaddrInet6]byte{}) + + idx := len(msgs) - 1 + msg := &msgs[idx] + iov := &iovs[idx] + name := &names[idx] + + setIovecSlice(iov, pkt.Payload) + msg.Hdr.Iov = iov + msg.Hdr.Iovlen = 1 + setRawMessageControl(msg, nil) + msg.Hdr.Flags = 0 + + nameLen, err := u.encodeSockaddr(name[:], pkt.Addr) + if err != nil { + return sent, err + } + msg.Hdr.Name = &name[0] + msg.Hdr.Namelen = nameLen + } + + if len(msgs) == 0 { + return sent, nil + } + + offset := 0 + for offset < len(msgs) { + n, _, errno := unix.Syscall6( + unix.SYS_SENDMMSG, + uintptr(u.sysFd), + uintptr(unsafe.Pointer(&msgs[offset])), + uintptr(len(msgs)-offset), + 0, + 0, + 0, + ) + + if errno != 0 { + if errno == unix.EINTR { + continue + } + return sent + offset, &net.OpError{Op: "sendmmsg", Err: errno} + } + + if n == 0 { + break + } + offset += int(n) + } + + return sent + len(msgs), nil +} + +func (u *StdConn) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) { + if u.isV4 { + if !addr.Addr().Is4() { + return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote") + } + var sa unix.RawSockaddrInet4 + sa.Family = unix.AF_INET + sa.Addr = addr.Addr().As4() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port()) + size := unix.SizeofSockaddrInet4 + copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:]) + return uint32(size), nil + } + + var sa unix.RawSockaddrInet6 + sa.Family = unix.AF_INET6 + sa.Addr = addr.Addr().As16() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port()) + size := unix.SizeofSockaddrInet6 + copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:]) + return uint32(size), nil +} + func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { var rsa unix.RawSockaddrInet6 rsa.Family = unix.AF_INET6 diff --git a/udp/udp_linux_32.go b/udp/udp_linux_32.go index 89e8002..1081117 100644 --- a/udp/udp_linux_32.go +++ b/udp/udp_linux_32.go @@ -77,3 +77,13 @@ func getRawMessageFlags(msg *rawMessage) int { func setCmsgLen(h *unix.Cmsghdr, l int) { h.Len = uint32(l) } + +func setIovecSlice(iov *iovec, b []byte) { + if len(b) == 0 { + iov.Base = nil + iov.Len = 0 + return + } + iov.Base = &b[0] + iov.Len = uint32(len(b)) +} diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 83e1b34..0a7abef 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -80,3 +80,13 @@ func getRawMessageFlags(msg *rawMessage) int { func setCmsgLen(h *unix.Cmsghdr, l int) { h.Len = uint64(l) } + +func setIovecSlice(iov *iovec, b []byte) { + if len(b) == 0 { + iov.Base = nil + iov.Len = 0 + return + } + iov.Base = &b[0] + iov.Len = uint64(len(b)) +} diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index ba96a20..4825b4f 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -304,6 +304,17 @@ func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error { return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) } +func (u *RIOConn) WriteBatch(pkts []BatchPacket) (int, error) { + sent := 0 + for _, pkt := range pkts { + if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil { + return sent, err + } + sent++ + } + return sent, nil +} + func (u *RIOConn) LocalAddr() (netip.AddrPort, error) { sa, err := windows.Getsockname(u.sock) if err != nil { diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 8d5e6c1..edf76dd 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -106,6 +106,17 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { return nil } +func (u *TesterConn) WriteBatch(pkts []BatchPacket) (int, error) { + sent := 0 + for _, pkt := range pkts { + if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil { + return sent, err + } + sent++ + } + return sent, nil +} + func (u *TesterConn) ListenOut(r EncReader) { for { p, ok := <-u.RxPackets