diff --git a/handshake_manager.go b/handshake_manager.go index 87257028..1384b346 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -974,6 +974,7 @@ func (hm *HandshakeManager) continueHandshake(via ViaSender, hh *HandshakeHostIn nb := make([]byte, 12, 12) out := make([]byte, mtu) for _, cp := range hh.packetStore { + //todo use a sendbatcher cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out) } f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore))) diff --git a/inside.go b/inside.go index 68cb38ec..de23713e 100644 --- a/inside.go +++ b/inside.go @@ -2,6 +2,7 @@ package nebula import ( "context" + "io" "log/slog" "net/netip" @@ -9,10 +10,16 @@ import ( "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/noiseutil" + "github.com/slackhq/nebula/overlay/batch" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" ) -func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) consumeInsidePacket(pkt tio.Packet, fwPacket *firewall.Packet, nb []byte, sendBatch batch.TxBatcher, rejectBuf []byte, q int, localCache firewall.ConntrackCache) { + // borrowed: pkt.Bytes is owned by the originating tio.Queue and is + // only valid until the next Read on that queue. If you must keep + // the packet, use pkt.Clone() to detach it + packet := pkt.Bytes err := newPacket(packet, false, fwPacket) if err != nil { if f.l.Enabled(context.Background(), slog.LevelDebug) { @@ -37,7 +44,10 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet // routes packets from the Nebula addr to the Nebula addr through the Nebula // TUN device. if immediatelyForwardToSelf { - _, err := f.readers[q].Write(packet) + err := tio.SegmentSuperpacket(pkt, func(seg []byte) error { + _, werr := f.readers[q].Write(seg) + return werr + }) if err != nil { f.l.Error("Failed to forward to tun", "error", err) } @@ -53,11 +63,23 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { - hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) + // borrowed: SegmentSuperpacket builds each segment in the kernel-supplied pkt + // bytes underneath. cachePacket explicitly copies its argument (handshake_manager.go cachePacket), + // so retaining segments past the loop is safe. + err := tio.SegmentSuperpacket(pkt, func(seg []byte) error { + hh.cachePacket(f.l, header.Message, 0, seg, f.sendMessageNow, f.cachedPacketMetrics) + return nil + }) + if err != nil && f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Failed to segment superpacket for handshake cache", + "error", err, + "vpnAddr", fwPacket.RemoteAddr, + ) + } }) if hostinfo == nil { - f.rejectInside(packet, out, q) + f.rejectInside(packet, rejectBuf, q) if f.l.Enabled(context.Background(), slog.LevelDebug) { f.l.Debug("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks", "vpnAddr", fwPacket.RemoteAddr, @@ -73,10 +95,9 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { - f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) - + f.sendInsideMessage(hostinfo, pkt, nb, sendBatch) } else { - f.rejectInside(packet, out, q) + f.rejectInside(packet, rejectBuf, q) if f.l.Enabled(context.Background(), slog.LevelDebug) { hostinfo.logger(f.l).Debug("dropping outbound packet", "fwPacket", fwPacket, @@ -86,6 +107,124 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } } +func (f *Interface) sendInsideEncrypt(hostinfo *HostInfo, ci *ConnectionState, seg, scratch, nb []byte) []byte { + if noiseutil.EncryptLockNeeded { + ci.writeLock.Lock() + } + c := ci.messageCounter.Add(1) + + out := header.Encode(scratch, header.Version, header.Message, 0, hostinfo.remoteIndexId, c) + f.connectionManager.Out(hostinfo) + + out, encErr := ci.eKey.EncryptDanger(out, out, seg, c, nb) + if noiseutil.EncryptLockNeeded { + ci.writeLock.Unlock() + } + if encErr != nil { + hostinfo.logger(f.l).Error("Failed to encrypt outgoing packet", + "error", encErr, + "udpAddr", hostinfo.remote, + "counter", c, + ) + // Skip this segment; the rest of the superpacket can still + // go out — TCP will retransmit anything we drop here. + return nil + } + + return out +} + +// sendInsideMessage encrypts a firewall-approved inside packet (or every +// segment of a TSO/USO superpacket) into the caller's batch slot for +// later sendmmsg flush. Segmentation is fused with encryption here so the +// kernel-supplied superpacket bytes never get written into a separate +// scratch arena: SegmentSuperpacket builds each segment's plaintext in +// segScratch[:segLen] in turn, and we encrypt directly into a fresh +// SendBatch slot. +func (f *Interface) sendInsideMessage(hostinfo *HostInfo, pkt tio.Packet, nb []byte, sendBatch batch.TxBatcher) { + ci := hostinfo.ConnectionState + if ci.eKey == nil { + return + } + + if hostinfo.lastRebindCount != f.rebindCount { + //NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is + // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. + f.lightHouse.QueryServer(hostinfo.vpnAddrs[0]) + hostinfo.lastRebindCount = f.rebindCount + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Lighthouse update triggered for punch due to rebind counter", + "vpnAddrs", hostinfo.vpnAddrs, + ) + } + } + + if !hostinfo.remote.IsValid() { //the relay path + //first, find our relay hostinfo: + var relayHostInfo *HostInfo + var relay *Relay + var err error + for _, relayIP := range hostinfo.relayState.CopyRelayIps() { + relayHostInfo, relay, err = f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP) + if err != nil { + hostinfo.relayState.DeleteRelay(relayIP) + hostinfo.logger(f.l).Info("sendNoMetrics failed to find HostInfo", + "relay", relayIP, + "error", err, + ) + continue + } + break + } + if relayHostInfo == nil || relay == nil { + //failure already logged + return + } + + err = tio.SegmentSuperpacket(pkt, func(seg []byte) error { + //relay header + header + plaintext + AEAD tag (16 bytes for both AES-GCM and ChaCha20-Poly1305) + relay tag + scratch := sendBatch.Reserve(header.Len + header.Len + len(seg) + 16 + 16) + + innerPacket := f.sendInsideEncrypt(hostinfo, ci, seg, scratch[header.Len:], nb) + if innerPacket == nil { + return nil + } + + //now we need to do a relay-encrypt: + toSend, err := f.prepareSendVia(relayHostInfo, relay, innerPacket, nb, scratch, true) + if err != nil { + //already logged + return nil + } + + sendBatch.Commit(toSend, relayHostInfo.remote, 0) + return nil + }) + if err != nil { + hostinfo.logger(f.l).Error("Failed to segment superpacket for relay send", "error", err) + } + return + } + + err := tio.SegmentSuperpacket(pkt, func(seg []byte) error { + // header + plaintext + AEAD tag (16 bytes for both AES-GCM and ChaCha20-Poly1305) + scratch := sendBatch.Reserve(header.Len + len(seg) + 16) + + out := f.sendInsideEncrypt(hostinfo, ci, seg, scratch, nb) + if out == nil { + return nil + } + + sendBatch.Commit(out, hostinfo.remote, 0) + return nil + }) + if err != nil { + hostinfo.logger(f.l).Error("Failed to segment superpacket for send", + "error", err, + ) + } +} + func (f *Interface) rejectInside(packet []byte, out []byte, q int) { if !f.firewall.InSendReject { return @@ -275,21 +414,13 @@ func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *C f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0) } -// SendVia sends a payload through a Relay tunnel. No authentication or encryption is done -// to the payload for the ultimate target host, making this a useful method for sending -// handshake messages to peers through relay tunnels. -// via is the HostInfo through which the message is relayed. -// ad is the plaintext data to authenticate, but not encrypt -// nb is a buffer used to store the nonce value, re-used for performance reasons. -// out is a buffer used to store the result of the Encrypt operation -// q indicates which writer to use to send the packet. -func (f *Interface) SendVia(via *HostInfo, +func (f *Interface) prepareSendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool, -) { +) ([]byte, error) { if noiseutil.EncryptLockNeeded { // NOTE: for goboring AESGCMTLS we need to lock because of the nonce check via.ConnectionState.writeLock.Lock() @@ -311,7 +442,7 @@ func (f *Interface) SendVia(via *HostInfo, "headerLen", len(out), "cipherOverhead", via.ConnectionState.eKey.Overhead(), ) - return + return nil, io.ErrShortBuffer } // The header bytes are written to the 'out' slice; Grow the slice to hold the header and associated data payload. @@ -331,13 +462,36 @@ func (f *Interface) SendVia(via *HostInfo, } if err != nil { via.logger(f.l).Info("Failed to EncryptDanger in sendVia", "error", err) + return nil, err + } + f.connectionManager.RelayUsed(relay.LocalIndex) + return out, nil +} + +// SendVia sends a payload through a Relay tunnel. No authentication or encryption is done +// to the payload for the ultimate target host, making this a useful method for sending +// handshake messages to peers through relay tunnels. +// via is the HostInfo through which the message is relayed. +// ad is the plaintext data to authenticate, but not encrypt +// nb is a buffer used to store the nonce value, re-used for performance reasons. +// out is a buffer used to store the result of the Encrypt operation +// q indicates which writer to use to send the packet. +func (f *Interface) SendVia(via *HostInfo, + relay *Relay, + ad, + nb, + out []byte, + nocopy bool, +) { + toSend, err := f.prepareSendVia(via, relay, ad, nb, out, nocopy) + if err != nil { + via.logger(f.l).Info("Failed to prepareSendVia", "error", err) return } - err = f.writers[0].WriteTo(out, via.remote) + err = f.writers[0].WriteTo(toSend, via.remote) if err != nil { via.logger(f.l).Info("Failed to WriteTo in sendVia", "error", err) } - f.connectionManager.RelayUsed(relay.LocalIndex) } func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) { diff --git a/interface.go b/interface.go index f4f87415..209019eb 100644 --- a/interface.go +++ b/interface.go @@ -12,13 +12,14 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" - "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/wire" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/overlay" + "github.com/slackhq/nebula/overlay/batch" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/udp" ) @@ -90,7 +91,11 @@ type Interface struct { ctx context.Context writers []udp.Conn readers []tio.Queue - wg sync.WaitGroup + // batchers is one per tun queue, wrapping readers[i]. + // decryptToTun sends plaintext into the batch.RxBatcher; + // listenOut calls its Flush at the end of each UDP recvmmsg batch. + batchers []batch.RxBatcher + wg sync.WaitGroup // fatalErr holds the first unexpected reader error that caused shutdown. // nil means "no fatal error" (yet) @@ -189,6 +194,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { version: c.version, writers: make([]udp.Conn, c.routines), readers: make([]tio.Queue, c.routines), + batchers: make([]batch.RxBatcher, c.routines), myVpnNetworks: cs.myVpnNetworks, myVpnNetworksTable: cs.myVpnNetworksTable, myVpnAddrs: cs.myVpnAddrs, @@ -254,6 +260,10 @@ func (f *Interface) activate() error { } } f.readers = f.inside.Readers() + for i := range f.readers { + arena := batch.NewArena(batch.DefaultPassthroughArenaCap) + f.batchers[i] = batch.NewPassthrough(f.readers[i], arena) + } f.wg.Add(1) // for us to wait on Close() to return if err = f.inside.Activate(); err != nil { @@ -310,14 +320,22 @@ func (f *Interface) listenOut(i int) { ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout) lhh := f.lightHouse.NewRequestHandler() - plaintext := make([]byte, udp.MTU) h := &header.H{} fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) - err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { - f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get()) - }) + listener := func(fromUdpAddr netip.AddrPort, payload []byte, meta udp.RxMeta) { + plaintext := f.batchers[i].Reserve(len(payload)) + f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(), meta) + } + + flusher := func() { + if err := f.batchers[i].Flush(); err != nil { + f.l.Error("Failed to flush tun coalescer", "error", err) + } + } + + err := li.ListenOut(listener, flusher) if err != nil && !f.closed.Load() { f.l.Error("Error while reading inbound packet, closing", "error", err) @@ -332,6 +350,9 @@ func (f *Interface) listenIn(reader tio.Queue, q int) { // TODO get the amount of bonus info from the reader packets := make([]wire.TunPacket, 1) out := make([]byte, mtu) + rejectBuf := make([]byte, mtu) + arenaSize := batch.SendBatchCap * (udp.MTU + 32) + sb := batch.NewSendBatch(f.writers[q], batch.SendBatchCap, batch.NewArena(arenaSize)) fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) @@ -346,9 +367,13 @@ func (f *Interface) listenIn(reader tio.Queue, q int) { } break } + ctCache := conntrackCache.Get() - for i := range n { - f.consumeInsidePacket(packets[i].Bytes, fwPacket, nb, out, q, ctCache) + for i := range n{ + f.consumeInsidePacket(packets[i], fwPacket, nb, sb, rejectBuf, q, ctCache) + } + if err := sb.Flush(); err != nil { + f.l.Error("Failed to write outgoing batch", "error", err, "writer", q) } } diff --git a/outside.go b/outside.go index 17013ed3..cf079fd7 100644 --- a/outside.go +++ b/outside.go @@ -13,6 +13,7 @@ import ( "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/udp" "golang.org/x/net/ipv4" ) @@ -22,7 +23,7 @@ const ( var ErrOutOfWindow = errors.New("out of window packet") -func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, meta udp.RxMeta) { err := h.Parse(packet) if err != nil { // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors @@ -110,8 +111,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, // Relay packets are special if isMessageRelay { - f.handleOutsideRelayPacket(hostinfo, via, out, packet, h, fwPacket, lhf, nb, q, localCache) - + f.handleOutsideRelayPacket(hostinfo, via, out, packet, h, fwPacket, lhf, nb, q, localCache, meta) return } @@ -135,7 +135,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, case header.Message: switch h.Subtype { case header.MessageNone: - f.handleOutsideMessagePacket(hostinfo, out, packet, fwPacket, nb, q, localCache) + f.handleOutsideMessagePacket(hostinfo, out, packet, fwPacket, nb, q, localCache, meta) default: hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message subtype seen", "from", via, "header", h) return @@ -168,7 +168,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, } } -func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, meta udp.RxMeta) { // The entire body is sent as AD, not encrypted. // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. // The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's @@ -211,7 +211,7 @@ func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, relay: relay, IsRelayed: true, } - f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) + f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, meta) case ForwardingType: // Find the target HostInfo relay object targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) @@ -229,7 +229,7 @@ func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, switch targetRelay.Type { case ForwardingType: // Forward this packet through the relay tunnel - // Find the target HostInfo + // Find the target HostInfo //todo it would potentially be nice to batch these f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false) case TerminalType: hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") @@ -512,7 +512,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] return out, nil } -func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, meta udp.RxMeta) { err := newPacket(out, true, fwPacket) if err != nil { hostinfo.logger(f.l).Warn("Error while validating inbound packet", @@ -536,7 +536,7 @@ func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, p return } - _, err = f.readers[q].Write(out) + err = f.batchers[q].Commit(out) if err != nil { f.l.Error("Failed to write to tun", "error", err) } diff --git a/overlay/batch/batch.go b/overlay/batch/batch.go new file mode 100644 index 00000000..d171d136 --- /dev/null +++ b/overlay/batch/batch.go @@ -0,0 +1,28 @@ +package batch + +import "net/netip" + +type RxBatcher interface { + // Reserve creates a pkt to borrow + Reserve(sz int) []byte + // Commit borrows pkt. The caller must keep pkt valid until the next Flush + Commit(pkt []byte) error + // Flush emits every queued packet in arrival order. Returns the + // first error observed; keeps draining so one bad packet doesn't hold up + // the rest. After Flush returns, borrowed payload slices may be recycled. + Flush() error +} + +type TxBatcher interface { + // Reserve creates a pkt to borrow + Reserve(sz int) []byte + // Commit borrows pkt and records its destination plus the 2-bit + // IP-level ECN codepoint to set on the outer (carrier) header. The + // caller must keep pkt valid until the next Flush. Pass 0 (Not-ECT) + // to leave the outer ECN field unset. + Commit(pkt []byte, dst netip.AddrPort, outerECN byte) + // Flush emits every queued packet via the underlying batch writer in + // arrival order. Returns an errors.Join of one or more errors. After Flush returns, + // borrowed payload slices may be recycled. + Flush() error +} diff --git a/overlay/batch/coalesce_core.go b/overlay/batch/coalesce_core.go new file mode 100644 index 00000000..0213ea9d --- /dev/null +++ b/overlay/batch/coalesce_core.go @@ -0,0 +1,42 @@ +package batch + +// Arena is an injectable byte-slab that hands out non-overlapping borrowed +// slices via Reserve and releases them in bulk via Reset. Coalescers take +// an *Arena at construction so the caller controls the slab lifetime and +// can share one slab across multiple coalescers (MultiCoalescer hands the +// same *Arena to every lane so the lanes don't carry their own backings). +// +// Reserve borrows; the slice is valid until the next Reset. The slab grows +// (by allocating a fresh, larger backing array) if a Reserve doesn't fit; +// pre-size the arena via NewArena to avoid that path on the hot path. +type Arena struct { + buf []byte +} + +// NewArena returns an Arena with a pre-allocated backing of the given +// capacity. Pass 0 if you don't intend to call Reserve (e.g. a test that +// only feeds the coalescer pre-made []byte packets via Commit). +func NewArena(capacity int) *Arena { + return &Arena{buf: make([]byte, 0, capacity)} +} + +// Reserve hands out a non-overlapping sz-byte slice from the arena. If the +// request doesn't fit the current backing, a fresh, larger backing is +// allocated; already-borrowed slices reference the old backing and remain +// valid until Reset. +func (a *Arena) Reserve(sz int) []byte { + if len(a.buf)+sz > cap(a.buf) { + newCap := max(cap(a.buf)*2, sz) + a.buf = make([]byte, 0, newCap) + } + start := len(a.buf) + a.buf = a.buf[:start+sz] + return a.buf[start : start+sz : start+sz] +} + +// Reset releases every slice handed out since the last Reset. Callers must +// not use any previously-borrowed slice after this returns. The underlying +// backing array is retained so subsequent Reserves don't re-allocate. +func (a *Arena) Reset() { + a.buf = a.buf[:0] +} diff --git a/overlay/batch/passthrough.go b/overlay/batch/passthrough.go new file mode 100644 index 00000000..aea90fa4 --- /dev/null +++ b/overlay/batch/passthrough.go @@ -0,0 +1,52 @@ +package batch + +import ( + "io" + + "github.com/slackhq/nebula/udp" +) + +// Passthrough is a RxBatcher that doesn't batch anything, it just accumulates and then sends packets. +type Passthrough struct { + out io.Writer + slots [][]byte + arena *Arena + cursor int +} + +const passthroughBaseNumSlots = 128 + +// DefaultPassthroughArenaCap is the recommended arena capacity for a +// standalone Passthrough batcher: 128 slots × udp.MTU ≈ 1.1 MiB. +const DefaultPassthroughArenaCap = passthroughBaseNumSlots * udp.MTU + +func NewPassthrough(w io.Writer, arena *Arena) *Passthrough { + return &Passthrough{ + out: w, + slots: make([][]byte, 0, passthroughBaseNumSlots), + arena: arena, + } +} + +func (p *Passthrough) Reserve(sz int) []byte { + return p.arena.Reserve(sz) +} + +func (p *Passthrough) Commit(pkt []byte) error { + p.slots = append(p.slots, pkt) + return nil +} + +func (p *Passthrough) Flush() error { + var firstErr error + for _, s := range p.slots { + _, err := p.out.Write(s) + if err != nil && firstErr == nil { + firstErr = err + } + } + clear(p.slots) + p.slots = p.slots[:0] + p.arena.Reset() + return firstErr +} diff --git a/overlay/batch/tx_batch.go b/overlay/batch/tx_batch.go new file mode 100644 index 00000000..599bc306 --- /dev/null +++ b/overlay/batch/tx_batch.go @@ -0,0 +1,65 @@ +package batch + +import ( + "net/netip" + + "github.com/slackhq/nebula/udp" +) + +const SendBatchCap = 128 + +// DefaultSendBatchArenaCap is the recommended arena capacity for a +// standalone SendBatch: 128 slots × (udp.MTU + 32) ≈ 1.1 MiB. The +32 covers +// the nebula header + AEAD tag tacked onto each plaintext segment. +const DefaultSendBatchArenaCap = SendBatchCap * (udp.MTU + 32) + +// batchWriter is the minimal subset of udp.Conn needed by SendBatch to flush. +type batchWriter interface { + WriteBatch(bufs [][]byte, addrs []netip.AddrPort, outerECNs []byte) error +} + +// SendBatch accumulates encrypted UDP packets and flushes them via WriteBatch. +// One SendBatch is owned by each listenIn goroutine; no locking is needed. +// Slot bytes are borrowed from the injected Arena and remain valid until +// Flush, which Resets the arena. +type SendBatch struct { + out batchWriter + bufs [][]byte + dsts []netip.AddrPort + ecns []byte + arena *Arena +} + +// NewSendBatch makes a SendBatch with batchCap slots backed by arena. +func NewSendBatch(out batchWriter, batchCap int, arena *Arena) *SendBatch { + return &SendBatch{ + out: out, + bufs: make([][]byte, 0, batchCap), + dsts: make([]netip.AddrPort, 0, batchCap), + ecns: make([]byte, 0, batchCap), + arena: arena, + } +} + +func (b *SendBatch) Reserve(sz int) []byte { + return b.arena.Reserve(sz) +} + +func (b *SendBatch) Commit(pkt []byte, dst netip.AddrPort, outerECN byte) { + b.bufs = append(b.bufs, pkt) + b.dsts = append(b.dsts, dst) + b.ecns = append(b.ecns, outerECN) +} + +func (b *SendBatch) Flush() error { + var err error + if len(b.bufs) > 0 { + err = b.out.WriteBatch(b.bufs, b.dsts, b.ecns) + } + clear(b.bufs) + b.bufs = b.bufs[:0] + b.dsts = b.dsts[:0] + b.ecns = b.ecns[:0] + b.arena.Reset() + return err +} diff --git a/overlay/batch/tx_batch_test.go b/overlay/batch/tx_batch_test.go new file mode 100644 index 00000000..59b06e58 --- /dev/null +++ b/overlay/batch/tx_batch_test.go @@ -0,0 +1,124 @@ +package batch + +import ( + "net/netip" + "testing" +) + +type fakeBatchWriter struct { + bufs [][]byte + addrs []netip.AddrPort + ecns []byte +} + +func (w *fakeBatchWriter) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, ecns []byte) error { + // Snapshot — SendBatch.Flush nils its slot pointers right after WriteBatch + // returns, so tests must capture data before that happens. + w.bufs = make([][]byte, len(bufs)) + for i, b := range bufs { + cp := make([]byte, len(b)) + copy(cp, b) + w.bufs[i] = cp + } + w.addrs = append(w.addrs[:0], addrs...) + w.ecns = append(w.ecns[:0], ecns...) + return nil +} + +func TestSendBatchReserveCommitFlush(t *testing.T) { + fw := &fakeBatchWriter{} + b := NewSendBatch(fw, 4, NewArena(32)) + + ap := netip.MustParseAddrPort("10.0.0.1:4242") + for i := 0; i < 4; i++ { + slot := b.Reserve(32) + if cap(slot) != 32 { + t.Fatalf("slot %d: cap=%d want 32", i, cap(slot)) + } + pkt := append(slot[:0], byte(i), byte(i+1), byte(i+2)) + b.Commit(pkt, ap, 0) + } + if err := b.Flush(); err != nil { + t.Fatalf("Flush: %v", err) + } + if len(fw.bufs) != 4 { + t.Fatalf("WriteBatch got %d bufs want 4", len(fw.bufs)) + } + for i, buf := range fw.bufs { + if len(buf) != 3 || buf[0] != byte(i) { + t.Errorf("buf %d: %x", i, buf) + } + if fw.addrs[i] != ap { + t.Errorf("addr %d: got %v want %v", i, fw.addrs[i], ap) + } + } + + // Flush again with nothing committed — should be a no-op. + fw.bufs = nil + if err := b.Flush(); err != nil { + t.Fatalf("empty Flush: %v", err) + } + if fw.bufs != nil { + t.Fatalf("empty Flush triggered WriteBatch") + } + + // Reuse after Flush. + slot := b.Reserve(32) + if cap(slot) != 32 { + t.Fatalf("after Flush Reserve wrong cap: %d", cap(slot)) + } +} + +func TestSendBatchSlotsDoNotOverlap(t *testing.T) { + fw := &fakeBatchWriter{} + b := NewSendBatch(fw, 3, NewArena(8)) + ap := netip.MustParseAddrPort("10.0.0.1:80") + + for i := 0; i < 3; i++ { + s := b.Reserve(8) + pkt := append(s[:0], byte(0xA0+i), byte(0xB0+i)) + b.Commit(pkt, ap, 0) + } + if err := b.Flush(); err != nil { + t.Fatalf("Flush: %v", err) + } + + for i, buf := range fw.bufs { + if buf[0] != byte(0xA0+i) || buf[1] != byte(0xB0+i) { + t.Errorf("slot %d corrupted: %x", i, buf) + } + } +} + +func TestSendBatchGrowPreservesCommitted(t *testing.T) { + fw := &fakeBatchWriter{} + // Tiny initial backing forces a grow on the second Reserve. + b := NewSendBatch(fw, 1, NewArena(4)) + ap := netip.MustParseAddrPort("10.0.0.1:80") + + s1 := b.Reserve(4) + pkt1 := append(s1[:0], 0x11, 0x22, 0x33, 0x44) + b.Commit(pkt1, ap, 0) + + s2 := b.Reserve(8) // exceeds remaining cap, triggers grow + pkt2 := append(s2[:0], 0xA, 0xB, 0xC, 0xD, 0xE) + b.Commit(pkt2, ap, 0) + + // pkt1 must still be intact even though backing reallocated. + if pkt1[0] != 0x11 || pkt1[3] != 0x44 { + t.Fatalf("first packet corrupted by grow: %x", pkt1) + } + + if err := b.Flush(); err != nil { + t.Fatalf("Flush: %v", err) + } + if len(fw.bufs) != 2 { + t.Fatalf("got %d bufs want 2", len(fw.bufs)) + } + if fw.bufs[0][0] != 0x11 || fw.bufs[0][3] != 0x44 { + t.Errorf("first packet on the wire: %x", fw.bufs[0]) + } + if fw.bufs[1][0] != 0xA || fw.bufs[1][4] != 0xE { + t.Errorf("second packet on the wire: %x", fw.bufs[1]) + } +} diff --git a/overlay/device.go b/overlay/device.go index cff3ac7d..8044ee75 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -8,6 +8,10 @@ import ( "github.com/slackhq/nebula/routing" ) +// defaultBatchBufSize is the per-Queue scratch size for Read on backends +// that don't do TSO segmentation. 65535 covers any single IP packet. +const defaultBatchBufSize = 65535 + type Device interface { io.Closer Activate() error diff --git a/overlay/tio/segment.go b/overlay/tio/segment.go new file mode 100644 index 00000000..67648ad2 --- /dev/null +++ b/overlay/tio/segment.go @@ -0,0 +1,12 @@ +package tio + +import "fmt" + +// SegmentSuperpacket invokes fn once per segment of pkt. +// This is a stub implementation that does not actually support segmentation +func SegmentSuperpacket(pkt Packet, fn func(seg []byte) error) error { + if pkt.GSO.IsSuperpacket() { + return fmt.Errorf("tio: GSO superpacket on platform without segmentation support") + } + return fn(pkt.Bytes) +} diff --git a/overlay/tio/tio.go b/overlay/tio/tio.go index bc0b05df..92aac452 100644 --- a/overlay/tio/tio.go +++ b/overlay/tio/tio.go @@ -18,7 +18,12 @@ type QueueSet interface { // Capabilities advertises which kernel offload features a Queue successfully negotiated. // Callers consult this to decide which coalescers to wire onto the write path. type Capabilities struct { - //none yet! + // TSO means the FD was opened with IFF_VNET_HDR and the kernel agreed + // to TUN_F_TSO4|TSO6 — i.e. WriteGSO with GSOProtoTCP is safe. + TSO bool + // USO means the kernel additionally agreed to TUN_F_USO4|USO6, so + // WriteGSO with GSOProtoUDP is safe. Linux ≥ 6.2. + USO bool } // Queue is a readable/writable Poll queue. One Queue is driven by a single @@ -40,3 +45,78 @@ type Queue interface { // or the zero value when q does not advertise any. Capabilities() Capabilities } + +// GSOInfo describes a kernel-supplied superpacket sitting in Packet.Bytes. +// The zero value means "not a superpacket" — Bytes is one regular IP +// datagram and no segmentation is required. +type GSOInfo struct { + // Size is the GSO segment size: max payload bytes per segment + // (== TCP MSS for TSO, == UDP payload chunk for USO). Zero means + // not a superpacket. + Size uint16 + // HdrLen is the total L3+L4 header length within Bytes (already + // corrected via correctHdrLen, so safe to slice on). + HdrLen uint16 + // CsumStart is the L4 header offset inside Bytes (== L3 header + // length). + CsumStart uint16 + // Proto picks the L4 protocol (TCP or UDP) so the segmenter knows + // which checksum/header layout to apply. + Proto GSOProto +} + +// GSOProto selects the L4 protocol for a GSO superpacket. Determines which +// VIRTIO_NET_HDR_GSO_* type the writer stamps and which checksum offset +// inside the transport header virtio NEEDS_CSUM expects. +type GSOProto uint8 + +const ( + GSOProtoNone GSOProto = iota + GSOProtoTCP + GSOProtoUDP +) + +// GSOWriter is implemented by Queues that can emit a TCP or UDP superpacket +// assembled from a header prefix plus one or more borrowed payload +// fragments, in a single vectored write (writev with a leading +// virtio_net_hdr). This lets the coalescer avoid copying payload bytes +// between the caller's decrypt buffer and the TUN. Backends without GSO +// support do not implement this interface and coalescing is skipped. +// +// hdr contains the IPv4/IPv6 header prefix (mutable - callers will have +// filled in total length and IP csum). transportHdr is the TCP or UDP +// header (mutable - the L4 checksum field must hold the pseudo-header +// partial, single-fold not inverted, per virtio NEEDS_CSUM semantics). +// pays are non-overlapping payload fragments whose concatenation is the +// full superpacket payload; they are read-only from the writer's +// perspective and must remain valid until the call returns. Every segment +// in pays except possibly the last is exactly the same size. proto picks +// the L4 protocol so the writer knows which GSOType / CsumOffset to set. +// +// Callers should also consult CapsProvider (via SupportsGSO or +// QueueCapabilities) for the per-protocol negotiated capability; an +// implementation of GSOWriter is necessary but not sufficient since USO +// may not have been negotiated even when TSO was. +type GSOWriter interface { + WriteGSO(hdr []byte, transportHdr []byte, pays [][]byte, proto GSOProto) error +} + +// SupportsGSO reports whether w implements GSOWriter and the underlying +// queue advertises the negotiated capability for `want`. A writer that +// implements GSOWriter but not CapsProvider is treated as permissive +// (used by tests and fakes that don't negotiate). +func SupportsGSO(w Queue, want GSOProto) (GSOWriter, bool) { + gw, ok := w.(GSOWriter) + if !ok { + return nil, false + } + caps := w.Capabilities() + switch want { + case GSOProtoTCP: + return gw, caps.TSO + case GSOProtoUDP: + return gw, caps.USO + default: + return gw, false + } +} \ No newline at end of file diff --git a/udp/conn.go b/udp/conn.go index 30d89dec..37277054 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -8,16 +8,49 @@ import ( const MTU = 9001 +// MaxWriteBatch is the largest batch any Conn.WriteBatch implementation is +// required to accept. Callers SHOULD NOT pass more than this per call; Linux +// backends preallocate sendmmsg scratch sized to this value, so exceeding it +// only costs additional sendmmsg chunks within a single WriteBatch call. +const MaxWriteBatch = 128 + +// RxMeta carries per-packet metadata extracted from the RX path (ancillary +// data, kernel offload state, etc.) and passed to EncReader callbacks. +// Backends that do not produce a particular signal leave its zero value. +// +// OuterECN is the 2-bit IP-level ECN codepoint stamped on the carrier +// datagram (extracted from IP_TOS / IPV6_TCLASS cmsg on Linux). Zero +// means Not-ECT, which is also the value backends without ECN RX support +// supply on every packet. +type RxMeta struct { + OuterECN byte +} + type EncReader func( addr netip.AddrPort, payload []byte, + meta RxMeta, ) type Conn interface { Rebind() error LocalAddr() (netip.AddrPort, error) - ListenOut(r EncReader) error + // ListenOut invokes r for each received packet. On batch-capable + // backends (recvmmsg), flush is called after each batch is fully + // delivered — callers use it to flush per-batch accumulators such as + // TUN write coalescers. Single-packet backends call flush after each + // packet. flush must not be nil. + ListenOut(r EncReader, flush func()) error WriteTo(b []byte, addr netip.AddrPort) error + // WriteBatch sends a contiguous batch of packets, each with its own + // destination. bufs and addrs must have the same length. outerECNs may + // be nil (treated as all-zero / Not-ECT); when non-nil it must have the + // same length as bufs, and outerECNs[i] is the 2-bit IP-level ECN + // codepoint to set on packet i's outer header. Linux uses sendmmsg(2) + // for a single syscall and attaches the value as IP_TOS / IPV6_TCLASS + // cmsg; other backends ignore it. Returns on the first error; callers + // may observe a partial send if some packets went out before the error. + WriteBatch(bufs [][]byte, addrs []netip.AddrPort, outerECNs []byte) error ReloadConfig(c *config.C) SupportsMultipleReaders() bool Close() error @@ -31,7 +64,7 @@ func (NoopConn) Rebind() error { func (NoopConn) LocalAddr() (netip.AddrPort, error) { return netip.AddrPort{}, nil } -func (NoopConn) ListenOut(_ EncReader) error { +func (NoopConn) ListenOut(_ EncReader, _ func()) error { return nil } func (NoopConn) SupportsMultipleReaders() bool { @@ -40,6 +73,9 @@ func (NoopConn) SupportsMultipleReaders() bool { func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { return nil } +func (NoopConn) WriteBatch(_ [][]byte, _ []netip.AddrPort, _ []byte) error { + return nil +} func (NoopConn) ReloadConfig(_ *config.C) { return } diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 8a4f5b18..e6ecea8f 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -140,6 +140,15 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error { } } +func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error { + for i, b := range bufs { + if err := u.WriteTo(b, addrs[i]); err != nil { + return err + } + } + return nil +} + func (u *StdConn) LocalAddr() (netip.AddrPort, error) { a := u.UDPConn.LocalAddr() @@ -165,7 +174,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() { return func() {} } -func (u *StdConn) ListenOut(r EncReader) error { +func (u *StdConn) ListenOut(r EncReader, flush func()) error { buffer := make([]byte, MTU) for { @@ -179,7 +188,8 @@ func (u *StdConn) ListenOut(r EncReader) error { u.l.Error("unexpected udp socket receive error", "error", err) } - r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) + r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n], RxMeta{}) + flush() } } diff --git a/udp/udp_generic.go b/udp/udp_generic.go index 131eb73b..0c254906 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -44,6 +44,15 @@ func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error { return err } +func (u *GenericConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error { + for i, b := range bufs { + if _, err := u.UDPConn.WriteToUDPAddrPort(b, addrs[i]); err != nil { + return err + } + } + return nil +} + func (u *GenericConn) LocalAddr() (netip.AddrPort, error) { a := u.UDPConn.LocalAddr() @@ -73,7 +82,7 @@ type rawMessage struct { Len uint32 } -func (u *GenericConn) ListenOut(r EncReader) error { +func (u *GenericConn) ListenOut(r EncReader, flush func()) error { buffer := make([]byte, MTU) var lastRecvErr time.Time @@ -93,7 +102,8 @@ func (u *GenericConn) ListenOut(r EncReader) error { continue } - r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) + r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n], RxMeta{}) + flush() } } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 3e2d726a..6465be32 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -24,6 +24,22 @@ type StdConn struct { isV4 bool l *slog.Logger batch int + + // sendmmsg scratch. Each queue has its own StdConn, so no locking is + // needed. Sized to MaxWriteBatch at construction; WriteBatch chunks + // larger inputs. + writeMsgs []rawMessage + writeIovs []iovec + writeNames [][]byte + + // sendmmsg(2) callback state. sendmmsgCB is bound once in NewListener + // to the sendmmsgRun method value so passing it to rawConn.Write does + // not allocate a fresh closure per send; sendmmsgN/Sent/Errno carry + // the inputs and outputs across the call without escaping locals. + sendmmsgCB func(fd uintptr) bool + sendmmsgN int + sendmmsgSent int + sendmmsgErrno syscall.Errno } func setReusePort(network, address string, c syscall.RawConn) error { @@ -70,9 +86,23 @@ func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) } out.isV4 = af == unix.AF_INET + out.prepareWriteMessages(MaxWriteBatch) + out.sendmmsgCB = out.sendmmsgRun + return out, nil } +func (u *StdConn) prepareWriteMessages(n int) { + u.writeMsgs = make([]rawMessage, n) + u.writeIovs = make([]iovec, n) + u.writeNames = make([][]byte, n) + + for i := range u.writeMsgs { + u.writeNames[i] = make([]byte, unix.SizeofSockaddrInet6) + u.writeMsgs[i].Hdr.Name = &u.writeNames[i][0] + } +} + func (u *StdConn) SupportsMultipleReaders() bool { return true } @@ -171,7 +201,7 @@ func recvmmsg(fd uintptr, msgs []rawMessage) (int, bool, error) { return int(n), true, nil } -func (u *StdConn) listenOutSingle(r EncReader) error { +func (u *StdConn) listenOutSingle(r EncReader, flush func()) error { var err error var n int var from netip.AddrPort @@ -183,16 +213,33 @@ func (u *StdConn) listenOutSingle(r EncReader) error { return err } from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port()) - r(from, buffer[:n]) + // listenOutSingle uses ReadFromUDPAddrPort which discards cmsgs, + // so the outer ECN field is not visible on this path. Zero RxMeta + // (Not-ECT) means RFC 6040 combine is a no-op. + r(from, buffer[:n], RxMeta{}) + flush() } } -func (u *StdConn) listenOutBatch(r EncReader) error { +// readSockaddr decodes the source address out of a recvmmsg name buffer +func (u *StdConn) readSockaddr(name []byte) netip.AddrPort { var ip netip.Addr + // It's ok to skip the ok check here, the slicing is the only error that can occur and it will panic + if u.isV4 { + ip, _ = netip.AddrFromSlice(name[4:8]) + } else { + ip, _ = netip.AddrFromSlice(name[8:24]) + } + return netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(name[2:4])) +} + +func (u *StdConn) listenOutBatch(r EncReader, flush func()) error { var n int var operr error - msgs, buffers, names := u.PrepareRawMessages(u.batch) + bufSize := MTU + cmsgSpace := 0 + msgs, buffers, names, _ := u.PrepareRawMessages(u.batch, bufSize, cmsgSpace) //reader needs to capture variables from this function, since it's used as a lambda with rawConn.Read //defining it outside the loop so it gets re-used @@ -211,22 +258,18 @@ func (u *StdConn) listenOutBatch(r EncReader) error { } for i := 0; i < n; i++ { - // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic - if u.isV4 { - ip, _ = netip.AddrFromSlice(names[i][4:8]) - } else { - ip, _ = netip.AddrFromSlice(names[i][8:24]) - } - r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) + r(u.readSockaddr(names[i]), buffers[i][:msgs[i].Len], RxMeta{}) } + + flush() } } -func (u *StdConn) ListenOut(r EncReader) error { +func (u *StdConn) ListenOut(r EncReader, flush func()) error { if u.batch == 1 { - return u.listenOutSingle(r) + return u.listenOutSingle(r, flush) } else { - return u.listenOutBatch(r) + return u.listenOutBatch(r, flush) } } @@ -235,6 +278,120 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { return err } +// WriteBatch sends bufs via sendmmsg(2) using the preallocated scratch on +// StdConn. If supported, consecutive packets to the same destination with +// matching segment sizes (all but possibly the last) are coalesced into a +// single mmsghdr entry +// +// If sendmmsg returns an error and zero entries went out, we fall back to +// per-packet WriteTo for that chunk so the caller still gets best-effort +// delivery. On a partial send we resume at the first un-acked entry on +// the next iteration. +func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error { + for i := 0; i < len(bufs); { + chunk := min(len(bufs)-i, len(u.writeMsgs)) + + for k := 0; k < chunk; k++ { + u.writeIovs[k].Base = &bufs[i+k][0] + setIovLen(&u.writeIovs[k], len(bufs[i+k])) + + nlen, err := writeSockaddr(u.writeNames[k], addrs[i+k], u.isV4) + if err != nil { + return err + } + + hdr := &u.writeMsgs[k].Hdr + hdr.Iov = &u.writeIovs[k] + setMsgIovlen(hdr, 1) + hdr.Namelen = uint32(nlen) + } + + sent, serr := u.sendmmsg(chunk) + if serr != nil && sent <= 0 { + // sendmmsg returns -1 / sent=0 when entry 0 itself failed; log + // that entry's destination and fall back to per-packet WriteTo + // for the whole chunk so the caller still gets best-effort + // delivery without duplicating packets the kernel accepted. + u.l.Warn("sendmmsg failed, falling back to per-packet WriteTo", + "err", serr, + "entries", chunk, + "entry0_dst", addrs[i], + "isV4", u.isV4, + ) + for k := 0; k < chunk; k++ { + if werr := u.WriteTo(bufs[i+k], addrs[i+k]); werr != nil { + return werr + } + } + i += chunk + continue + } + i += sent + } + return nil +} + +// sendmmsg issues sendmmsg(2) against the first n entries of u.writeMsgs. +// The bound u.sendmmsgCB is passed to rawConn.Write so no closure is +// allocated per call; inputs and outputs ride on the StdConn fields. +func (u *StdConn) sendmmsg(n int) (int, error) { + u.sendmmsgN = n + u.sendmmsgSent = 0 + u.sendmmsgErrno = 0 + if err := u.rawConn.Write(u.sendmmsgCB); err != nil { + return u.sendmmsgSent, err + } + if u.sendmmsgErrno != 0 { + return u.sendmmsgSent, &net.OpError{Op: "sendmmsg", Err: u.sendmmsgErrno} + } + return u.sendmmsgSent, nil +} + +// sendmmsgRun is the rawConn.Write callback. It is bound once into +// u.sendmmsgCB at construction so it stays alloc-free in the hot path; +// inputs (sendmmsgN) and outputs (sendmmsgSent, sendmmsgErrno) ride on +// the receiver rather than escaping locals. +func (u *StdConn) sendmmsgRun(fd uintptr) bool { + r1, _, errno := unix.Syscall6(unix.SYS_SENDMMSG, fd, + uintptr(unsafe.Pointer(&u.writeMsgs[0])), uintptr(u.sendmmsgN), + 0, 0, 0, + ) + if errno == syscall.EAGAIN || errno == syscall.EWOULDBLOCK { + return false + } + u.sendmmsgSent = int(r1) + u.sendmmsgErrno = errno + return true +} + +// writeSockaddr encodes addr into buf (which must be at least +// SizeofSockaddrInet6 bytes). Returns the number of bytes used. If isV4 is +// true and addr is not a v4 (or v4-in-v6) address, returns an error. +func writeSockaddr(buf []byte, addr netip.AddrPort, isV4 bool) (int, error) { + ap := addr.Addr().Unmap() + if isV4 { + if !ap.Is4() { + return 0, ErrInvalidIPv6RemoteForSocket + } + // struct sockaddr_in: { sa_family_t(2), in_port_t(2, BE), in_addr(4), zero(8) } + // sa_family is host endian. + binary.NativeEndian.PutUint16(buf[0:2], unix.AF_INET) + binary.BigEndian.PutUint16(buf[2:4], addr.Port()) + ip4 := ap.As4() + copy(buf[4:8], ip4[:]) + clear(buf[8:16]) + return unix.SizeofSockaddrInet4, nil + } + // struct sockaddr_in6: { sa_family_t(2), in_port_t(2, BE), flowinfo(4), in6_addr(16), scope_id(4) } + binary.NativeEndian.PutUint16(buf[0:2], unix.AF_INET6) + binary.BigEndian.PutUint16(buf[2:4], addr.Port()) + binary.NativeEndian.PutUint32(buf[4:8], 0) + ip6 := addr.Addr().As16() + copy(buf[8:24], ip6[:]) + binary.NativeEndian.PutUint32(buf[24:28], 0) + return unix.SizeofSockaddrInet6, nil +} + func (u *StdConn) ReloadConfig(c *config.C) { b := c.GetInt("listen.read_buffer", 0) if b > 0 { diff --git a/udp/udp_linux_32.go b/udp/udp_linux_32.go index de8f1cdf..0f153a49 100644 --- a/udp/udp_linux_32.go +++ b/udp/udp_linux_32.go @@ -30,13 +30,18 @@ type rawMessage struct { Len uint32 } -func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { +func (u *StdConn) PrepareRawMessages(n, bufSize, cmsgSpace int) ([]rawMessage, [][]byte, [][]byte, []byte) { msgs := make([]rawMessage, n) buffers := make([][]byte, n) names := make([][]byte, n) + var cmsgs []byte + if cmsgSpace > 0 { + cmsgs = make([]byte, n*cmsgSpace) + } + for i := range msgs { - buffers[i] = make([]byte, MTU) + buffers[i] = make([]byte, bufSize) names[i] = make([]byte, unix.SizeofSockaddrInet6) vs := []iovec{ @@ -48,7 +53,28 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { msgs[i].Hdr.Name = &names[i][0] msgs[i].Hdr.Namelen = uint32(len(names[i])) + + if cmsgSpace > 0 { + msgs[i].Hdr.Control = &cmsgs[i*cmsgSpace] + msgs[i].Hdr.Controllen = uint32(cmsgSpace) + } } - return msgs, buffers, names + return msgs, buffers, names, cmsgs +} + +func setIovLen(v *iovec, n int) { + v.Len = uint32(n) +} + +func setMsgIovlen(m *msghdr, n int) { + m.Iovlen = uint32(n) +} + +func setMsgControllen(m *msghdr, n int) { + m.Controllen = uint32(n) +} + +func setCmsgLen(h *unix.Cmsghdr, n int) { + h.Len = uint32(n) } diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 48c5a978..dc373538 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -33,13 +33,18 @@ type rawMessage struct { Pad0 [4]byte } -func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { +func (u *StdConn) PrepareRawMessages(n, bufSize, cmsgSpace int) ([]rawMessage, [][]byte, [][]byte, []byte) { msgs := make([]rawMessage, n) buffers := make([][]byte, n) names := make([][]byte, n) + var cmsgs []byte + if cmsgSpace > 0 { + cmsgs = make([]byte, n*cmsgSpace) + } + for i := range msgs { - buffers[i] = make([]byte, MTU) + buffers[i] = make([]byte, bufSize) names[i] = make([]byte, unix.SizeofSockaddrInet6) vs := []iovec{ @@ -51,7 +56,28 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { msgs[i].Hdr.Name = &names[i][0] msgs[i].Hdr.Namelen = uint32(len(names[i])) + + if cmsgSpace > 0 { + msgs[i].Hdr.Control = &cmsgs[i*cmsgSpace] + msgs[i].Hdr.Controllen = uint64(cmsgSpace) + } } - return msgs, buffers, names + return msgs, buffers, names, cmsgs +} + +func setIovLen(v *iovec, n int) { + v.Len = uint64(n) +} + +func setMsgIovlen(m *msghdr, n int) { + m.Iovlen = uint64(n) +} + +func setMsgControllen(m *msghdr, n int) { + m.Controllen = uint64(n) +} + +func setCmsgLen(h *unix.Cmsghdr, n int) { + h.Len = uint64(n) } diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index d110af19..a95ad3d0 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -140,7 +140,7 @@ func (u *RIOConn) bind(l *slog.Logger, sa windows.Sockaddr) error { return nil } -func (u *RIOConn) ListenOut(r EncReader) error { +func (u *RIOConn) ListenOut(r EncReader, flush func()) error { buffer := make([]byte, MTU) var lastRecvErr time.Time @@ -161,7 +161,8 @@ func (u *RIOConn) ListenOut(r EncReader) error { continue } - r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n]) + r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n], RxMeta{}) + flush() } } @@ -316,6 +317,15 @@ func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error { return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) } +func (u *RIOConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error { + for i, b := range bufs { + if err := u.WriteTo(b, addrs[i]); err != nil { + return err + } + } + return nil +} + func (u *RIOConn) LocalAddr() (netip.AddrPort, error) { sa, err := windows.Getsockname(u.sock) if err != nil { diff --git a/udp/udp_tester.go b/udp/udp_tester.go index f872e32a..6b877b71 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -157,15 +157,24 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { return nil } } +func (u *TesterConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error { + for i, b := range bufs { + if err := u.WriteTo(b, addrs[i]); err != nil { + return err + } + } + return nil +} -func (u *TesterConn) ListenOut(r EncReader) error { +func (u *TesterConn) ListenOut(r EncReader, flush func()) error { for { select { case <-u.done: return os.ErrClosed case p := <-u.RxPackets: - r(p.From, p.Data) + r(p.From, p.Data, RxMeta{}) p.Release() + flush() } } }