diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index b555fbc4..1691aeab 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -4,15 +4,13 @@ package e2e import ( - "io" + "log/slog" "net/netip" "os" "strings" "testing" "time" - "log/slog" - "dario.cat/mergo" "github.com/google/gopacket" "github.com/google/gopacket/layers" @@ -382,7 +380,7 @@ func getAddrs(ns []netip.Prefix) []netip.Addr { func NewTestLogger() *slog.Logger { v := os.Getenv("TEST_LOGS") if v == "" { - return slog.New(slog.NewTextHandler(io.Discard, nil)) + return slog.New(slog.DiscardHandler) } level := slog.LevelInfo diff --git a/handshake_manager.go b/handshake_manager.go index 87257028..1384b346 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -974,6 +974,7 @@ func (hm *HandshakeManager) continueHandshake(via ViaSender, hh *HandshakeHostIn nb := make([]byte, 12, 12) out := make([]byte, mtu) for _, cp := range hh.packetStore { + //todo use a sendbatcher cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out) } f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore))) diff --git a/inside.go b/inside.go index 68cb38ec..35da66d3 100644 --- a/inside.go +++ b/inside.go @@ -2,6 +2,7 @@ package nebula import ( "context" + "io" "log/slog" "net/netip" @@ -9,10 +10,16 @@ import ( "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/noiseutil" + "github.com/slackhq/nebula/overlay/batch" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" ) -func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) consumeInsidePacket(pkt tio.Packet, fwPacket *firewall.Packet, nb []byte, sendBatch batch.TxBatcher, rejectBuf []byte, q int, localCache firewall.ConntrackCache) { + // borrowed: pkt.Bytes is owned by the originating tio.Queue and is + // only valid until the next Read on that queue. If you must keep + // the packet, use pkt.Clone() to detach it + packet := pkt.Bytes err := newPacket(packet, false, fwPacket) if err != nil { if f.l.Enabled(context.Background(), slog.LevelDebug) { @@ -37,7 +44,10 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet // routes packets from the Nebula addr to the Nebula addr through the Nebula // TUN device. if immediatelyForwardToSelf { - _, err := f.readers[q].Write(packet) + err := tio.SegmentSuperpacket(pkt, func(seg []byte) error { + _, werr := f.readers[q].Write(seg) + return werr + }) if err != nil { f.l.Error("Failed to forward to tun", "error", err) } @@ -53,11 +63,23 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { - hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) + // borrowed: SegmentSuperpacket builds each segment in the kernel-supplied pkt + // bytes underneath. cachePacket explicitly copies its argument (handshake_manager.go cachePacket), + // so retaining segments past the loop is safe. + err := tio.SegmentSuperpacket(pkt, func(seg []byte) error { + hh.cachePacket(f.l, header.Message, 0, seg, f.sendMessageNow, f.cachedPacketMetrics) + return nil + }) + if err != nil && f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Failed to segment superpacket for handshake cache", + "error", err, + "vpnAddr", fwPacket.RemoteAddr, + ) + } }) if hostinfo == nil { - f.rejectInside(packet, out, q) + f.rejectInside(packet, rejectBuf, q) if f.l.Enabled(context.Background(), slog.LevelDebug) { f.l.Debug("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks", "vpnAddr", fwPacket.RemoteAddr, @@ -73,10 +95,9 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { - f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) - + f.sendInsideMessage(hostinfo, pkt, nb, sendBatch, rejectBuf, q) } else { - f.rejectInside(packet, out, q) + f.rejectInside(packet, rejectBuf, q) if f.l.Enabled(context.Background(), slog.LevelDebug) { hostinfo.logger(f.l).Debug("dropping outbound packet", "fwPacket", fwPacket, @@ -86,6 +107,124 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } } +func (f *Interface) sendInsideEncrypt(hostinfo *HostInfo, ci *ConnectionState, seg, scratch, nb []byte) []byte { + if noiseutil.EncryptLockNeeded { + ci.writeLock.Lock() + } + c := ci.messageCounter.Add(1) + + out := header.Encode(scratch, header.Version, header.Message, 0, hostinfo.remoteIndexId, c) + f.connectionManager.Out(hostinfo) + + out, encErr := ci.eKey.EncryptDanger(out, out, seg, c, nb) + if noiseutil.EncryptLockNeeded { + ci.writeLock.Unlock() + } + if encErr != nil { + hostinfo.logger(f.l).Error("Failed to encrypt outgoing packet", + "error", encErr, + "udpAddr", hostinfo.remote, + "counter", c, + ) + // Skip this segment; the rest of the superpacket can still + // go out — TCP will retransmit anything we drop here. + return nil + } + + return out +} + +// sendInsideMessage encrypts a firewall-approved inside packet (or every +// segment of a TSO/USO superpacket) into the caller's batch slot for +// later sendmmsg flush. Segmentation is fused with encryption here so the +// kernel-supplied superpacket bytes never get written into a separate +// scratch arena: SegmentSuperpacket builds each segment's plaintext in +// segScratch[:segLen] in turn, and we encrypt directly into a fresh +// SendBatch slot. +func (f *Interface) sendInsideMessage(hostinfo *HostInfo, pkt tio.Packet, nb []byte, sendBatch batch.TxBatcher, rejectBuf []byte, q int) { + ci := hostinfo.ConnectionState + if ci.eKey == nil { + return + } + + if hostinfo.lastRebindCount != f.rebindCount { + //NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is + // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. + f.lightHouse.QueryServer(hostinfo.vpnAddrs[0]) + hostinfo.lastRebindCount = f.rebindCount + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Lighthouse update triggered for punch due to rebind counter", + "vpnAddrs", hostinfo.vpnAddrs, + ) + } + } + + if !hostinfo.remote.IsValid() { //the relay path + //first, find our relay hostinfo: + var relayHostInfo *HostInfo + var relay *Relay + var err error + for _, relayIP := range hostinfo.relayState.CopyRelayIps() { + relayHostInfo, relay, err = f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP) + if err != nil { + hostinfo.relayState.DeleteRelay(relayIP) + hostinfo.logger(f.l).Info("sendNoMetrics failed to find HostInfo", + "relay", relayIP, + "error", err, + ) + continue + } + break + } + if relayHostInfo == nil || relay == nil { + //failure already logged + return + } + + err = tio.SegmentSuperpacket(pkt, func(seg []byte) error { + //relay header + header + plaintext + AEAD tag (16 bytes for both AES-GCM and ChaCha20-Poly1305) + relay tag + scratch := sendBatch.Reserve(header.Len + header.Len + len(seg) + 16 + 16) + + innerPacket := f.sendInsideEncrypt(hostinfo, ci, seg, scratch[header.Len:], nb) + if innerPacket == nil { + return nil + } + + //now we need to do a relay-encrypt: + toSend, err := f.prepareSendVia(relayHostInfo, relay, innerPacket, nb, scratch, true) + if err != nil { + //already logged + return nil + } + + sendBatch.Commit(toSend, relayHostInfo.remote, 0) + return nil + }) + if err != nil { + hostinfo.logger(f.l).Error("Failed to segment superpacket for relay send", "error", err) + } + return + } + + err := tio.SegmentSuperpacket(pkt, func(seg []byte) error { + // header + plaintext + AEAD tag (16 bytes for both AES-GCM and ChaCha20-Poly1305) + scratch := sendBatch.Reserve(header.Len + len(seg) + 16) + + out := f.sendInsideEncrypt(hostinfo, ci, seg, scratch, nb) + if out == nil { + return nil + } + + sendBatch.Commit(out, hostinfo.remote, 0) + return nil + }) + if err != nil { + hostinfo.logger(f.l).Error("Failed to segment superpacket for send", + "error", err, + ) + } +} + func (f *Interface) rejectInside(packet []byte, out []byte, q int) { if !f.firewall.InSendReject { return @@ -275,21 +414,13 @@ func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *C f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0) } -// SendVia sends a payload through a Relay tunnel. No authentication or encryption is done -// to the payload for the ultimate target host, making this a useful method for sending -// handshake messages to peers through relay tunnels. -// via is the HostInfo through which the message is relayed. -// ad is the plaintext data to authenticate, but not encrypt -// nb is a buffer used to store the nonce value, re-used for performance reasons. -// out is a buffer used to store the result of the Encrypt operation -// q indicates which writer to use to send the packet. -func (f *Interface) SendVia(via *HostInfo, +func (f *Interface) prepareSendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool, -) { +) ([]byte, error) { if noiseutil.EncryptLockNeeded { // NOTE: for goboring AESGCMTLS we need to lock because of the nonce check via.ConnectionState.writeLock.Lock() @@ -311,7 +442,7 @@ func (f *Interface) SendVia(via *HostInfo, "headerLen", len(out), "cipherOverhead", via.ConnectionState.eKey.Overhead(), ) - return + return nil, io.ErrShortBuffer } // The header bytes are written to the 'out' slice; Grow the slice to hold the header and associated data payload. @@ -331,13 +462,32 @@ func (f *Interface) SendVia(via *HostInfo, } if err != nil { via.logger(f.l).Info("Failed to EncryptDanger in sendVia", "error", err) - return + return nil, err } - err = f.writers[0].WriteTo(out, via.remote) + f.connectionManager.RelayUsed(relay.LocalIndex) + return out, nil +} + +// SendVia sends a payload through a Relay tunnel. No authentication or encryption is done +// to the payload for the ultimate target host, making this a useful method for sending +// handshake messages to peers through relay tunnels. +// via is the HostInfo through which the message is relayed. +// ad is the plaintext data to authenticate, but not encrypt +// nb is a buffer used to store the nonce value, re-used for performance reasons. +// out is a buffer used to store the result of the Encrypt operation +// q indicates which writer to use to send the packet. +func (f *Interface) SendVia(via *HostInfo, + relay *Relay, + ad, + nb, + out []byte, + nocopy bool, +) { + toSend, err := f.prepareSendVia(via, relay, ad, nb, out, nocopy) + err = f.writers[0].WriteTo(toSend, via.remote) if err != nil { via.logger(f.l).Info("Failed to WriteTo in sendVia", "error", err) } - f.connectionManager.RelayUsed(relay.LocalIndex) } func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) { diff --git a/interface.go b/interface.go index 5fedcdd3..2af581f2 100644 --- a/interface.go +++ b/interface.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io" "log/slog" "net/netip" "sync" @@ -13,11 +12,12 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" - "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/overlay" + "github.com/slackhq/nebula/overlay/batch" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/udp" ) @@ -47,7 +47,8 @@ type InterfaceConfig struct { reQueryWait time.Duration ConntrackCacheTimeout time.Duration - l *slog.Logger + + l *slog.Logger } type Interface struct { @@ -88,8 +89,12 @@ type Interface struct { ctx context.Context writers []udp.Conn - readers []io.ReadWriteCloser - wg sync.WaitGroup + readers []tio.Queue + // batchers is one per tun queue, wrapping readers[i]. + // decryptToTun sends plaintext into the batch.RxBatcher; + // listenOut calls its Flush at the end of each UDP recvmmsg batch. + batchers []batch.RxBatcher + wg sync.WaitGroup // fatalErr holds the first unexpected reader error that caused shutdown. // nil means "no fatal error" (yet) @@ -187,7 +192,8 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { routines: c.routines, version: c.version, writers: make([]udp.Conn, c.routines), - readers: make([]io.ReadWriteCloser, c.routines), + readers: make([]tio.Queue, c.routines), + batchers: make([]batch.RxBatcher, c.routines), myVpnNetworks: cs.myVpnNetworks, myVpnNetworksTable: cs.myVpnNetworksTable, myVpnAddrs: cs.myVpnAddrs, @@ -245,15 +251,17 @@ func (f *Interface) activate() error { metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines)) // Prepare n tun queues - var reader io.ReadWriteCloser = f.inside for i := 0; i < f.routines; i++ { if i > 0 { - reader, err = f.inside.NewMultiQueueReader() - if err != nil { + if err = f.inside.NewMultiQueueReader(); err != nil { return err } } - f.readers[i] = reader + } + f.readers = f.inside.Readers() + for i := range f.readers { + arena := batch.NewArena(batch.DefaultPassthroughArenaCap) + f.batchers[i] = batch.NewPassthrough(f.readers[i], arena) } f.wg.Add(1) // for us to wait on Close() to return @@ -311,14 +319,22 @@ func (f *Interface) listenOut(i int) { ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout) lhh := f.lightHouse.NewRequestHandler() - plaintext := make([]byte, udp.MTU) h := &header.H{} fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) - err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { - f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get()) - }) + listener := func(fromUdpAddr netip.AddrPort, payload []byte, meta udp.RxMeta) { + plaintext := f.batchers[i].Reserve(len(payload)) + f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(), meta) + } + + flusher := func() { + if err := f.batchers[i].Flush(); err != nil { + f.l.Error("Failed to flush tun coalescer", "error", err) + } + } + + err := li.ListenOut(listener, flusher) if err != nil && !f.closed.Load() { f.l.Error("Error while reading inbound packet, closing", "error", err) @@ -328,16 +344,17 @@ func (f *Interface) listenOut(i int) { f.l.Debug("underlay reader is done", "reader", i) } -func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { - packet := make([]byte, mtu) - out := make([]byte, mtu) +func (f *Interface) listenIn(reader tio.Queue, i int) { + rejectBuf := make([]byte, mtu) + arenaSize := batch.SendBatchCap * (udp.MTU + 32) + sb := batch.NewSendBatch(f.writers[i], batch.SendBatchCap, arenaSize) fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout) for { - n, err := reader.Read(packet) + pkts, err := reader.Read() if err != nil { if !f.closed.Load() { f.l.Error("Error while reading outbound packet, closing", "error", err, "reader", i) @@ -346,7 +363,12 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { break } - f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get()) + for _, pkt := range pkts { + f.consumeInsidePacket(pkt, fwPacket, nb, sb, rejectBuf, i, conntrackCache.Get()) + } + if err := sb.Flush(); err != nil { + f.l.Error("Failed to write outgoing batch", "error", err, "writer", i) + } } f.l.Debug("overlay reader is done", "reader", i) diff --git a/outside.go b/outside.go index 17013ed3..cf079fd7 100644 --- a/outside.go +++ b/outside.go @@ -13,6 +13,7 @@ import ( "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/udp" "golang.org/x/net/ipv4" ) @@ -22,7 +23,7 @@ const ( var ErrOutOfWindow = errors.New("out of window packet") -func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, meta udp.RxMeta) { err := h.Parse(packet) if err != nil { // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors @@ -110,8 +111,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, // Relay packets are special if isMessageRelay { - f.handleOutsideRelayPacket(hostinfo, via, out, packet, h, fwPacket, lhf, nb, q, localCache) - + f.handleOutsideRelayPacket(hostinfo, via, out, packet, h, fwPacket, lhf, nb, q, localCache, meta) return } @@ -135,7 +135,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, case header.Message: switch h.Subtype { case header.MessageNone: - f.handleOutsideMessagePacket(hostinfo, out, packet, fwPacket, nb, q, localCache) + f.handleOutsideMessagePacket(hostinfo, out, packet, fwPacket, nb, q, localCache, meta) default: hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message subtype seen", "from", via, "header", h) return @@ -168,7 +168,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, } } -func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, meta udp.RxMeta) { // The entire body is sent as AD, not encrypted. // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. // The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's @@ -211,7 +211,7 @@ func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, relay: relay, IsRelayed: true, } - f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) + f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, meta) case ForwardingType: // Find the target HostInfo relay object targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) @@ -229,7 +229,7 @@ func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, switch targetRelay.Type { case ForwardingType: // Forward this packet through the relay tunnel - // Find the target HostInfo + // Find the target HostInfo //todo it would potentially be nice to batch these f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false) case TerminalType: hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") @@ -512,7 +512,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] return out, nil } -func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, meta udp.RxMeta) { err := newPacket(out, true, fwPacket) if err != nil { hostinfo.logger(f.l).Warn("Error while validating inbound packet", @@ -536,7 +536,7 @@ func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, p return } - _, err = f.readers[q].Write(out) + err = f.batchers[q].Commit(out) if err != nil { f.l.Error("Failed to write to tun", "error", err) } diff --git a/overlay/batch/batch.go b/overlay/batch/batch.go new file mode 100644 index 00000000..d171d136 --- /dev/null +++ b/overlay/batch/batch.go @@ -0,0 +1,28 @@ +package batch + +import "net/netip" + +type RxBatcher interface { + // Reserve creates a pkt to borrow + Reserve(sz int) []byte + // Commit borrows pkt. The caller must keep pkt valid until the next Flush + Commit(pkt []byte) error + // Flush emits every queued packet in arrival order. Returns the + // first error observed; keeps draining so one bad packet doesn't hold up + // the rest. After Flush returns, borrowed payload slices may be recycled. + Flush() error +} + +type TxBatcher interface { + // Reserve creates a pkt to borrow + Reserve(sz int) []byte + // Commit borrows pkt and records its destination plus the 2-bit + // IP-level ECN codepoint to set on the outer (carrier) header. The + // caller must keep pkt valid until the next Flush. Pass 0 (Not-ECT) + // to leave the outer ECN field unset. + Commit(pkt []byte, dst netip.AddrPort, outerECN byte) + // Flush emits every queued packet via the underlying batch writer in + // arrival order. Returns an errors.Join of one or more errors. After Flush returns, + // borrowed payload slices may be recycled. + Flush() error +} diff --git a/overlay/batch/coalesce_core.go b/overlay/batch/coalesce_core.go new file mode 100644 index 00000000..0213ea9d --- /dev/null +++ b/overlay/batch/coalesce_core.go @@ -0,0 +1,42 @@ +package batch + +// Arena is an injectable byte-slab that hands out non-overlapping borrowed +// slices via Reserve and releases them in bulk via Reset. Coalescers take +// an *Arena at construction so the caller controls the slab lifetime and +// can share one slab across multiple coalescers (MultiCoalescer hands the +// same *Arena to every lane so the lanes don't carry their own backings). +// +// Reserve borrows; the slice is valid until the next Reset. The slab grows +// (by allocating a fresh, larger backing array) if a Reserve doesn't fit; +// pre-size the arena via NewArena to avoid that path on the hot path. +type Arena struct { + buf []byte +} + +// NewArena returns an Arena with a pre-allocated backing of the given +// capacity. Pass 0 if you don't intend to call Reserve (e.g. a test that +// only feeds the coalescer pre-made []byte packets via Commit). +func NewArena(capacity int) *Arena { + return &Arena{buf: make([]byte, 0, capacity)} +} + +// Reserve hands out a non-overlapping sz-byte slice from the arena. If the +// request doesn't fit the current backing, a fresh, larger backing is +// allocated; already-borrowed slices reference the old backing and remain +// valid until Reset. +func (a *Arena) Reserve(sz int) []byte { + if len(a.buf)+sz > cap(a.buf) { + newCap := max(cap(a.buf)*2, sz) + a.buf = make([]byte, 0, newCap) + } + start := len(a.buf) + a.buf = a.buf[:start+sz] + return a.buf[start : start+sz : start+sz] +} + +// Reset releases every slice handed out since the last Reset. Callers must +// not use any previously-borrowed slice after this returns. The underlying +// backing array is retained so subsequent Reserves don't re-allocate. +func (a *Arena) Reset() { + a.buf = a.buf[:0] +} diff --git a/overlay/batch/passthrough.go b/overlay/batch/passthrough.go new file mode 100644 index 00000000..aea90fa4 --- /dev/null +++ b/overlay/batch/passthrough.go @@ -0,0 +1,52 @@ +package batch + +import ( + "io" + + "github.com/slackhq/nebula/udp" +) + +// Passthrough is a RxBatcher that doesn't batch anything, it just accumulates and then sends packets. +type Passthrough struct { + out io.Writer + slots [][]byte + arena *Arena + cursor int +} + +const passthroughBaseNumSlots = 128 + +// DefaultPassthroughArenaCap is the recommended arena capacity for a +// standalone Passthrough batcher: 128 slots × udp.MTU ≈ 1.1 MiB. +const DefaultPassthroughArenaCap = passthroughBaseNumSlots * udp.MTU + +func NewPassthrough(w io.Writer, arena *Arena) *Passthrough { + return &Passthrough{ + out: w, + slots: make([][]byte, 0, passthroughBaseNumSlots), + arena: arena, + } +} + +func (p *Passthrough) Reserve(sz int) []byte { + return p.arena.Reserve(sz) +} + +func (p *Passthrough) Commit(pkt []byte) error { + p.slots = append(p.slots, pkt) + return nil +} + +func (p *Passthrough) Flush() error { + var firstErr error + for _, s := range p.slots { + _, err := p.out.Write(s) + if err != nil && firstErr == nil { + firstErr = err + } + } + clear(p.slots) + p.slots = p.slots[:0] + p.arena.Reset() + return firstErr +} diff --git a/overlay/batch/tx_batch.go b/overlay/batch/tx_batch.go new file mode 100644 index 00000000..38f86b25 --- /dev/null +++ b/overlay/batch/tx_batch.go @@ -0,0 +1,65 @@ +package batch + +import "net/netip" + +const SendBatchCap = 128 + +// batchWriter is the minimal subset of udp.Conn needed by SendBatch to flush. +type batchWriter interface { + WriteBatch(bufs [][]byte, addrs []netip.AddrPort, outerECNs []byte) error +} + +// SendBatch accumulates encrypted UDP packets and flushes them via WriteBatch. +// One SendBatch is owned by each listenIn goroutine; no locking is needed. +// The backing arena grows on demand: when there isn't room for the next slot +// we allocate a fresh backing array. Already-committed slices keep referencing +// the old array and remain valid until Flush drops them. +type SendBatch struct { + out batchWriter + bufs [][]byte + dsts []netip.AddrPort + ecns []byte + backing []byte +} + +// NewSendBatch makes a SendBatch with batchCap slots and an arenaSize byte buffer for slices to back those slots +func NewSendBatch(out batchWriter, batchCap, arenaSize int) *SendBatch { + return &SendBatch{ + out: out, + bufs: make([][]byte, 0, batchCap), + dsts: make([]netip.AddrPort, 0, batchCap), + ecns: make([]byte, 0, batchCap), + backing: make([]byte, 0, arenaSize), + } +} + +func (b *SendBatch) Reserve(sz int) []byte { + if len(b.backing)+sz > cap(b.backing) { + // Grow: allocate a fresh backing. Already-committed slices still + // reference the old array and remain valid until Flush drops them. + newCap := max(cap(b.backing)*2, sz) + b.backing = make([]byte, 0, newCap) + } + start := len(b.backing) + b.backing = b.backing[:start+sz] + return b.backing[start : start+sz : start+sz] +} + +func (b *SendBatch) Commit(pkt []byte, dst netip.AddrPort, outerECN byte) { + b.bufs = append(b.bufs, pkt) + b.dsts = append(b.dsts, dst) + b.ecns = append(b.ecns, outerECN) +} + +func (b *SendBatch) Flush() error { + var err error + if len(b.bufs) > 0 { + err = b.out.WriteBatch(b.bufs, b.dsts, b.ecns) + } + clear(b.bufs) + b.bufs = b.bufs[:0] + b.dsts = b.dsts[:0] + b.ecns = b.ecns[:0] + b.backing = b.backing[:0] + return err +} diff --git a/overlay/batch/tx_batch_test.go b/overlay/batch/tx_batch_test.go new file mode 100644 index 00000000..454011dc --- /dev/null +++ b/overlay/batch/tx_batch_test.go @@ -0,0 +1,124 @@ +package batch + +import ( + "net/netip" + "testing" +) + +type fakeBatchWriter struct { + bufs [][]byte + addrs []netip.AddrPort + ecns []byte +} + +func (w *fakeBatchWriter) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, ecns []byte) error { + // Snapshot — SendBatch.Flush nils its slot pointers right after WriteBatch + // returns, so tests must capture data before that happens. + w.bufs = make([][]byte, len(bufs)) + for i, b := range bufs { + cp := make([]byte, len(b)) + copy(cp, b) + w.bufs[i] = cp + } + w.addrs = append(w.addrs[:0], addrs...) + w.ecns = append(w.ecns[:0], ecns...) + return nil +} + +func TestSendBatchReserveCommitFlush(t *testing.T) { + fw := &fakeBatchWriter{} + b := NewSendBatch(fw, 4, 32) + + ap := netip.MustParseAddrPort("10.0.0.1:4242") + for i := 0; i < 4; i++ { + slot := b.Reserve(32) + if cap(slot) != 32 { + t.Fatalf("slot %d: cap=%d want 32", i, cap(slot)) + } + pkt := append(slot[:0], byte(i), byte(i+1), byte(i+2)) + b.Commit(pkt, ap, 0) + } + if err := b.Flush(); err != nil { + t.Fatalf("Flush: %v", err) + } + if len(fw.bufs) != 4 { + t.Fatalf("WriteBatch got %d bufs want 4", len(fw.bufs)) + } + for i, buf := range fw.bufs { + if len(buf) != 3 || buf[0] != byte(i) { + t.Errorf("buf %d: %x", i, buf) + } + if fw.addrs[i] != ap { + t.Errorf("addr %d: got %v want %v", i, fw.addrs[i], ap) + } + } + + // Flush again with nothing committed — should be a no-op. + fw.bufs = nil + if err := b.Flush(); err != nil { + t.Fatalf("empty Flush: %v", err) + } + if fw.bufs != nil { + t.Fatalf("empty Flush triggered WriteBatch") + } + + // Reuse after Flush. + slot := b.Reserve(32) + if cap(slot) != 32 { + t.Fatalf("after Flush Reserve wrong cap: %d", cap(slot)) + } +} + +func TestSendBatchSlotsDoNotOverlap(t *testing.T) { + fw := &fakeBatchWriter{} + b := NewSendBatch(fw, 3, 8) + ap := netip.MustParseAddrPort("10.0.0.1:80") + + for i := 0; i < 3; i++ { + s := b.Reserve(8) + pkt := append(s[:0], byte(0xA0+i), byte(0xB0+i)) + b.Commit(pkt, ap, 0) + } + if err := b.Flush(); err != nil { + t.Fatalf("Flush: %v", err) + } + + for i, buf := range fw.bufs { + if buf[0] != byte(0xA0+i) || buf[1] != byte(0xB0+i) { + t.Errorf("slot %d corrupted: %x", i, buf) + } + } +} + +func TestSendBatchGrowPreservesCommitted(t *testing.T) { + fw := &fakeBatchWriter{} + // Tiny initial backing forces a grow on the second Reserve. + b := NewSendBatch(fw, 1, 4) + ap := netip.MustParseAddrPort("10.0.0.1:80") + + s1 := b.Reserve(4) + pkt1 := append(s1[:0], 0x11, 0x22, 0x33, 0x44) + b.Commit(pkt1, ap, 0) + + s2 := b.Reserve(8) // exceeds remaining cap, triggers grow + pkt2 := append(s2[:0], 0xA, 0xB, 0xC, 0xD, 0xE) + b.Commit(pkt2, ap, 0) + + // pkt1 must still be intact even though backing reallocated. + if pkt1[0] != 0x11 || pkt1[3] != 0x44 { + t.Fatalf("first packet corrupted by grow: %x", pkt1) + } + + if err := b.Flush(); err != nil { + t.Fatalf("Flush: %v", err) + } + if len(fw.bufs) != 2 { + t.Fatalf("got %d bufs want 2", len(fw.bufs)) + } + if fw.bufs[0][0] != 0x11 || fw.bufs[0][3] != 0x44 { + t.Errorf("first packet on the wire: %x", fw.bufs[0]) + } + if fw.bufs[1][0] != 0xA || fw.bufs[1][4] != 0xE { + t.Errorf("second packet on the wire: %x", fw.bufs[1]) + } +} diff --git a/overlay/device.go b/overlay/device.go index b6077aba..8044ee75 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -4,15 +4,21 @@ import ( "io" "net/netip" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" ) +// defaultBatchBufSize is the per-Queue scratch size for Read on backends +// that don't do TSO segmentation. 65535 covers any single IP packet. +const defaultBatchBufSize = 65535 + type Device interface { - io.ReadWriteCloser + io.Closer Activate() error Networks() []netip.Prefix Name() string RoutesFor(netip.Addr) routing.Gateways SupportsMultiqueue() bool - NewMultiQueueReader() (io.ReadWriteCloser, error) + NewMultiQueueReader() error + Readers() []tio.Queue } diff --git a/overlay/overlaytest/noop.go b/overlay/overlaytest/noop.go index 956da7dd..6a39ab43 100644 --- a/overlay/overlaytest/noop.go +++ b/overlay/overlaytest/noop.go @@ -4,9 +4,9 @@ package overlaytest import ( "errors" - "io" "net/netip" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" ) @@ -31,8 +31,8 @@ func (NoopTun) Name() string { return "noop" } -func (NoopTun) Read([]byte) (int, error) { - return 0, nil +func (NoopTun) Read() ([]tio.Packet, error) { + return nil, nil } func (NoopTun) Write([]byte) (int, error) { @@ -43,8 +43,12 @@ func (NoopTun) SupportsMultiqueue() bool { return false } -func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, errors.New("unsupported") +func (NoopTun) NewMultiQueueReader() error { + return errors.New("unsupported") +} + +func (NoopTun) Readers() []tio.Queue { + return []tio.Queue{NoopTun{}} } func (NoopTun) Close() error { diff --git a/overlay/tio/queueset_poll_linux.go b/overlay/tio/queueset_poll_linux.go new file mode 100644 index 00000000..ab967df4 --- /dev/null +++ b/overlay/tio/queueset_poll_linux.go @@ -0,0 +1,69 @@ +package tio + +import ( + "encoding/binary" + "errors" + "fmt" + + "golang.org/x/sys/unix" +) + +type pollQueueSet struct { + pq []*Poll + // pqi is exactly the same as pq, but stored as the interface type + pqi []Queue + shutdownFd int +} + +func NewPollQueueSet() (QueueSet, error) { + shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC) + if err != nil { + return nil, fmt.Errorf("failed to create eventfd: %w", err) + } + + out := &pollQueueSet{ + pq: []*Poll{}, + pqi: []Queue{}, + shutdownFd: shutdownFd, + } + + return out, nil +} + +func (c *pollQueueSet) Queues() []Queue { + return c.pqi +} + +func (c *pollQueueSet) Add(fd int) error { + x, err := newPoll(fd, c.shutdownFd) + if err != nil { + return err + } + c.pq = append(c.pq, x) + c.pqi = append(c.pqi, x) + + return nil +} + +func (c *pollQueueSet) wakeForShutdown() error { + var buf [8]byte + binary.NativeEndian.PutUint64(buf[:], 1) + _, err := unix.Write(int(c.shutdownFd), buf[:]) + return err +} + +func (c *pollQueueSet) Close() error { + errs := []error{} + + if err := c.wakeForShutdown(); err != nil { + errs = append(errs, err) + } + + for _, x := range c.pq { + if err := x.Close(); err != nil { + errs = append(errs, err) + } + } + + return errors.Join(errs...) +} diff --git a/overlay/tio/segment.go b/overlay/tio/segment.go new file mode 100644 index 00000000..67648ad2 --- /dev/null +++ b/overlay/tio/segment.go @@ -0,0 +1,12 @@ +package tio + +import "fmt" + +// SegmentSuperpacket invokes fn once per segment of pkt. +// This is a stub implementation that does not actually support segmentation +func SegmentSuperpacket(pkt Packet, fn func(seg []byte) error) error { + if pkt.GSO.IsSuperpacket() { + return fmt.Errorf("tio: GSO superpacket on platform without segmentation support") + } + return fn(pkt.Bytes) +} diff --git a/overlay/tio/tio.go b/overlay/tio/tio.go new file mode 100644 index 00000000..2d94c764 --- /dev/null +++ b/overlay/tio/tio.go @@ -0,0 +1,170 @@ +package tio + +import ( + "io" +) + +// QueueSet holds one or many Queue objects and helps close them in an orderly way. +type QueueSet interface { + io.Closer + Queues() []Queue + + // Add takes a tun fd, adds it to the set, and prepares it for use as a Queue. + Add(fd int) error +} + +// Capabilities advertises which kernel offload features a Queue +// successfully negotiated. Callers consult this to decide which coalescers +// to wire onto the write path — a Queue without TSO can't usefully accept a +// TCPCoalescer, and a Queue without USO can't accept a UDPCoalescer. +type Capabilities struct { + // TSO means the FD was opened with IFF_VNET_HDR and the kernel agreed + // to TUN_F_TSO4|TSO6 — i.e. WriteGSO with GSOProtoTCP is safe. + TSO bool + // USO means the kernel additionally agreed to TUN_F_USO4|USO6, so + // WriteGSO with GSOProtoUDP is safe. Linux ≥ 6.2. + USO bool +} + +// Queue is a readable/writable Poll queue. One Queue is driven by a single +// read goroutine plus a single writer (see Write below). +type Queue interface { + io.Closer + + // Read returns one or more packets. The returned Packet.Bytes slices + // are borrowed from the Queue's internal buffer and are only valid + // until the next Read or Close on this Queue - callers must encrypt + // or copy each slice before the next call. A Packet may carry a + // GSO/USO superpacket (see GSOInfo); when GSO.IsSuperpacket() is + // true the caller must segment Bytes before treating it as a single + // IP datagram. Not safe for concurrent Reads. + Read() ([]Packet, error) + + // Write emits a single packet on the plaintext (outside→inside) + // delivery path. Not safe for concurrent Writes. + Write(p []byte) (int, error) +} + +// Packet is the unit Queue.Read returns. Bytes points into the queue's +// internal buffer and is only valid until the next Read or Close on the +// queue that produced it. GSO is the zero value for an already-segmented +// IP datagram; when non-zero it describes a kernel-supplied TSO/USO +// superpacket the caller must segment before consuming. +type Packet struct { + Bytes []byte + GSO GSOInfo +} + +// GSOInfo describes a kernel-supplied superpacket sitting in Packet.Bytes. +// The zero value means "not a superpacket" — Bytes is one regular IP +// datagram and no segmentation is required. +type GSOInfo struct { + // Size is the GSO segment size: max payload bytes per segment + // (== TCP MSS for TSO, == UDP payload chunk for USO). Zero means + // not a superpacket. + Size uint16 + // HdrLen is the total L3+L4 header length within Bytes (already + // corrected via correctHdrLen, so safe to slice on). + HdrLen uint16 + // CsumStart is the L4 header offset inside Bytes (== L3 header + // length). + CsumStart uint16 + // Proto picks the L4 protocol (TCP or UDP) so the segmenter knows + // which checksum/header layout to apply. + Proto GSOProto +} + +// IsSuperpacket reports whether g describes a multi-segment GSO/USO +// superpacket that needs segmentation before its bytes can be encrypted +// and sent on the wire. +func (g GSOInfo) IsSuperpacket() bool { return g.Size > 0 } + +// Clone returns a Packet whose Bytes is a freshly allocated copy of p.Bytes, +// safe to retain past the next Read or Close on the originating Queue. +// GSO metadata is copied verbatim. Use this only when a caller genuinely +// needs to outlive the borrowed-slice contract — the hot path reads should +// continue to consume the borrow synchronously to avoid the allocation. +func (p Packet) Clone() Packet { + if p.Bytes == nil { + return p + } + cp := make([]byte, len(p.Bytes)) + copy(cp, p.Bytes) + return Packet{Bytes: cp, GSO: p.GSO} +} + +// CapsProvider is an optional interface implemented by Queues that +// successfully negotiated kernel offload features at open time. Callers +// pick a write-path coalescer based on the result. Queues that don't +// implement it are treated as having no offload capability — callers must +// fall back to plain per-packet writes. +type CapsProvider interface { + Capabilities() Capabilities +} + +// QueueCapabilities returns q's negotiated offload capabilities, or the +// zero value when q does not advertise any. +func QueueCapabilities(q Queue) Capabilities { + if cp, ok := q.(CapsProvider); ok { + return cp.Capabilities() + } + return Capabilities{} +} + +// GSOProto selects the L4 protocol for a GSO superpacket. Determines which +// VIRTIO_NET_HDR_GSO_* type the writer stamps and which checksum offset +// inside the transport header virtio NEEDS_CSUM expects. +type GSOProto uint8 + +const ( + GSOProtoTCP GSOProto = iota + GSOProtoUDP +) + +// GSOWriter is implemented by Queues that can emit a TCP or UDP superpacket +// assembled from a header prefix plus one or more borrowed payload +// fragments, in a single vectored write (writev with a leading +// virtio_net_hdr). This lets the coalescer avoid copying payload bytes +// between the caller's decrypt buffer and the TUN. Backends without GSO +// support do not implement this interface and coalescing is skipped. +// +// hdr contains the IPv4/IPv6 header prefix (mutable - callers will have +// filled in total length and IP csum). transportHdr is the TCP or UDP +// header (mutable - the L4 checksum field must hold the pseudo-header +// partial, single-fold not inverted, per virtio NEEDS_CSUM semantics). +// pays are non-overlapping payload fragments whose concatenation is the +// full superpacket payload; they are read-only from the writer's +// perspective and must remain valid until the call returns. Every segment +// in pays except possibly the last is exactly the same size. proto picks +// the L4 protocol so the writer knows which GSOType / CsumOffset to set. +// +// Callers should also consult CapsProvider (via SupportsGSO or +// QueueCapabilities) for the per-protocol negotiated capability; an +// implementation of GSOWriter is necessary but not sufficient since USO +// may not have been negotiated even when TSO was. +type GSOWriter interface { + WriteGSO(hdr []byte, transportHdr []byte, pays [][]byte, proto GSOProto) error +} + +// SupportsGSO reports whether w implements GSOWriter and the underlying +// queue advertises the negotiated capability for `want`. A writer that +// implements GSOWriter but not CapsProvider is treated as permissive +// (used by tests and fakes that don't negotiate). +func SupportsGSO(w any, want GSOProto) (GSOWriter, bool) { + gw, ok := w.(GSOWriter) + if !ok { + return nil, false + } + cp, ok := w.(CapsProvider) + if !ok { + return gw, true + } + caps := cp.Capabilities() + switch want { + case GSOProtoTCP: + return gw, caps.TSO + case GSOProtoUDP: + return gw, caps.USO + } + return gw, false +} diff --git a/overlay/tio/tio_poll_linux.go b/overlay/tio/tio_poll_linux.go new file mode 100644 index 00000000..2aa58813 --- /dev/null +++ b/overlay/tio/tio_poll_linux.go @@ -0,0 +1,168 @@ +package tio + +import ( + "fmt" + "os" + "sync" + "sync/atomic" + + "golang.org/x/sys/unix" +) + +type Poll struct { + fd int + + readPoll [2]unix.PollFd + writePoll [2]unix.PollFd + writeLock sync.Mutex + closed atomic.Bool + + readBuf []byte + batchRet [1]Packet +} + +func newPoll(fd int, shutdownFd int) (*Poll, error) { + if err := unix.SetNonblock(fd, true); err != nil { + _ = unix.Close(fd) + return nil, fmt.Errorf("failed to set Poll device as nonblocking: %w", err) + } + + out := &Poll{ + fd: fd, + readBuf: make([]byte, 65535), + readPoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLIN}, + {Fd: int32(shutdownFd), Events: unix.POLLIN}, + }, + writePoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLOUT}, + {Fd: int32(shutdownFd), Events: unix.POLLIN}, + }, + writeLock: sync.Mutex{}, + } + return out, nil +} + +// blockOnRead waits until the Poll fd is readable or shutdown has been signaled. +// Returns os.ErrClosed if Close was called. +func (t *Poll) blockOnRead() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error + for { + _, err = unix.Poll(t.readPoll[:], -1) + if err != unix.EINTR { + break + } + } + tunEvents := t.readPoll[0].Revents + shutdownEvents := t.readPoll[1].Revents + t.readPoll[0].Revents = 0 + t.readPoll[1].Revents = 0 + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } + if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil +} + +func (t *Poll) blockOnWrite() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error + for { + _, err = unix.Poll(t.writePoll[:], -1) + if err != unix.EINTR { + break + } + } + t.writeLock.Lock() + tunEvents := t.writePoll[0].Revents + shutdownEvents := t.writePoll[1].Revents + t.writePoll[0].Revents = 0 + t.writePoll[1].Revents = 0 + t.writeLock.Unlock() + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } + if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil +} + +func (t *Poll) Read() ([]Packet, error) { + n, err := t.readOne(t.readBuf) + if err != nil { + return nil, err + } + t.batchRet[0] = Packet{Bytes: t.readBuf[:n]} + return t.batchRet[:], nil +} + +func (t *Poll) readOne(to []byte) (int, error) { + for { + n, errno := unix.Read(t.fd, to) + if errno == nil { + return n, nil + } + switch errno { + case unix.EAGAIN: + if err := t.blockOnRead(); err != nil { + return 0, err + } + case unix.EINTR: + // retry + case unix.EBADF: + return 0, os.ErrClosed + default: + return 0, errno + } + } +} + +// Write is only valid for single threaded use +func (t *Poll) Write(from []byte) (int, error) { + for { + n, errno := unix.Write(t.fd, from) + if errno == nil { + return n, nil + } + switch errno { + case unix.EAGAIN: + if err := t.blockOnWrite(); err != nil { + return 0, err + } + case unix.EINTR: + // retry + case unix.EBADF: + return 0, os.ErrClosed + default: + return 0, errno + } + } +} + +func (t *Poll) Close() error { + if t.closed.Swap(true) { + return nil + } + //shutdownFd is owned by the container, so we should not close it + var err error + if t.fd >= 0 { + err = unix.Close(t.fd) + t.fd = -1 + } + + return err +} + +func (t *Poll) Capabilities() Capabilities { + return Capabilities{TSO: false, USO: false} +} diff --git a/overlay/tio/tun_file_linux_test.go b/overlay/tio/tun_file_linux_test.go new file mode 100644 index 00000000..f92f58ec --- /dev/null +++ b/overlay/tio/tun_file_linux_test.go @@ -0,0 +1,82 @@ +//go:build linux && !android && !e2e_testing +// +build linux,!android,!e2e_testing + +package tio + +import ( + "errors" + "os" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/sys/unix" +) + +// newReadPipe returns a read fd. The matching write fd is registered for cleanup. +// The caller takes ownership of the read fd (pass it into a QueueSet). +func newReadPipe(t *testing.T) int { + t.Helper() + var fds [2]int + if err := unix.Pipe2(fds[:], unix.O_CLOEXEC); err != nil { + t.Fatalf("pipe2: %v", err) + } + t.Cleanup(func() { _ = unix.Close(fds[1]) }) + return fds[0] +} + +func TestPoll_WakeForShutdown_WakesFriends(t *testing.T) { + pipe1 := newReadPipe(t) + pipe2 := newReadPipe(t) + parent, err := NewPollQueueSet() + require.NoError(t, err) + require.NoError(t, parent.Add(pipe1)) + require.NoError(t, parent.Add(pipe2)) + t.Cleanup(func() { + _ = unix.Close(pipe1) + _ = unix.Close(pipe2) + }) + + readers := parent.Queues() + errs := make([]error, len(readers)) + var wg sync.WaitGroup + for i, r := range readers { + wg.Add(1) + go func(i int, r Queue) { + defer wg.Done() + _, errs[i] = r.Read() + }(i, r) + } + + time.Sleep(50 * time.Millisecond) + + if err := parent.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("readers did not wake") + } + + for i, err := range errs { + if !errors.Is(err, os.ErrClosed) { + t.Errorf("reader %d: expected os.ErrClosed, got %v", i, err) + } + } +} + +func TestPoll_Close_Idempotent(t *testing.T) { + tf, err := newPoll(newReadPipe(t), 1) + require.NoError(t, err) + if err := tf.Close(); err != nil { + t.Fatalf("first Close: %v", err) + } + if err := tf.Close(); err != nil { + t.Fatalf("second Close should be a no-op, got %v", err) + } +} diff --git a/overlay/tun_android.go b/overlay/tun_android.go index 9cbb64be..ea2e1295 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -13,17 +13,38 @@ import ( "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) type tun struct { - io.ReadWriteCloser + rwc io.ReadWriteCloser fd int vpnNetworks []netip.Prefix Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *slog.Logger + + readBuf []byte + batchRet [1]tio.Packet +} + +func (t *tun) Read() ([]tio.Packet, error) { + n, err := t.rwc.Read(t.readBuf) + if err != nil { + return nil, err + } + t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]} + return t.batchRet[:], nil +} + +func (t *tun) Write(p []byte) (int, error) { + return t.rwc.Write(p) +} + +func (t *tun) Close() error { + return t.rwc.Close() } func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { @@ -32,10 +53,11 @@ func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") t := &tun{ - ReadWriteCloser: file, - fd: deviceFd, - vpnNetworks: vpnNetworks, - l: l, + rwc: file, + fd: deviceFd, + vpnNetworks: vpnNetworks, + l: l, + readBuf: make([]byte, defaultBatchBufSize), } err := t.reload(c, true) @@ -62,7 +84,7 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { return r } -func (t tun) Activate() error { +func (t *tun) Activate() error { return nil } @@ -99,6 +121,10 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, fmt.Errorf("TODO: multiqueue not implemented for android") +func (t *tun) NewMultiQueueReader() error { + return fmt.Errorf("TODO: multiqueue not implemented for android") +} + +func (t *tun) Readers() []tio.Queue { + return []tio.Queue{t} } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 524ef0cd..9ace4fc8 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -16,6 +16,7 @@ import ( "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" @@ -23,7 +24,7 @@ import ( ) type tun struct { - io.ReadWriteCloser + rwc io.ReadWriteCloser Device string vpnNetworks []netip.Prefix DefaultMTU int @@ -34,6 +35,9 @@ type tun struct { // cache out buffer since we need to prepend 4 bytes for tun metadata out []byte + + readBuf []byte + batchRet [1]tio.Packet } type ifReq struct { @@ -124,11 +128,12 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t } t := &tun{ - ReadWriteCloser: os.NewFile(uintptr(fd), ""), - Device: name, - vpnNetworks: vpnNetworks, - DefaultMTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, + rwc: os.NewFile(uintptr(fd), ""), + Device: name, + vpnNetworks: vpnNetworks, + DefaultMTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, + readBuf: make([]byte, defaultBatchBufSize), } err = t.reload(c, true) @@ -158,8 +163,8 @@ func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, e } func (t *tun) Close() error { - if t.ReadWriteCloser != nil { - return t.ReadWriteCloser.Close() + if t.rwc != nil { + return t.rwc.Close() } return nil } @@ -502,15 +507,24 @@ func delRoute(prefix netip.Prefix, gateway netroute.Addr) error { return nil } -func (t *tun) Read(to []byte) (int, error) { +func (t *tun) readOne(to []byte) (int, error) { buf := make([]byte, len(to)+4) - n, err := t.ReadWriteCloser.Read(buf) + n, err := t.rwc.Read(buf) copy(to, buf[4:]) return n - 4, err } +func (t *tun) Read() ([]tio.Packet, error) { + n, err := t.readOne(t.readBuf) + if err != nil { + return nil, err + } + t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]} + return t.batchRet[:], nil +} + // Write is only valid for single threaded use func (t *tun) Write(from []byte) (int, error) { buf := t.out @@ -536,7 +550,7 @@ func (t *tun) Write(from []byte) (int, error) { copy(buf[4:], from) - n, err := t.ReadWriteCloser.Write(buf) + n, err := t.rwc.Write(buf) return n - 4, err } @@ -552,6 +566,10 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") +func (t *tun) NewMultiQueueReader() error { + return fmt.Errorf("TODO: multiqueue not implemented for darwin") +} + +func (t *tun) Readers() []tio.Queue { + return []tio.Queue{t} } diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index f47880dd..ff86bc29 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -10,6 +10,7 @@ import ( "github.com/rcrowley/go-metrics" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" ) @@ -18,9 +19,45 @@ type disabledTun struct { vpnNetworks []netip.Prefix // Track these metrics since we don't have the tun device to do it for us - tx metrics.Counter - rx metrics.Counter - l *slog.Logger + tx metrics.Counter + rx metrics.Counter + l *slog.Logger + numReaders int +} + +// disabledQueue is one tio.Queue view onto a shared disabledTun. Each queue +// owns a private batchRet so concurrent Read calls from different reader +// goroutines do not race on the returned slice. +type disabledQueue struct { + parent *disabledTun + batchRet [1]tio.Packet +} + +func (q *disabledQueue) Read() ([]tio.Packet, error) { + r, ok := <-q.parent.read + if !ok { + return nil, io.EOF + } + + q.parent.tx.Inc(1) + if q.parent.l.Enabled(context.Background(), slog.LevelDebug) { + q.parent.l.Debug("Write payload", "raw", prettyPacket(r)) + } + + q.batchRet[0] = tio.Packet{Bytes: r} + return q.batchRet[:], nil +} + +// Write on a queue forwards to the underlying disabledTun. All queues share +// one ICMP-handling/log path so this is a thin pass-through. +func (q *disabledQueue) Write(b []byte) (int, error) { + return q.parent.Write(b) +} + +// Close on a queue is a no-op. The shared channel and metrics are owned by +// the disabledTun; Close on the device tears them down once for everybody. +func (q *disabledQueue) Close() error { + return nil } func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *slog.Logger) *disabledTun { @@ -28,6 +65,7 @@ func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled boo vpnNetworks: vpnNetworks, read: make(chan []byte, queueLen), l: l, + numReaders: 1, } if metricsEnabled { @@ -57,24 +95,6 @@ func (*disabledTun) Name() string { return "disabled" } -func (t *disabledTun) Read(b []byte) (int, error) { - r, ok := <-t.read - if !ok { - return 0, io.EOF - } - - if len(r) > len(b) { - return 0, fmt.Errorf("packet larger than mtu: %d > %d bytes", len(r), len(b)) - } - - t.tx.Inc(1) - if t.l.Enabled(context.Background(), slog.LevelDebug) { - t.l.Debug("Write payload", "raw", prettyPacket(r)) - } - - return copy(b, r), nil -} - func (t *disabledTun) handleICMPEchoRequest(b []byte) bool { out := make([]byte, len(b)) out = iputil.CreateICMPEchoResponse(b, out) @@ -110,8 +130,17 @@ func (t *disabledTun) SupportsMultiqueue() bool { return true } -func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return t, nil +func (t *disabledTun) NewMultiQueueReader() error { + t.numReaders++ + return nil +} + +func (t *disabledTun) Readers() []tio.Queue { + out := make([]tio.Queue, t.numReaders) + for i := range t.numReaders { + out[i] = &disabledQueue{parent: t} + } + return out } func (t *disabledTun) Close() error { diff --git a/overlay/tun_file_linux_test.go b/overlay/tun_file_linux_test.go deleted file mode 100644 index 5ab87e05..00000000 --- a/overlay/tun_file_linux_test.go +++ /dev/null @@ -1,120 +0,0 @@ -//go:build linux && !android && !e2e_testing -// +build linux,!android,!e2e_testing - -package overlay - -import ( - "errors" - "os" - "sync" - "testing" - "time" - - "golang.org/x/sys/unix" -) - -// newReadPipe returns a read fd. The matching write fd is registered for cleanup. -// The caller takes ownership of the read fd (pass it to newTunFd / newFriend). -func newReadPipe(t *testing.T) int { - t.Helper() - var fds [2]int - if err := unix.Pipe2(fds[:], unix.O_CLOEXEC); err != nil { - t.Fatalf("pipe2: %v", err) - } - t.Cleanup(func() { _ = unix.Close(fds[1]) }) - return fds[0] -} - -func TestTunFile_WakeForShutdown_UnblocksRead(t *testing.T) { - tf, err := newTunFd(newReadPipe(t)) - if err != nil { - t.Fatalf("newTunFd: %v", err) - } - t.Cleanup(func() { _ = tf.Close() }) - - done := make(chan error, 1) - go func() { - _, err := tf.Read(make([]byte, 64)) - done <- err - }() - - // Verify Read is actually blocked in poll. - select { - case err := <-done: - t.Fatalf("Read returned before shutdown signal: %v", err) - case <-time.After(50 * time.Millisecond): - } - - if err := tf.wakeForShutdown(); err != nil { - t.Fatalf("wakeForShutdown: %v", err) - } - - select { - case err := <-done: - if !errors.Is(err, os.ErrClosed) { - t.Fatalf("expected os.ErrClosed, got %v", err) - } - case <-time.After(2 * time.Second): - t.Fatal("Read did not wake on shutdown") - } -} - -func TestTunFile_WakeForShutdown_WakesFriends(t *testing.T) { - parent, err := newTunFd(newReadPipe(t)) - if err != nil { - t.Fatalf("newTunFd: %v", err) - } - friend, err := parent.newFriend(newReadPipe(t)) - if err != nil { - _ = parent.Close() - t.Fatalf("newFriend: %v", err) - } - t.Cleanup(func() { - _ = friend.Close() - _ = parent.Close() - }) - - readers := []*tunFile{parent, friend} - errs := make([]error, len(readers)) - var wg sync.WaitGroup - for i, r := range readers { - wg.Add(1) - go func(i int, r *tunFile) { - defer wg.Done() - _, errs[i] = r.Read(make([]byte, 64)) - }(i, r) - } - - time.Sleep(50 * time.Millisecond) - - if err := parent.wakeForShutdown(); err != nil { - t.Fatalf("wakeForShutdown: %v", err) - } - - done := make(chan struct{}) - go func() { wg.Wait(); close(done) }() - select { - case <-done: - case <-time.After(2 * time.Second): - t.Fatal("readers did not wake") - } - - for i, err := range errs { - if !errors.Is(err, os.ErrClosed) { - t.Errorf("reader %d: expected os.ErrClosed, got %v", i, err) - } - } -} - -func TestTunFile_Close_Idempotent(t *testing.T) { - tf, err := newTunFd(newReadPipe(t)) - if err != nil { - t.Fatalf("newTunFd: %v", err) - } - if err := tf.Close(); err != nil { - t.Fatalf("first Close: %v", err) - } - if err := tf.Close(); err != nil { - t.Fatalf("second Close should be a no-op, got %v", err) - } -} diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 3d995553..71784ad7 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -7,7 +7,6 @@ import ( "bytes" "errors" "fmt" - "io" "io/fs" "log/slog" "net/netip" @@ -20,7 +19,7 @@ import ( "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" - + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" @@ -103,6 +102,9 @@ type tun struct { readPoll [2]unix.PollFd writePoll [2]unix.PollFd closed atomic.Bool + + readBuf []byte + batchRet [1]tio.Packet } // blockOnRead waits until the tun fd is readable or shutdown has been signaled. @@ -157,7 +159,16 @@ func (t *tun) blockOnWrite() error { return nil } -func (t *tun) Read(to []byte) (int, error) { +func (t *tun) Read() ([]tio.Packet, error) { + n, err := t.readOne(t.readBuf) + if err != nil { + return nil, err + } + t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]} + return t.batchRet[:], nil +} + +func (t *tun) readOne(to []byte) (int, error) { // first 4 bytes is protocol family, in network byte order var head [4]byte iovecs := [2]syscall.Iovec{ @@ -375,6 +386,7 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, fd: fd, + readBuf: make([]byte, defaultBatchBufSize), shutdownR: shutdownR, shutdownW: shutdownW, readPoll: [2]unix.PollFd{ @@ -565,8 +577,8 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd") +func (t *tun) NewMultiQueueReader() error { + return fmt.Errorf("TODO: multiqueue not implemented for freebsd") } func (t *tun) addRoutes(logErrors bool) error { @@ -593,6 +605,10 @@ func (t *tun) addRoutes(logErrors bool) error { return nil } +func (t *tun) Readers() []tio.Queue { + return []tio.Queue{t} +} + func (t *tun) removeRoutes(routes []Route) error { for _, r := range routes { if !r.Install { diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 6bfcbdfb..2c332e06 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -16,16 +16,37 @@ import ( "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) type tun struct { - io.ReadWriteCloser + rwc io.ReadWriteCloser vpnNetworks []netip.Prefix Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *slog.Logger + + readBuf []byte + batchRet [1]tio.Packet +} + +func (t *tun) Read() ([]tio.Packet, error) { + n, err := t.rwc.Read(t.readBuf) + if err != nil { + return nil, err + } + t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]} + return t.batchRet[:], nil +} + +func (t *tun) Write(p []byte) (int, error) { + return t.rwc.Write(p) +} + +func (t *tun) Close() error { + return t.rwc.Close() } func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) { @@ -35,9 +56,10 @@ func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/tun") t := &tun{ - vpnNetworks: vpnNetworks, - ReadWriteCloser: &tunReadCloser{f: file}, - l: l, + vpnNetworks: vpnNetworks, + rwc: &tunReadCloser{f: file}, + l: l, + readBuf: make([]byte, defaultBatchBufSize), } err := t.reload(c, true) @@ -155,6 +177,10 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, fmt.Errorf("TODO: multiqueue not implemented for ios") +func (t *tun) NewMultiQueueReader() error { + return fmt.Errorf("TODO: multiqueue not implemented for ios") +} + +func (t *tun) Readers() []tio.Queue { + return []tio.Queue{t} } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index c6cfb686..3cf2d70c 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -4,9 +4,7 @@ package overlay import ( - "encoding/binary" "fmt" - "io" "log/slog" "net" "net/netip" @@ -19,180 +17,15 @@ import ( "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" "github.com/vishvananda/netlink" "golang.org/x/sys/unix" ) -// tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking. -// A shared eventfd allows Close to wake all readers blocked in poll. -type tunFile struct { - fd int - shutdownFd int - lastOne bool - readPoll [2]unix.PollFd - writePoll [2]unix.PollFd - closed bool -} - -// newFriend makes a tunFile for a MultiQueueReader that copies the shutdown eventfd from the parent tun -func (r *tunFile) newFriend(fd int) (*tunFile, error) { - if err := unix.SetNonblock(fd, true); err != nil { - return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) - } - return &tunFile{ - fd: fd, - shutdownFd: r.shutdownFd, - readPoll: [2]unix.PollFd{ - {Fd: int32(fd), Events: unix.POLLIN}, - {Fd: int32(r.shutdownFd), Events: unix.POLLIN}, - }, - writePoll: [2]unix.PollFd{ - {Fd: int32(fd), Events: unix.POLLOUT}, - {Fd: int32(r.shutdownFd), Events: unix.POLLIN}, - }, - }, nil -} - -func newTunFd(fd int) (*tunFile, error) { - if err := unix.SetNonblock(fd, true); err != nil { - return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) - } - - shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC) - if err != nil { - return nil, fmt.Errorf("failed to create eventfd: %w", err) - } - - out := &tunFile{ - fd: fd, - shutdownFd: shutdownFd, - lastOne: true, - readPoll: [2]unix.PollFd{ - {Fd: int32(fd), Events: unix.POLLIN}, - {Fd: int32(shutdownFd), Events: unix.POLLIN}, - }, - writePoll: [2]unix.PollFd{ - {Fd: int32(fd), Events: unix.POLLOUT}, - {Fd: int32(shutdownFd), Events: unix.POLLIN}, - }, - } - - return out, nil -} - -func (r *tunFile) blockOnRead() error { - const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR - var err error - for { - _, err = unix.Poll(r.readPoll[:], -1) - if err != unix.EINTR { - break - } - } - //always reset these! - tunEvents := r.readPoll[0].Revents - shutdownEvents := r.readPoll[1].Revents - r.readPoll[0].Revents = 0 - r.readPoll[1].Revents = 0 - //do the err check before trusting the potentially bogus bits we just got - if err != nil { - return err - } - if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { - return os.ErrClosed - } else if tunEvents&problemFlags != 0 { - return os.ErrClosed - } - return nil -} - -func (r *tunFile) blockOnWrite() error { - const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR - var err error - for { - _, err = unix.Poll(r.writePoll[:], -1) - if err != unix.EINTR { - break - } - } - //always reset these! - tunEvents := r.writePoll[0].Revents - shutdownEvents := r.writePoll[1].Revents - r.writePoll[0].Revents = 0 - r.writePoll[1].Revents = 0 - //do the err check before trusting the potentially bogus bits we just got - if err != nil { - return err - } - if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { - return os.ErrClosed - } else if tunEvents&problemFlags != 0 { - return os.ErrClosed - } - return nil -} - -func (r *tunFile) Read(buf []byte) (int, error) { - for { - if n, err := unix.Read(r.fd, buf); err == nil { - return n, nil - } else if err == unix.EAGAIN { - if err = r.blockOnRead(); err != nil { - return 0, err - } - continue - } else if err == unix.EINTR { - continue - } else if err == unix.EBADF { - return 0, os.ErrClosed - } else { - return 0, err - } - } -} - -func (r *tunFile) Write(buf []byte) (int, error) { - for { - if n, err := unix.Write(r.fd, buf); err == nil { - return n, nil - } else if err == unix.EAGAIN { - if err = r.blockOnWrite(); err != nil { - return 0, err - } - continue - } else if err == unix.EINTR { - continue - } else if err == unix.EBADF { - return 0, os.ErrClosed - } else { - return 0, err - } - } -} - -func (r *tunFile) wakeForShutdown() error { - var buf [8]byte - binary.NativeEndian.PutUint64(buf[:], 1) - _, err := unix.Write(int(r.readPoll[1].Fd), buf[:]) - return err -} - -func (r *tunFile) Close() error { - if r.closed { // avoid closing more than once. Technically a fd could get re-used, which would be a problem - return nil - } - r.closed = true - if r.lastOne { - _ = unix.Close(r.shutdownFd) - } - return unix.Close(r.fd) -} - type tun struct { - *tunFile - readers []*tunFile + readers tio.QueueSet closeLock sync.Mutex Device string vpnNetworks []netip.Prefix @@ -239,7 +72,9 @@ type ifreqQLEN struct { } func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { - t, err := newTunGeneric(c, l, deviceFd, vpnNetworks) + // We don't know what flags the caller opened this fd with and can't turn + // on IFF_VNET_HDR after TUNSETIFF, so skip offload on inherited fds. + t, err := newTunGeneric(c, l, deviceFd, false, false, vpnNetworks) if err != nil { return nil, err } @@ -249,46 +84,65 @@ func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip return t, nil } -func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { +// openTunDev opens /dev/net/tun, creating the device node first if it's +// missing (docker containers occasionally omit it). +func openTunDev() (int, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) - if err != nil { - // If /dev/net/tun doesn't exist, try to create it (will happen in docker) - if os.IsNotExist(err) { - err = os.MkdirAll("/dev/net", 0755) - if err != nil { - return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err) - } - err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))) - if err != nil { - return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err) - } - - fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0) - if err != nil { - return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err) - } - } else { - return nil, err - } + if err == nil { + return fd, nil } + if !os.IsNotExist(err) { + return -1, err + } + if err = os.MkdirAll("/dev/net", 0755); err != nil { + return -1, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err) + } + if err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))); err != nil { + return -1, fmt.Errorf("failed to create /dev/net/tun: %w", err) + } + fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0) + if err != nil { + return -1, fmt.Errorf("created /dev/net/tun, but still failed: %w", err) + } + return fd, nil +} +// tunSetIff runs TUNSETIFF with the given flags and returns the kernel-chosen +// device name on success. +func tunSetIff(fd int, name string, flags uint16) (string, error) { var req ifReq - req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI) + req.Flags = flags + copy(req.Name[:], name) + if err := ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + return "", err + } + return strings.Trim(string(req.Name[:]), "\x00"), nil +} + +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { + baseFlags := uint16(unix.IFF_TUN | unix.IFF_NO_PI) if multiqueue { - req.Flags |= unix.IFF_MULTI_QUEUE + baseFlags |= unix.IFF_MULTI_QUEUE } nameStr := c.GetString("tun.dev", "") - copy(req.Name[:], nameStr) - if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { - _ = unix.Close(fd) - return nil, &NameError{ - Name: nameStr, - Underlying: err, - } - } - name := strings.Trim(string(req.Name[:]), "\x00") - t, err := newTunGeneric(c, l, fd, vpnNetworks) + // First try to enable IFF_VNET_HDR via TUNSETIFF and negotiate TUN_F_* + // offloads via TUNSETOFFLOAD so we can receive TSO/USO superpackets. + // We try TSO+USO first, fall back to TSO-only on kernels without USO + // (Linux < 6.2), and finally give up on virtio headers entirely and + // reopen as a plain TUN if neither offload mask is accepted. + fd, err := openTunDev() + if err != nil { + return nil, err + } + + name, err := tunSetIff(fd, nameStr, baseFlags) + if err != nil { + _ = unix.Close(fd) + return nil, &NameError{Name: nameStr, Underlying: err} + } + + t, err := newTunGeneric(c, l, fd, false, false, vpnNetworks) if err != nil { return nil, err } @@ -299,15 +153,21 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue } // newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error. -func newTunGeneric(c *config.C, l *slog.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) { - tfd, err := newTunFd(fd) +func newTunGeneric(c *config.C, l *slog.Logger, fd int, vnetHdr, usoEnabled bool, vpnNetworks []netip.Prefix) (*tun, error) { + qs, err := tio.NewPollQueueSet() + if err != nil { _ = unix.Close(fd) return nil, err } + err = qs.Add(fd) + if err != nil { + _ = unix.Close(fd) + return nil, err + } + t := &tun{ - tunFile: tfd, - readers: []*tunFile{tfd}, + readers: qs, closeLock: sync.Mutex{}, vpnNetworks: vpnNetworks, TXQueueLen: c.GetInt("tun.tx_queue", 500), @@ -410,32 +270,29 @@ func (t *tun) SupportsMultiqueue() bool { return true } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *tun) NewMultiQueueReader() error { t.closeLock.Lock() defer t.closeLock.Unlock() fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { - return nil, err + return err } - var req ifReq - req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) - copy(req.Name[:], t.Device) - if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + flags := uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) + + if _, err = tunSetIff(fd, t.Device, flags); err != nil { _ = unix.Close(fd) - return nil, err + return err } - out, err := t.tunFile.newFriend(fd) + err = t.readers.Add(fd) if err != nil { _ = unix.Close(fd) - return nil, err + return err } - t.readers = append(t.readers, out) - - return out, nil + return nil } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { @@ -603,6 +460,15 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error { Table: unix.RT_TABLE_MAIN, Type: unix.RTN_UNICAST, } + // Match the metric the kernel uses for its auto-installed connected + // route, so RouteReplace overwrites it in place instead of adding a + // second route at a worse metric. IPv6 connected routes are installed + // at metric 256 (IP6_RT_PRIO_KERN); IPv4 uses 0. Without this, the + // kernel route wins lookups and our MTU / AdvMSS / Features never + // apply on v6. + if cidr.Addr().Is6() { + nr.Priority = 256 + } err := netlink.RouteReplace(&nr) if err != nil { t.l.Warn("Failed to set default route MTU, retrying", "error", err, "cidr", cidr) @@ -869,6 +735,10 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { t.routeTree.Store(newTree) } +func (t *tun) Readers() []tio.Queue { + return t.readers.Queues() +} + func (t *tun) Close() error { t.closeLock.Lock() defer t.closeLock.Unlock() @@ -878,32 +748,10 @@ func (t *tun) Close() error { t.routeChan = nil } - // Signal all readers blocked in poll to wake up and exit - _ = t.tunFile.wakeForShutdown() - if t.ioctlFd > 0 { _ = unix.Close(int(t.ioctlFd)) t.ioctlFd = 0 } - for i := range t.readers { - if i == 0 { - continue //we want to close the zeroth reader last - } - err := t.readers[i].Close() - if err != nil { - t.l.Error("error closing tun reader", "reader", i, "error", err) - } else { - t.l.Info("closed tun reader", "reader", i) - } - } - - //this is t.readers[0] too - err := t.tunFile.Close() - if err != nil { - t.l.Error("error closing tun reader", "reader", 0, "error", err) - } else { - t.l.Info("closed tun reader", "reader", 0) - } - return err + return t.readers.Close() } diff --git a/overlay/tun_linux_test.go b/overlay/tun_linux_test.go index 1c1842da..1003a165 100644 --- a/overlay/tun_linux_test.go +++ b/overlay/tun_linux_test.go @@ -3,7 +3,9 @@ package overlay -import "testing" +import ( + "testing" +) var runAdvMSSTests = []struct { name string diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index c971bb6e..e8678959 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -6,7 +6,6 @@ package overlay import ( "errors" "fmt" - "io" "log/slog" "net/netip" "os" @@ -17,6 +16,7 @@ import ( "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" @@ -66,6 +66,22 @@ type tun struct { l *slog.Logger f *os.File fd int + + readBuf []byte + batchRet [1]tio.Packet +} + +func (t *tun) Read() ([]tio.Packet, error) { + n, err := t.readOne(t.readBuf) + if err != nil { + return nil, err + } + t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]} + return t.batchRet[:], nil +} + +func (t *tun) Readers() []tio.Queue { + return []tio.Queue{t} } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) @@ -102,6 +118,7 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, + readBuf: make([]byte, defaultBatchBufSize), } err = t.reload(c, true) @@ -141,7 +158,7 @@ func (t *tun) Close() error { return nil } -func (t *tun) Read(to []byte) (int, error) { +func (t *tun) readOne(to []byte) (int, error) { rc, err := t.f.SyscallConn() if err != nil { return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err) @@ -394,8 +411,8 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd") +func (t *tun) NewMultiQueueReader() error { + return fmt.Errorf("TODO: multiqueue not implemented for netbsd") } func (t *tun) addRoutes(logErrors bool) error { diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 81362184..0e754732 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -6,7 +6,6 @@ package overlay import ( "errors" "fmt" - "io" "log/slog" "net/netip" "os" @@ -17,6 +16,7 @@ import ( "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" @@ -59,6 +59,18 @@ type tun struct { fd int // cache out buffer since we need to prepend 4 bytes for tun metadata out []byte + + readBuf []byte + batchRet [1]tio.Packet +} + +func (t *tun) Read() ([]tio.Packet, error) { + n, err := t.readOne(t.readBuf) + if err != nil { + return nil, err + } + t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]} + return t.batchRet[:], nil } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) @@ -95,6 +107,7 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, + readBuf: make([]byte, defaultBatchBufSize), } err = t.reload(c, true) @@ -124,7 +137,7 @@ func (t *tun) Close() error { return nil } -func (t *tun) Read(to []byte) (int, error) { +func (t *tun) readOne(to []byte) (int, error) { buf := make([]byte, len(to)+4) n, err := t.f.Read(buf) @@ -314,8 +327,8 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd") +func (t *tun) NewMultiQueueReader() error { + return fmt.Errorf("TODO: multiqueue not implemented for openbsd") } func (t *tun) addRoutes(logErrors bool) error { @@ -366,6 +379,10 @@ func (t *tun) deviceBytes() (o [16]byte) { return } +func (t *tun) Readers() []tio.Queue { + return []tio.Queue{t} +} + func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error { sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) if err != nil { diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 8acd83f0..898adc23 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -14,6 +14,7 @@ import ( "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/udp" ) @@ -28,6 +29,8 @@ type TestTun struct { closed atomic.Bool rxPackets chan []byte // Packets to receive into nebula TxPackets chan []byte // Packets transmitted outside by nebula + + batchRet [1]tio.Packet } func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { @@ -48,6 +51,9 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*T l: l, rxPackets: make(chan []byte, 10), TxPackets: make(chan []byte, 10), + batchRet: [1]tio.Packet{ + tio.Packet{Bytes: make([]byte, udp.MTU)}, + }, }, nil } @@ -162,7 +168,17 @@ func (t *TestTun) Close() error { return nil } -func (t *TestTun) Read(b []byte) (int, error) { +func (t *TestTun) Read() ([]tio.Packet, error) { + t.batchRet[0].Bytes = t.batchRet[0].Bytes[:udp.MTU] + n, err := t.read(t.batchRet[0].Bytes) + if err != nil { + return nil, err + } + t.batchRet[0].Bytes = t.batchRet[0].Bytes[:n] + return t.batchRet[:], nil +} + +func (t *TestTun) read(b []byte) (int, error) { p, ok := <-t.rxPackets if !ok { return 0, os.ErrClosed @@ -177,10 +193,14 @@ func (t *TestTun) Read(b []byte) (int, error) { return n, nil } +func (t *TestTun) Readers() []tio.Queue { + return []tio.Queue{t} +} + func (t *TestTun) SupportsMultiqueue() bool { return false } -func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, fmt.Errorf("TODO: multiqueue not implemented") +func (t *TestTun) NewMultiQueueReader() error { + return fmt.Errorf("TODO: multiqueue not implemented") } diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index cf01615f..a5ee063c 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -6,7 +6,6 @@ package overlay import ( "crypto" "fmt" - "io" "log/slog" "net/netip" "os" @@ -18,6 +17,7 @@ import ( "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" "github.com/slackhq/nebula/wintun" @@ -45,6 +45,18 @@ type winTun struct { l *slog.Logger tun *wintun.NativeTun + + readBuf []byte + batchRet [1]tio.Packet +} + +func (t *winTun) Read() ([]tio.Packet, error) { + n, err := t.tun.Read(t.readBuf, 0) + if err != nil { + return nil, err + } + t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]} + return t.batchRet[:], nil } func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (Device, error) { @@ -69,6 +81,7 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*w } t := &winTun{ + readBuf: make([]byte, defaultBatchBufSize), Device: deviceName, vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), @@ -255,10 +268,6 @@ func (t *winTun) Name() string { return t.Device } -func (t *winTun) Read(b []byte) (int, error) { - return t.tun.Read(b, 0) -} - func (t *winTun) Write(b []byte) (int, error) { return t.tun.Write(b, 0) } @@ -267,8 +276,12 @@ func (t *winTun) SupportsMultiqueue() bool { return false } -func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") +func (t *winTun) NewMultiQueueReader() error { + return fmt.Errorf("TODO: multiqueue not implemented for windows") +} + +func (t *winTun) Readers() []tio.Queue { + return []tio.Queue{t} } func (t *winTun) Close() error { diff --git a/overlay/user.go b/overlay/user.go index e5f27f37..3128ebf0 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -1,11 +1,13 @@ package overlay import ( + "errors" "io" "log/slog" "net/netip" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" ) @@ -28,12 +30,28 @@ func NewUserDevice(vpnNetworks []netip.Prefix) (Device, error) { type UserDevice struct { vpnNetworks []netip.Prefix + numReaders int outboundReader *io.PipeReader outboundWriter *io.PipeWriter inboundReader *io.PipeReader inboundWriter *io.PipeWriter + + readBuf []byte + batchRet [1]tio.Packet +} + +func (d *UserDevice) Read() ([]tio.Packet, error) { + if d.readBuf == nil { + d.readBuf = make([]byte, defaultBatchBufSize) + } + n, err := d.outboundReader.Read(d.readBuf) + if err != nil { + return nil, err + } + d.batchRet[0] = tio.Packet{Bytes: d.readBuf[:n]} + return d.batchRet[:], nil } func (d *UserDevice) Activate() error { @@ -47,23 +65,25 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways { } func (d *UserDevice) SupportsMultiqueue() bool { - return true + return false } -func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return d, nil +func (d *UserDevice) NewMultiQueueReader() error { + return errors.New("not implemented") +} + +func (d *UserDevice) Readers() []tio.Queue { + return []tio.Queue{d} } func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) { return d.inboundReader, d.outboundWriter } -func (d *UserDevice) Read(p []byte) (n int, err error) { - return d.outboundReader.Read(p) -} func (d *UserDevice) Write(p []byte) (n int, err error) { return d.inboundWriter.Write(p) } + func (d *UserDevice) Close() error { d.inboundWriter.Close() d.outboundWriter.Close() diff --git a/udp/conn.go b/udp/conn.go index 30d89dec..37277054 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -8,16 +8,49 @@ import ( const MTU = 9001 +// MaxWriteBatch is the largest batch any Conn.WriteBatch implementation is +// required to accept. Callers SHOULD NOT pass more than this per call; Linux +// backends preallocate sendmmsg scratch sized to this value, so exceeding it +// only costs additional sendmmsg chunks within a single WriteBatch call. +const MaxWriteBatch = 128 + +// RxMeta carries per-packet metadata extracted from the RX path (ancillary +// data, kernel offload state, etc.) and passed to EncReader callbacks. +// Backends that do not produce a particular signal leave its zero value. +// +// OuterECN is the 2-bit IP-level ECN codepoint stamped on the carrier +// datagram (extracted from IP_TOS / IPV6_TCLASS cmsg on Linux). Zero +// means Not-ECT, which is also the value backends without ECN RX support +// supply on every packet. +type RxMeta struct { + OuterECN byte +} + type EncReader func( addr netip.AddrPort, payload []byte, + meta RxMeta, ) type Conn interface { Rebind() error LocalAddr() (netip.AddrPort, error) - ListenOut(r EncReader) error + // ListenOut invokes r for each received packet. On batch-capable + // backends (recvmmsg), flush is called after each batch is fully + // delivered — callers use it to flush per-batch accumulators such as + // TUN write coalescers. Single-packet backends call flush after each + // packet. flush must not be nil. + ListenOut(r EncReader, flush func()) error WriteTo(b []byte, addr netip.AddrPort) error + // WriteBatch sends a contiguous batch of packets, each with its own + // destination. bufs and addrs must have the same length. outerECNs may + // be nil (treated as all-zero / Not-ECT); when non-nil it must have the + // same length as bufs, and outerECNs[i] is the 2-bit IP-level ECN + // codepoint to set on packet i's outer header. Linux uses sendmmsg(2) + // for a single syscall and attaches the value as IP_TOS / IPV6_TCLASS + // cmsg; other backends ignore it. Returns on the first error; callers + // may observe a partial send if some packets went out before the error. + WriteBatch(bufs [][]byte, addrs []netip.AddrPort, outerECNs []byte) error ReloadConfig(c *config.C) SupportsMultipleReaders() bool Close() error @@ -31,7 +64,7 @@ func (NoopConn) Rebind() error { func (NoopConn) LocalAddr() (netip.AddrPort, error) { return netip.AddrPort{}, nil } -func (NoopConn) ListenOut(_ EncReader) error { +func (NoopConn) ListenOut(_ EncReader, _ func()) error { return nil } func (NoopConn) SupportsMultipleReaders() bool { @@ -40,6 +73,9 @@ func (NoopConn) SupportsMultipleReaders() bool { func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { return nil } +func (NoopConn) WriteBatch(_ [][]byte, _ []netip.AddrPort, _ []byte) error { + return nil +} func (NoopConn) ReloadConfig(_ *config.C) { return } diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 8a4f5b18..e6ecea8f 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -140,6 +140,15 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error { } } +func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error { + for i, b := range bufs { + if err := u.WriteTo(b, addrs[i]); err != nil { + return err + } + } + return nil +} + func (u *StdConn) LocalAddr() (netip.AddrPort, error) { a := u.UDPConn.LocalAddr() @@ -165,7 +174,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() { return func() {} } -func (u *StdConn) ListenOut(r EncReader) error { +func (u *StdConn) ListenOut(r EncReader, flush func()) error { buffer := make([]byte, MTU) for { @@ -179,7 +188,8 @@ func (u *StdConn) ListenOut(r EncReader) error { u.l.Error("unexpected udp socket receive error", "error", err) } - r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) + r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n], RxMeta{}) + flush() } } diff --git a/udp/udp_generic.go b/udp/udp_generic.go index 131eb73b..0c254906 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -44,6 +44,15 @@ func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error { return err } +func (u *GenericConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error { + for i, b := range bufs { + if _, err := u.UDPConn.WriteToUDPAddrPort(b, addrs[i]); err != nil { + return err + } + } + return nil +} + func (u *GenericConn) LocalAddr() (netip.AddrPort, error) { a := u.UDPConn.LocalAddr() @@ -73,7 +82,7 @@ type rawMessage struct { Len uint32 } -func (u *GenericConn) ListenOut(r EncReader) error { +func (u *GenericConn) ListenOut(r EncReader, flush func()) error { buffer := make([]byte, MTU) var lastRecvErr time.Time @@ -93,7 +102,8 @@ func (u *GenericConn) ListenOut(r EncReader) error { continue } - r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) + r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n], RxMeta{}) + flush() } } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 3e2d726a..6465be32 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -24,6 +24,22 @@ type StdConn struct { isV4 bool l *slog.Logger batch int + + // sendmmsg scratch. Each queue has its own StdConn, so no locking is + // needed. Sized to MaxWriteBatch at construction; WriteBatch chunks + // larger inputs. + writeMsgs []rawMessage + writeIovs []iovec + writeNames [][]byte + + // sendmmsg(2) callback state. sendmmsgCB is bound once in NewListener + // to the sendmmsgRun method value so passing it to rawConn.Write does + // not allocate a fresh closure per send; sendmmsgN/Sent/Errno carry + // the inputs and outputs across the call without escaping locals. + sendmmsgCB func(fd uintptr) bool + sendmmsgN int + sendmmsgSent int + sendmmsgErrno syscall.Errno } func setReusePort(network, address string, c syscall.RawConn) error { @@ -70,9 +86,23 @@ func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) } out.isV4 = af == unix.AF_INET + out.prepareWriteMessages(MaxWriteBatch) + out.sendmmsgCB = out.sendmmsgRun + return out, nil } +func (u *StdConn) prepareWriteMessages(n int) { + u.writeMsgs = make([]rawMessage, n) + u.writeIovs = make([]iovec, n) + u.writeNames = make([][]byte, n) + + for i := range u.writeMsgs { + u.writeNames[i] = make([]byte, unix.SizeofSockaddrInet6) + u.writeMsgs[i].Hdr.Name = &u.writeNames[i][0] + } +} + func (u *StdConn) SupportsMultipleReaders() bool { return true } @@ -171,7 +201,7 @@ func recvmmsg(fd uintptr, msgs []rawMessage) (int, bool, error) { return int(n), true, nil } -func (u *StdConn) listenOutSingle(r EncReader) error { +func (u *StdConn) listenOutSingle(r EncReader, flush func()) error { var err error var n int var from netip.AddrPort @@ -183,16 +213,33 @@ func (u *StdConn) listenOutSingle(r EncReader) error { return err } from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port()) - r(from, buffer[:n]) + // listenOutSingle uses ReadFromUDPAddrPort which discards cmsgs, + // so the outer ECN field is not visible on this path. Zero RxMeta + // (Not-ECT) means RFC 6040 combine is a no-op. + r(from, buffer[:n], RxMeta{}) + flush() } } -func (u *StdConn) listenOutBatch(r EncReader) error { +// readSockaddr decodes the source address out of a recvmmsg name buffer +func (u *StdConn) readSockaddr(name []byte) netip.AddrPort { var ip netip.Addr + // It's ok to skip the ok check here, the slicing is the only error that can occur and it will panic + if u.isV4 { + ip, _ = netip.AddrFromSlice(name[4:8]) + } else { + ip, _ = netip.AddrFromSlice(name[8:24]) + } + return netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(name[2:4])) +} + +func (u *StdConn) listenOutBatch(r EncReader, flush func()) error { var n int var operr error - msgs, buffers, names := u.PrepareRawMessages(u.batch) + bufSize := MTU + cmsgSpace := 0 + msgs, buffers, names, _ := u.PrepareRawMessages(u.batch, bufSize, cmsgSpace) //reader needs to capture variables from this function, since it's used as a lambda with rawConn.Read //defining it outside the loop so it gets re-used @@ -211,22 +258,18 @@ func (u *StdConn) listenOutBatch(r EncReader) error { } for i := 0; i < n; i++ { - // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic - if u.isV4 { - ip, _ = netip.AddrFromSlice(names[i][4:8]) - } else { - ip, _ = netip.AddrFromSlice(names[i][8:24]) - } - r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) + r(u.readSockaddr(names[i]), buffers[i][:msgs[i].Len], RxMeta{}) } + + flush() } } -func (u *StdConn) ListenOut(r EncReader) error { +func (u *StdConn) ListenOut(r EncReader, flush func()) error { if u.batch == 1 { - return u.listenOutSingle(r) + return u.listenOutSingle(r, flush) } else { - return u.listenOutBatch(r) + return u.listenOutBatch(r, flush) } } @@ -235,6 +278,120 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { return err } +// WriteBatch sends bufs via sendmmsg(2) using the preallocated scratch on +// StdConn. If supported, consecutive packets to the same destination with +// matching segment sizes (all but possibly the last) are coalesced into a +// single mmsghdr entry +// +// If sendmmsg returns an error and zero entries went out, we fall back to +// per-packet WriteTo for that chunk so the caller still gets best-effort +// delivery. On a partial send we resume at the first un-acked entry on +// the next iteration. +func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error { + for i := 0; i < len(bufs); { + chunk := min(len(bufs)-i, len(u.writeMsgs)) + + for k := 0; k < chunk; k++ { + u.writeIovs[k].Base = &bufs[i+k][0] + setIovLen(&u.writeIovs[k], len(bufs[i+k])) + + nlen, err := writeSockaddr(u.writeNames[k], addrs[i+k], u.isV4) + if err != nil { + return err + } + + hdr := &u.writeMsgs[k].Hdr + hdr.Iov = &u.writeIovs[k] + setMsgIovlen(hdr, 1) + hdr.Namelen = uint32(nlen) + } + + sent, serr := u.sendmmsg(chunk) + if serr != nil && sent <= 0 { + // sendmmsg returns -1 / sent=0 when entry 0 itself failed; log + // that entry's destination and fall back to per-packet WriteTo + // for the whole chunk so the caller still gets best-effort + // delivery without duplicating packets the kernel accepted. + u.l.Warn("sendmmsg failed, falling back to per-packet WriteTo", + "err", serr, + "entries", chunk, + "entry0_dst", addrs[i], + "isV4", u.isV4, + ) + for k := 0; k < chunk; k++ { + if werr := u.WriteTo(bufs[i+k], addrs[i+k]); werr != nil { + return werr + } + } + i += chunk + continue + } + i += sent + } + return nil +} + +// sendmmsg issues sendmmsg(2) against the first n entries of u.writeMsgs. +// The bound u.sendmmsgCB is passed to rawConn.Write so no closure is +// allocated per call; inputs and outputs ride on the StdConn fields. +func (u *StdConn) sendmmsg(n int) (int, error) { + u.sendmmsgN = n + u.sendmmsgSent = 0 + u.sendmmsgErrno = 0 + if err := u.rawConn.Write(u.sendmmsgCB); err != nil { + return u.sendmmsgSent, err + } + if u.sendmmsgErrno != 0 { + return u.sendmmsgSent, &net.OpError{Op: "sendmmsg", Err: u.sendmmsgErrno} + } + return u.sendmmsgSent, nil +} + +// sendmmsgRun is the rawConn.Write callback. It is bound once into +// u.sendmmsgCB at construction so it stays alloc-free in the hot path; +// inputs (sendmmsgN) and outputs (sendmmsgSent, sendmmsgErrno) ride on +// the receiver rather than escaping locals. +func (u *StdConn) sendmmsgRun(fd uintptr) bool { + r1, _, errno := unix.Syscall6(unix.SYS_SENDMMSG, fd, + uintptr(unsafe.Pointer(&u.writeMsgs[0])), uintptr(u.sendmmsgN), + 0, 0, 0, + ) + if errno == syscall.EAGAIN || errno == syscall.EWOULDBLOCK { + return false + } + u.sendmmsgSent = int(r1) + u.sendmmsgErrno = errno + return true +} + +// writeSockaddr encodes addr into buf (which must be at least +// SizeofSockaddrInet6 bytes). Returns the number of bytes used. If isV4 is +// true and addr is not a v4 (or v4-in-v6) address, returns an error. +func writeSockaddr(buf []byte, addr netip.AddrPort, isV4 bool) (int, error) { + ap := addr.Addr().Unmap() + if isV4 { + if !ap.Is4() { + return 0, ErrInvalidIPv6RemoteForSocket + } + // struct sockaddr_in: { sa_family_t(2), in_port_t(2, BE), in_addr(4), zero(8) } + // sa_family is host endian. + binary.NativeEndian.PutUint16(buf[0:2], unix.AF_INET) + binary.BigEndian.PutUint16(buf[2:4], addr.Port()) + ip4 := ap.As4() + copy(buf[4:8], ip4[:]) + clear(buf[8:16]) + return unix.SizeofSockaddrInet4, nil + } + // struct sockaddr_in6: { sa_family_t(2), in_port_t(2, BE), flowinfo(4), in6_addr(16), scope_id(4) } + binary.NativeEndian.PutUint16(buf[0:2], unix.AF_INET6) + binary.BigEndian.PutUint16(buf[2:4], addr.Port()) + binary.NativeEndian.PutUint32(buf[4:8], 0) + ip6 := addr.Addr().As16() + copy(buf[8:24], ip6[:]) + binary.NativeEndian.PutUint32(buf[24:28], 0) + return unix.SizeofSockaddrInet6, nil +} + func (u *StdConn) ReloadConfig(c *config.C) { b := c.GetInt("listen.read_buffer", 0) if b > 0 { diff --git a/udp/udp_linux_32.go b/udp/udp_linux_32.go index de8f1cdf..0f153a49 100644 --- a/udp/udp_linux_32.go +++ b/udp/udp_linux_32.go @@ -30,13 +30,18 @@ type rawMessage struct { Len uint32 } -func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { +func (u *StdConn) PrepareRawMessages(n, bufSize, cmsgSpace int) ([]rawMessage, [][]byte, [][]byte, []byte) { msgs := make([]rawMessage, n) buffers := make([][]byte, n) names := make([][]byte, n) + var cmsgs []byte + if cmsgSpace > 0 { + cmsgs = make([]byte, n*cmsgSpace) + } + for i := range msgs { - buffers[i] = make([]byte, MTU) + buffers[i] = make([]byte, bufSize) names[i] = make([]byte, unix.SizeofSockaddrInet6) vs := []iovec{ @@ -48,7 +53,28 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { msgs[i].Hdr.Name = &names[i][0] msgs[i].Hdr.Namelen = uint32(len(names[i])) + + if cmsgSpace > 0 { + msgs[i].Hdr.Control = &cmsgs[i*cmsgSpace] + msgs[i].Hdr.Controllen = uint32(cmsgSpace) + } } - return msgs, buffers, names + return msgs, buffers, names, cmsgs +} + +func setIovLen(v *iovec, n int) { + v.Len = uint32(n) +} + +func setMsgIovlen(m *msghdr, n int) { + m.Iovlen = uint32(n) +} + +func setMsgControllen(m *msghdr, n int) { + m.Controllen = uint32(n) +} + +func setCmsgLen(h *unix.Cmsghdr, n int) { + h.Len = uint32(n) } diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 48c5a978..dc373538 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -33,13 +33,18 @@ type rawMessage struct { Pad0 [4]byte } -func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { +func (u *StdConn) PrepareRawMessages(n, bufSize, cmsgSpace int) ([]rawMessage, [][]byte, [][]byte, []byte) { msgs := make([]rawMessage, n) buffers := make([][]byte, n) names := make([][]byte, n) + var cmsgs []byte + if cmsgSpace > 0 { + cmsgs = make([]byte, n*cmsgSpace) + } + for i := range msgs { - buffers[i] = make([]byte, MTU) + buffers[i] = make([]byte, bufSize) names[i] = make([]byte, unix.SizeofSockaddrInet6) vs := []iovec{ @@ -51,7 +56,28 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { msgs[i].Hdr.Name = &names[i][0] msgs[i].Hdr.Namelen = uint32(len(names[i])) + + if cmsgSpace > 0 { + msgs[i].Hdr.Control = &cmsgs[i*cmsgSpace] + msgs[i].Hdr.Controllen = uint64(cmsgSpace) + } } - return msgs, buffers, names + return msgs, buffers, names, cmsgs +} + +func setIovLen(v *iovec, n int) { + v.Len = uint64(n) +} + +func setMsgIovlen(m *msghdr, n int) { + m.Iovlen = uint64(n) +} + +func setMsgControllen(m *msghdr, n int) { + m.Controllen = uint64(n) +} + +func setCmsgLen(h *unix.Cmsghdr, n int) { + h.Len = uint64(n) } diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index d110af19..a95ad3d0 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -140,7 +140,7 @@ func (u *RIOConn) bind(l *slog.Logger, sa windows.Sockaddr) error { return nil } -func (u *RIOConn) ListenOut(r EncReader) error { +func (u *RIOConn) ListenOut(r EncReader, flush func()) error { buffer := make([]byte, MTU) var lastRecvErr time.Time @@ -161,7 +161,8 @@ func (u *RIOConn) ListenOut(r EncReader) error { continue } - r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n]) + r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n], RxMeta{}) + flush() } } @@ -316,6 +317,15 @@ func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error { return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) } +func (u *RIOConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error { + for i, b := range bufs { + if err := u.WriteTo(b, addrs[i]); err != nil { + return err + } + } + return nil +} + func (u *RIOConn) LocalAddr() (netip.AddrPort, error) { sa, err := windows.Getsockname(u.sock) if err != nil { diff --git a/udp/udp_tester.go b/udp/udp_tester.go index f872e32a..6b877b71 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -157,15 +157,24 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { return nil } } +func (u *TesterConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error { + for i, b := range bufs { + if err := u.WriteTo(b, addrs[i]); err != nil { + return err + } + } + return nil +} -func (u *TesterConn) ListenOut(r EncReader) error { +func (u *TesterConn) ListenOut(r EncReader, flush func()) error { for { select { case <-u.done: return os.ErrClosed case p := <-u.RxPackets: - r(p.From, p.Data) + r(p.From, p.Data, RxMeta{}) p.Release() + flush() } } }