batched tun interface

This commit is contained in:
JackDoan
2026-04-17 10:25:05 -05:00
parent 9a30c5b6a1
commit afcdf2163b
20 changed files with 939 additions and 68 deletions

View File

@@ -974,6 +974,7 @@ func (hm *HandshakeManager) continueHandshake(via ViaSender, hh *HandshakeHostIn
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
for _, cp := range hh.packetStore { for _, cp := range hh.packetStore {
//todo use a sendbatcher
cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out) cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out)
} }
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore))) f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))

194
inside.go
View File

@@ -2,6 +2,7 @@ package nebula
import ( import (
"context" "context"
"io"
"log/slog" "log/slog"
"net/netip" "net/netip"
@@ -9,10 +10,16 @@ import (
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/noiseutil"
"github.com/slackhq/nebula/overlay/batch"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) consumeInsidePacket(pkt tio.Packet, fwPacket *firewall.Packet, nb []byte, sendBatch batch.TxBatcher, rejectBuf []byte, q int, localCache firewall.ConntrackCache) {
// borrowed: pkt.Bytes is owned by the originating tio.Queue and is
// only valid until the next Read on that queue. If you must keep
// the packet, use pkt.Clone() to detach it
packet := pkt.Bytes
err := newPacket(packet, false, fwPacket) err := newPacket(packet, false, fwPacket)
if err != nil { if err != nil {
if f.l.Enabled(context.Background(), slog.LevelDebug) { if f.l.Enabled(context.Background(), slog.LevelDebug) {
@@ -37,7 +44,10 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
// routes packets from the Nebula addr to the Nebula addr through the Nebula // routes packets from the Nebula addr to the Nebula addr through the Nebula
// TUN device. // TUN device.
if immediatelyForwardToSelf { if immediatelyForwardToSelf {
_, err := f.readers[q].Write(packet) err := tio.SegmentSuperpacket(pkt, func(seg []byte) error {
_, werr := f.readers[q].Write(seg)
return werr
})
if err != nil { if err != nil {
f.l.Error("Failed to forward to tun", "error", err) f.l.Error("Failed to forward to tun", "error", err)
} }
@@ -53,11 +63,23 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
} }
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) // borrowed: SegmentSuperpacket builds each segment in the kernel-supplied pkt
// bytes underneath. cachePacket explicitly copies its argument (handshake_manager.go cachePacket),
// so retaining segments past the loop is safe.
err := tio.SegmentSuperpacket(pkt, func(seg []byte) error {
hh.cachePacket(f.l, header.Message, 0, seg, f.sendMessageNow, f.cachedPacketMetrics)
return nil
})
if err != nil && f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Failed to segment superpacket for handshake cache",
"error", err,
"vpnAddr", fwPacket.RemoteAddr,
)
}
}) })
if hostinfo == nil { if hostinfo == nil {
f.rejectInside(packet, out, q) f.rejectInside(packet, rejectBuf, q)
if f.l.Enabled(context.Background(), slog.LevelDebug) { if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks", f.l.Debug("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks",
"vpnAddr", fwPacket.RemoteAddr, "vpnAddr", fwPacket.RemoteAddr,
@@ -73,10 +95,9 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason == nil { if dropReason == nil {
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) f.sendInsideMessage(hostinfo, pkt, nb, sendBatch)
} else { } else {
f.rejectInside(packet, out, q) f.rejectInside(packet, rejectBuf, q)
if f.l.Enabled(context.Background(), slog.LevelDebug) { if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("dropping outbound packet", hostinfo.logger(f.l).Debug("dropping outbound packet",
"fwPacket", fwPacket, "fwPacket", fwPacket,
@@ -86,6 +107,124 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
} }
} }
func (f *Interface) sendInsideEncrypt(hostinfo *HostInfo, ci *ConnectionState, seg, scratch, nb []byte) []byte {
if noiseutil.EncryptLockNeeded {
ci.writeLock.Lock()
}
c := ci.messageCounter.Add(1)
out := header.Encode(scratch, header.Version, header.Message, 0, hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo)
out, encErr := ci.eKey.EncryptDanger(out, out, seg, c, nb)
if noiseutil.EncryptLockNeeded {
ci.writeLock.Unlock()
}
if encErr != nil {
hostinfo.logger(f.l).Error("Failed to encrypt outgoing packet",
"error", encErr,
"udpAddr", hostinfo.remote,
"counter", c,
)
// Skip this segment; the rest of the superpacket can still
// go out — TCP will retransmit anything we drop here.
return nil
}
return out
}
// sendInsideMessage encrypts a firewall-approved inside packet (or every
// segment of a TSO/USO superpacket) into the caller's batch slot for
// later sendmmsg flush. Segmentation is fused with encryption here so the
// kernel-supplied superpacket bytes never get written into a separate
// scratch arena: SegmentSuperpacket builds each segment's plaintext in
// segScratch[:segLen] in turn, and we encrypt directly into a fresh
// SendBatch slot.
func (f *Interface) sendInsideMessage(hostinfo *HostInfo, pkt tio.Packet, nb []byte, sendBatch batch.TxBatcher) {
ci := hostinfo.ConnectionState
if ci.eKey == nil {
return
}
if hostinfo.lastRebindCount != f.rebindCount {
//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
hostinfo.lastRebindCount = f.rebindCount
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("Lighthouse update triggered for punch due to rebind counter",
"vpnAddrs", hostinfo.vpnAddrs,
)
}
}
if !hostinfo.remote.IsValid() { //the relay path
//first, find our relay hostinfo:
var relayHostInfo *HostInfo
var relay *Relay
var err error
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
relayHostInfo, relay, err = f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
if err != nil {
hostinfo.relayState.DeleteRelay(relayIP)
hostinfo.logger(f.l).Info("sendNoMetrics failed to find HostInfo",
"relay", relayIP,
"error", err,
)
continue
}
break
}
if relayHostInfo == nil || relay == nil {
//failure already logged
return
}
err = tio.SegmentSuperpacket(pkt, func(seg []byte) error {
//relay header + header + plaintext + AEAD tag (16 bytes for both AES-GCM and ChaCha20-Poly1305) + relay tag
scratch := sendBatch.Reserve(header.Len + header.Len + len(seg) + 16 + 16)
innerPacket := f.sendInsideEncrypt(hostinfo, ci, seg, scratch[header.Len:], nb)
if innerPacket == nil {
return nil
}
//now we need to do a relay-encrypt:
toSend, err := f.prepareSendVia(relayHostInfo, relay, innerPacket, nb, scratch, true)
if err != nil {
//already logged
return nil
}
sendBatch.Commit(toSend, relayHostInfo.remote, 0)
return nil
})
if err != nil {
hostinfo.logger(f.l).Error("Failed to segment superpacket for relay send", "error", err)
}
return
}
err := tio.SegmentSuperpacket(pkt, func(seg []byte) error {
// header + plaintext + AEAD tag (16 bytes for both AES-GCM and ChaCha20-Poly1305)
scratch := sendBatch.Reserve(header.Len + len(seg) + 16)
out := f.sendInsideEncrypt(hostinfo, ci, seg, scratch, nb)
if out == nil {
return nil
}
sendBatch.Commit(out, hostinfo.remote, 0)
return nil
})
if err != nil {
hostinfo.logger(f.l).Error("Failed to segment superpacket for send",
"error", err,
)
}
}
func (f *Interface) rejectInside(packet []byte, out []byte, q int) { func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
if !f.firewall.InSendReject { if !f.firewall.InSendReject {
return return
@@ -275,21 +414,13 @@ func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *C
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0) f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
} }
// SendVia sends a payload through a Relay tunnel. No authentication or encryption is done func (f *Interface) prepareSendVia(via *HostInfo,
// to the payload for the ultimate target host, making this a useful method for sending
// handshake messages to peers through relay tunnels.
// via is the HostInfo through which the message is relayed.
// ad is the plaintext data to authenticate, but not encrypt
// nb is a buffer used to store the nonce value, re-used for performance reasons.
// out is a buffer used to store the result of the Encrypt operation
// q indicates which writer to use to send the packet.
func (f *Interface) SendVia(via *HostInfo,
relay *Relay, relay *Relay,
ad, ad,
nb, nb,
out []byte, out []byte,
nocopy bool, nocopy bool,
) { ) ([]byte, error) {
if noiseutil.EncryptLockNeeded { if noiseutil.EncryptLockNeeded {
// NOTE: for goboring AESGCMTLS we need to lock because of the nonce check // NOTE: for goboring AESGCMTLS we need to lock because of the nonce check
via.ConnectionState.writeLock.Lock() via.ConnectionState.writeLock.Lock()
@@ -311,7 +442,7 @@ func (f *Interface) SendVia(via *HostInfo,
"headerLen", len(out), "headerLen", len(out),
"cipherOverhead", via.ConnectionState.eKey.Overhead(), "cipherOverhead", via.ConnectionState.eKey.Overhead(),
) )
return return nil, io.ErrShortBuffer
} }
// The header bytes are written to the 'out' slice; Grow the slice to hold the header and associated data payload. // The header bytes are written to the 'out' slice; Grow the slice to hold the header and associated data payload.
@@ -331,13 +462,36 @@ func (f *Interface) SendVia(via *HostInfo,
} }
if err != nil { if err != nil {
via.logger(f.l).Info("Failed to EncryptDanger in sendVia", "error", err) via.logger(f.l).Info("Failed to EncryptDanger in sendVia", "error", err)
return nil, err
}
f.connectionManager.RelayUsed(relay.LocalIndex)
return out, nil
}
// SendVia sends a payload through a Relay tunnel. No authentication or encryption is done
// to the payload for the ultimate target host, making this a useful method for sending
// handshake messages to peers through relay tunnels.
// via is the HostInfo through which the message is relayed.
// ad is the plaintext data to authenticate, but not encrypt
// nb is a buffer used to store the nonce value, re-used for performance reasons.
// out is a buffer used to store the result of the Encrypt operation
// q indicates which writer to use to send the packet.
func (f *Interface) SendVia(via *HostInfo,
relay *Relay,
ad,
nb,
out []byte,
nocopy bool,
) {
toSend, err := f.prepareSendVia(via, relay, ad, nb, out, nocopy)
if err != nil {
via.logger(f.l).Info("Failed to prepareSendVia", "error", err)
return return
} }
err = f.writers[0].WriteTo(out, via.remote) err = f.writers[0].WriteTo(toSend, via.remote)
if err != nil { if err != nil {
via.logger(f.l).Info("Failed to WriteTo in sendVia", "error", err) via.logger(f.l).Info("Failed to WriteTo in sendVia", "error", err)
} }
f.connectionManager.RelayUsed(relay.LocalIndex)
} }
func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) { func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) {

View File

@@ -12,13 +12,14 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/wire" "github.com/slackhq/nebula/wire"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/overlay/batch"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
) )
@@ -90,7 +91,11 @@ type Interface struct {
ctx context.Context ctx context.Context
writers []udp.Conn writers []udp.Conn
readers []tio.Queue readers []tio.Queue
wg sync.WaitGroup // batchers is one per tun queue, wrapping readers[i].
// decryptToTun sends plaintext into the batch.RxBatcher;
// listenOut calls its Flush at the end of each UDP recvmmsg batch.
batchers []batch.RxBatcher
wg sync.WaitGroup
// fatalErr holds the first unexpected reader error that caused shutdown. // fatalErr holds the first unexpected reader error that caused shutdown.
// nil means "no fatal error" (yet) // nil means "no fatal error" (yet)
@@ -189,6 +194,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
version: c.version, version: c.version,
writers: make([]udp.Conn, c.routines), writers: make([]udp.Conn, c.routines),
readers: make([]tio.Queue, c.routines), readers: make([]tio.Queue, c.routines),
batchers: make([]batch.RxBatcher, c.routines),
myVpnNetworks: cs.myVpnNetworks, myVpnNetworks: cs.myVpnNetworks,
myVpnNetworksTable: cs.myVpnNetworksTable, myVpnNetworksTable: cs.myVpnNetworksTable,
myVpnAddrs: cs.myVpnAddrs, myVpnAddrs: cs.myVpnAddrs,
@@ -254,6 +260,10 @@ func (f *Interface) activate() error {
} }
} }
f.readers = f.inside.Readers() f.readers = f.inside.Readers()
for i := range f.readers {
arena := batch.NewArena(batch.DefaultPassthroughArenaCap)
f.batchers[i] = batch.NewPassthrough(f.readers[i], arena)
}
f.wg.Add(1) // for us to wait on Close() to return f.wg.Add(1) // for us to wait on Close() to return
if err = f.inside.Activate(); err != nil { if err = f.inside.Activate(); err != nil {
@@ -310,14 +320,22 @@ func (f *Interface) listenOut(i int) {
ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout) ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
lhh := f.lightHouse.NewRequestHandler() lhh := f.lightHouse.NewRequestHandler()
plaintext := make([]byte, udp.MTU)
h := &header.H{} h := &header.H{}
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { listener := func(fromUdpAddr netip.AddrPort, payload []byte, meta udp.RxMeta) {
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get()) plaintext := f.batchers[i].Reserve(len(payload))
}) f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(), meta)
}
flusher := func() {
if err := f.batchers[i].Flush(); err != nil {
f.l.Error("Failed to flush tun coalescer", "error", err)
}
}
err := li.ListenOut(listener, flusher)
if err != nil && !f.closed.Load() { if err != nil && !f.closed.Load() {
f.l.Error("Error while reading inbound packet, closing", "error", err) f.l.Error("Error while reading inbound packet, closing", "error", err)
@@ -332,6 +350,9 @@ func (f *Interface) listenIn(reader tio.Queue, q int) {
// TODO get the amount of bonus info from the reader // TODO get the amount of bonus info from the reader
packets := make([]wire.TunPacket, 1) packets := make([]wire.TunPacket, 1)
out := make([]byte, mtu) out := make([]byte, mtu)
rejectBuf := make([]byte, mtu)
arenaSize := batch.SendBatchCap * (udp.MTU + 32)
sb := batch.NewSendBatch(f.writers[q], batch.SendBatchCap, batch.NewArena(arenaSize))
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
@@ -346,9 +367,13 @@ func (f *Interface) listenIn(reader tio.Queue, q int) {
} }
break break
} }
ctCache := conntrackCache.Get() ctCache := conntrackCache.Get()
for i := range n { for i := range n{
f.consumeInsidePacket(packets[i].Bytes, fwPacket, nb, out, q, ctCache) f.consumeInsidePacket(packets[i], fwPacket, nb, sb, rejectBuf, q, ctCache)
}
if err := sb.Flush(); err != nil {
f.l.Error("Failed to write outgoing batch", "error", err, "writer", q)
} }
} }

View File

@@ -13,6 +13,7 @@ import (
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
) )
@@ -22,7 +23,7 @@ const (
var ErrOutOfWindow = errors.New("out of window packet") var ErrOutOfWindow = errors.New("out of window packet")
func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, meta udp.RxMeta) {
err := h.Parse(packet) err := h.Parse(packet)
if err != nil { if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
@@ -110,8 +111,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
// Relay packets are special // Relay packets are special
if isMessageRelay { if isMessageRelay {
f.handleOutsideRelayPacket(hostinfo, via, out, packet, h, fwPacket, lhf, nb, q, localCache) f.handleOutsideRelayPacket(hostinfo, via, out, packet, h, fwPacket, lhf, nb, q, localCache, meta)
return return
} }
@@ -135,7 +135,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
case header.Message: case header.Message:
switch h.Subtype { switch h.Subtype {
case header.MessageNone: case header.MessageNone:
f.handleOutsideMessagePacket(hostinfo, out, packet, fwPacket, nb, q, localCache) f.handleOutsideMessagePacket(hostinfo, out, packet, fwPacket, nb, q, localCache, meta)
default: default:
hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message subtype seen", "from", via, "header", h) hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message subtype seen", "from", via, "header", h)
return return
@@ -168,7 +168,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
} }
} }
func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, meta udp.RxMeta) {
// The entire body is sent as AD, not encrypted. // The entire body is sent as AD, not encrypted.
// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's // The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
@@ -211,7 +211,7 @@ func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender,
relay: relay, relay: relay,
IsRelayed: true, IsRelayed: true,
} }
f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, meta)
case ForwardingType: case ForwardingType:
// Find the target HostInfo relay object // Find the target HostInfo relay object
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
@@ -229,7 +229,7 @@ func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender,
switch targetRelay.Type { switch targetRelay.Type {
case ForwardingType: case ForwardingType:
// Forward this packet through the relay tunnel // Forward this packet through the relay tunnel
// Find the target HostInfo // Find the target HostInfo //todo it would potentially be nice to batch these
f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false) f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false)
case TerminalType: case TerminalType:
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
@@ -512,7 +512,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
return out, nil return out, nil
} }
func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, meta udp.RxMeta) {
err := newPacket(out, true, fwPacket) err := newPacket(out, true, fwPacket)
if err != nil { if err != nil {
hostinfo.logger(f.l).Warn("Error while validating inbound packet", hostinfo.logger(f.l).Warn("Error while validating inbound packet",
@@ -536,7 +536,7 @@ func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, p
return return
} }
_, err = f.readers[q].Write(out) err = f.batchers[q].Commit(out)
if err != nil { if err != nil {
f.l.Error("Failed to write to tun", "error", err) f.l.Error("Failed to write to tun", "error", err)
} }

28
overlay/batch/batch.go Normal file
View File

@@ -0,0 +1,28 @@
package batch
import "net/netip"
type RxBatcher interface {
// Reserve creates a pkt to borrow
Reserve(sz int) []byte
// Commit borrows pkt. The caller must keep pkt valid until the next Flush
Commit(pkt []byte) error
// Flush emits every queued packet in arrival order. Returns the
// first error observed; keeps draining so one bad packet doesn't hold up
// the rest. After Flush returns, borrowed payload slices may be recycled.
Flush() error
}
type TxBatcher interface {
// Reserve creates a pkt to borrow
Reserve(sz int) []byte
// Commit borrows pkt and records its destination plus the 2-bit
// IP-level ECN codepoint to set on the outer (carrier) header. The
// caller must keep pkt valid until the next Flush. Pass 0 (Not-ECT)
// to leave the outer ECN field unset.
Commit(pkt []byte, dst netip.AddrPort, outerECN byte)
// Flush emits every queued packet via the underlying batch writer in
// arrival order. Returns an errors.Join of one or more errors. After Flush returns,
// borrowed payload slices may be recycled.
Flush() error
}

View File

@@ -0,0 +1,42 @@
package batch
// Arena is an injectable byte-slab that hands out non-overlapping borrowed
// slices via Reserve and releases them in bulk via Reset. Coalescers take
// an *Arena at construction so the caller controls the slab lifetime and
// can share one slab across multiple coalescers (MultiCoalescer hands the
// same *Arena to every lane so the lanes don't carry their own backings).
//
// Reserve borrows; the slice is valid until the next Reset. The slab grows
// (by allocating a fresh, larger backing array) if a Reserve doesn't fit;
// pre-size the arena via NewArena to avoid that path on the hot path.
type Arena struct {
buf []byte
}
// NewArena returns an Arena with a pre-allocated backing of the given
// capacity. Pass 0 if you don't intend to call Reserve (e.g. a test that
// only feeds the coalescer pre-made []byte packets via Commit).
func NewArena(capacity int) *Arena {
return &Arena{buf: make([]byte, 0, capacity)}
}
// Reserve hands out a non-overlapping sz-byte slice from the arena. If the
// request doesn't fit the current backing, a fresh, larger backing is
// allocated; already-borrowed slices reference the old backing and remain
// valid until Reset.
func (a *Arena) Reserve(sz int) []byte {
if len(a.buf)+sz > cap(a.buf) {
newCap := max(cap(a.buf)*2, sz)
a.buf = make([]byte, 0, newCap)
}
start := len(a.buf)
a.buf = a.buf[:start+sz]
return a.buf[start : start+sz : start+sz]
}
// Reset releases every slice handed out since the last Reset. Callers must
// not use any previously-borrowed slice after this returns. The underlying
// backing array is retained so subsequent Reserves don't re-allocate.
func (a *Arena) Reset() {
a.buf = a.buf[:0]
}

View File

@@ -0,0 +1,52 @@
package batch
import (
"io"
"github.com/slackhq/nebula/udp"
)
// Passthrough is a RxBatcher that doesn't batch anything, it just accumulates and then sends packets.
type Passthrough struct {
out io.Writer
slots [][]byte
arena *Arena
cursor int
}
const passthroughBaseNumSlots = 128
// DefaultPassthroughArenaCap is the recommended arena capacity for a
// standalone Passthrough batcher: 128 slots × udp.MTU ≈ 1.1 MiB.
const DefaultPassthroughArenaCap = passthroughBaseNumSlots * udp.MTU
func NewPassthrough(w io.Writer, arena *Arena) *Passthrough {
return &Passthrough{
out: w,
slots: make([][]byte, 0, passthroughBaseNumSlots),
arena: arena,
}
}
func (p *Passthrough) Reserve(sz int) []byte {
return p.arena.Reserve(sz)
}
func (p *Passthrough) Commit(pkt []byte) error {
p.slots = append(p.slots, pkt)
return nil
}
func (p *Passthrough) Flush() error {
var firstErr error
for _, s := range p.slots {
_, err := p.out.Write(s)
if err != nil && firstErr == nil {
firstErr = err
}
}
clear(p.slots)
p.slots = p.slots[:0]
p.arena.Reset()
return firstErr
}

65
overlay/batch/tx_batch.go Normal file
View File

@@ -0,0 +1,65 @@
package batch
import (
"net/netip"
"github.com/slackhq/nebula/udp"
)
const SendBatchCap = 128
// DefaultSendBatchArenaCap is the recommended arena capacity for a
// standalone SendBatch: 128 slots × (udp.MTU + 32) ≈ 1.1 MiB. The +32 covers
// the nebula header + AEAD tag tacked onto each plaintext segment.
const DefaultSendBatchArenaCap = SendBatchCap * (udp.MTU + 32)
// batchWriter is the minimal subset of udp.Conn needed by SendBatch to flush.
type batchWriter interface {
WriteBatch(bufs [][]byte, addrs []netip.AddrPort, outerECNs []byte) error
}
// SendBatch accumulates encrypted UDP packets and flushes them via WriteBatch.
// One SendBatch is owned by each listenIn goroutine; no locking is needed.
// Slot bytes are borrowed from the injected Arena and remain valid until
// Flush, which Resets the arena.
type SendBatch struct {
out batchWriter
bufs [][]byte
dsts []netip.AddrPort
ecns []byte
arena *Arena
}
// NewSendBatch makes a SendBatch with batchCap slots backed by arena.
func NewSendBatch(out batchWriter, batchCap int, arena *Arena) *SendBatch {
return &SendBatch{
out: out,
bufs: make([][]byte, 0, batchCap),
dsts: make([]netip.AddrPort, 0, batchCap),
ecns: make([]byte, 0, batchCap),
arena: arena,
}
}
func (b *SendBatch) Reserve(sz int) []byte {
return b.arena.Reserve(sz)
}
func (b *SendBatch) Commit(pkt []byte, dst netip.AddrPort, outerECN byte) {
b.bufs = append(b.bufs, pkt)
b.dsts = append(b.dsts, dst)
b.ecns = append(b.ecns, outerECN)
}
func (b *SendBatch) Flush() error {
var err error
if len(b.bufs) > 0 {
err = b.out.WriteBatch(b.bufs, b.dsts, b.ecns)
}
clear(b.bufs)
b.bufs = b.bufs[:0]
b.dsts = b.dsts[:0]
b.ecns = b.ecns[:0]
b.arena.Reset()
return err
}

View File

@@ -0,0 +1,124 @@
package batch
import (
"net/netip"
"testing"
)
type fakeBatchWriter struct {
bufs [][]byte
addrs []netip.AddrPort
ecns []byte
}
func (w *fakeBatchWriter) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, ecns []byte) error {
// Snapshot — SendBatch.Flush nils its slot pointers right after WriteBatch
// returns, so tests must capture data before that happens.
w.bufs = make([][]byte, len(bufs))
for i, b := range bufs {
cp := make([]byte, len(b))
copy(cp, b)
w.bufs[i] = cp
}
w.addrs = append(w.addrs[:0], addrs...)
w.ecns = append(w.ecns[:0], ecns...)
return nil
}
func TestSendBatchReserveCommitFlush(t *testing.T) {
fw := &fakeBatchWriter{}
b := NewSendBatch(fw, 4, NewArena(32))
ap := netip.MustParseAddrPort("10.0.0.1:4242")
for i := 0; i < 4; i++ {
slot := b.Reserve(32)
if cap(slot) != 32 {
t.Fatalf("slot %d: cap=%d want 32", i, cap(slot))
}
pkt := append(slot[:0], byte(i), byte(i+1), byte(i+2))
b.Commit(pkt, ap, 0)
}
if err := b.Flush(); err != nil {
t.Fatalf("Flush: %v", err)
}
if len(fw.bufs) != 4 {
t.Fatalf("WriteBatch got %d bufs want 4", len(fw.bufs))
}
for i, buf := range fw.bufs {
if len(buf) != 3 || buf[0] != byte(i) {
t.Errorf("buf %d: %x", i, buf)
}
if fw.addrs[i] != ap {
t.Errorf("addr %d: got %v want %v", i, fw.addrs[i], ap)
}
}
// Flush again with nothing committed — should be a no-op.
fw.bufs = nil
if err := b.Flush(); err != nil {
t.Fatalf("empty Flush: %v", err)
}
if fw.bufs != nil {
t.Fatalf("empty Flush triggered WriteBatch")
}
// Reuse after Flush.
slot := b.Reserve(32)
if cap(slot) != 32 {
t.Fatalf("after Flush Reserve wrong cap: %d", cap(slot))
}
}
func TestSendBatchSlotsDoNotOverlap(t *testing.T) {
fw := &fakeBatchWriter{}
b := NewSendBatch(fw, 3, NewArena(8))
ap := netip.MustParseAddrPort("10.0.0.1:80")
for i := 0; i < 3; i++ {
s := b.Reserve(8)
pkt := append(s[:0], byte(0xA0+i), byte(0xB0+i))
b.Commit(pkt, ap, 0)
}
if err := b.Flush(); err != nil {
t.Fatalf("Flush: %v", err)
}
for i, buf := range fw.bufs {
if buf[0] != byte(0xA0+i) || buf[1] != byte(0xB0+i) {
t.Errorf("slot %d corrupted: %x", i, buf)
}
}
}
func TestSendBatchGrowPreservesCommitted(t *testing.T) {
fw := &fakeBatchWriter{}
// Tiny initial backing forces a grow on the second Reserve.
b := NewSendBatch(fw, 1, NewArena(4))
ap := netip.MustParseAddrPort("10.0.0.1:80")
s1 := b.Reserve(4)
pkt1 := append(s1[:0], 0x11, 0x22, 0x33, 0x44)
b.Commit(pkt1, ap, 0)
s2 := b.Reserve(8) // exceeds remaining cap, triggers grow
pkt2 := append(s2[:0], 0xA, 0xB, 0xC, 0xD, 0xE)
b.Commit(pkt2, ap, 0)
// pkt1 must still be intact even though backing reallocated.
if pkt1[0] != 0x11 || pkt1[3] != 0x44 {
t.Fatalf("first packet corrupted by grow: %x", pkt1)
}
if err := b.Flush(); err != nil {
t.Fatalf("Flush: %v", err)
}
if len(fw.bufs) != 2 {
t.Fatalf("got %d bufs want 2", len(fw.bufs))
}
if fw.bufs[0][0] != 0x11 || fw.bufs[0][3] != 0x44 {
t.Errorf("first packet on the wire: %x", fw.bufs[0])
}
if fw.bufs[1][0] != 0xA || fw.bufs[1][4] != 0xE {
t.Errorf("second packet on the wire: %x", fw.bufs[1])
}
}

View File

@@ -8,6 +8,10 @@ import (
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
// defaultBatchBufSize is the per-Queue scratch size for Read on backends
// that don't do TSO segmentation. 65535 covers any single IP packet.
const defaultBatchBufSize = 65535
type Device interface { type Device interface {
io.Closer io.Closer
Activate() error Activate() error

12
overlay/tio/segment.go Normal file
View File

@@ -0,0 +1,12 @@
package tio
import "fmt"
// SegmentSuperpacket invokes fn once per segment of pkt.
// This is a stub implementation that does not actually support segmentation
func SegmentSuperpacket(pkt Packet, fn func(seg []byte) error) error {
if pkt.GSO.IsSuperpacket() {
return fmt.Errorf("tio: GSO superpacket on platform without segmentation support")
}
return fn(pkt.Bytes)
}

View File

@@ -18,7 +18,12 @@ type QueueSet interface {
// Capabilities advertises which kernel offload features a Queue successfully negotiated. // Capabilities advertises which kernel offload features a Queue successfully negotiated.
// Callers consult this to decide which coalescers to wire onto the write path. // Callers consult this to decide which coalescers to wire onto the write path.
type Capabilities struct { type Capabilities struct {
//none yet! // TSO means the FD was opened with IFF_VNET_HDR and the kernel agreed
// to TUN_F_TSO4|TSO6 — i.e. WriteGSO with GSOProtoTCP is safe.
TSO bool
// USO means the kernel additionally agreed to TUN_F_USO4|USO6, so
// WriteGSO with GSOProtoUDP is safe. Linux ≥ 6.2.
USO bool
} }
// Queue is a readable/writable Poll queue. One Queue is driven by a single // Queue is a readable/writable Poll queue. One Queue is driven by a single
@@ -40,3 +45,78 @@ type Queue interface {
// or the zero value when q does not advertise any. // or the zero value when q does not advertise any.
Capabilities() Capabilities Capabilities() Capabilities
} }
// GSOInfo describes a kernel-supplied superpacket sitting in Packet.Bytes.
// The zero value means "not a superpacket" — Bytes is one regular IP
// datagram and no segmentation is required.
type GSOInfo struct {
// Size is the GSO segment size: max payload bytes per segment
// (== TCP MSS for TSO, == UDP payload chunk for USO). Zero means
// not a superpacket.
Size uint16
// HdrLen is the total L3+L4 header length within Bytes (already
// corrected via correctHdrLen, so safe to slice on).
HdrLen uint16
// CsumStart is the L4 header offset inside Bytes (== L3 header
// length).
CsumStart uint16
// Proto picks the L4 protocol (TCP or UDP) so the segmenter knows
// which checksum/header layout to apply.
Proto GSOProto
}
// GSOProto selects the L4 protocol for a GSO superpacket. Determines which
// VIRTIO_NET_HDR_GSO_* type the writer stamps and which checksum offset
// inside the transport header virtio NEEDS_CSUM expects.
type GSOProto uint8
const (
GSOProtoNone GSOProto = iota
GSOProtoTCP
GSOProtoUDP
)
// GSOWriter is implemented by Queues that can emit a TCP or UDP superpacket
// assembled from a header prefix plus one or more borrowed payload
// fragments, in a single vectored write (writev with a leading
// virtio_net_hdr). This lets the coalescer avoid copying payload bytes
// between the caller's decrypt buffer and the TUN. Backends without GSO
// support do not implement this interface and coalescing is skipped.
//
// hdr contains the IPv4/IPv6 header prefix (mutable - callers will have
// filled in total length and IP csum). transportHdr is the TCP or UDP
// header (mutable - the L4 checksum field must hold the pseudo-header
// partial, single-fold not inverted, per virtio NEEDS_CSUM semantics).
// pays are non-overlapping payload fragments whose concatenation is the
// full superpacket payload; they are read-only from the writer's
// perspective and must remain valid until the call returns. Every segment
// in pays except possibly the last is exactly the same size. proto picks
// the L4 protocol so the writer knows which GSOType / CsumOffset to set.
//
// Callers should also consult CapsProvider (via SupportsGSO or
// QueueCapabilities) for the per-protocol negotiated capability; an
// implementation of GSOWriter is necessary but not sufficient since USO
// may not have been negotiated even when TSO was.
type GSOWriter interface {
WriteGSO(hdr []byte, transportHdr []byte, pays [][]byte, proto GSOProto) error
}
// SupportsGSO reports whether w implements GSOWriter and the underlying
// queue advertises the negotiated capability for `want`. A writer that
// implements GSOWriter but not CapsProvider is treated as permissive
// (used by tests and fakes that don't negotiate).
func SupportsGSO(w Queue, want GSOProto) (GSOWriter, bool) {
gw, ok := w.(GSOWriter)
if !ok {
return nil, false
}
caps := w.Capabilities()
switch want {
case GSOProtoTCP:
return gw, caps.TSO
case GSOProtoUDP:
return gw, caps.USO
default:
return gw, false
}
}

View File

@@ -8,16 +8,49 @@ import (
const MTU = 9001 const MTU = 9001
// MaxWriteBatch is the largest batch any Conn.WriteBatch implementation is
// required to accept. Callers SHOULD NOT pass more than this per call; Linux
// backends preallocate sendmmsg scratch sized to this value, so exceeding it
// only costs additional sendmmsg chunks within a single WriteBatch call.
const MaxWriteBatch = 128
// RxMeta carries per-packet metadata extracted from the RX path (ancillary
// data, kernel offload state, etc.) and passed to EncReader callbacks.
// Backends that do not produce a particular signal leave its zero value.
//
// OuterECN is the 2-bit IP-level ECN codepoint stamped on the carrier
// datagram (extracted from IP_TOS / IPV6_TCLASS cmsg on Linux). Zero
// means Not-ECT, which is also the value backends without ECN RX support
// supply on every packet.
type RxMeta struct {
OuterECN byte
}
type EncReader func( type EncReader func(
addr netip.AddrPort, addr netip.AddrPort,
payload []byte, payload []byte,
meta RxMeta,
) )
type Conn interface { type Conn interface {
Rebind() error Rebind() error
LocalAddr() (netip.AddrPort, error) LocalAddr() (netip.AddrPort, error)
ListenOut(r EncReader) error // ListenOut invokes r for each received packet. On batch-capable
// backends (recvmmsg), flush is called after each batch is fully
// delivered — callers use it to flush per-batch accumulators such as
// TUN write coalescers. Single-packet backends call flush after each
// packet. flush must not be nil.
ListenOut(r EncReader, flush func()) error
WriteTo(b []byte, addr netip.AddrPort) error WriteTo(b []byte, addr netip.AddrPort) error
// WriteBatch sends a contiguous batch of packets, each with its own
// destination. bufs and addrs must have the same length. outerECNs may
// be nil (treated as all-zero / Not-ECT); when non-nil it must have the
// same length as bufs, and outerECNs[i] is the 2-bit IP-level ECN
// codepoint to set on packet i's outer header. Linux uses sendmmsg(2)
// for a single syscall and attaches the value as IP_TOS / IPV6_TCLASS
// cmsg; other backends ignore it. Returns on the first error; callers
// may observe a partial send if some packets went out before the error.
WriteBatch(bufs [][]byte, addrs []netip.AddrPort, outerECNs []byte) error
ReloadConfig(c *config.C) ReloadConfig(c *config.C)
SupportsMultipleReaders() bool SupportsMultipleReaders() bool
Close() error Close() error
@@ -31,7 +64,7 @@ func (NoopConn) Rebind() error {
func (NoopConn) LocalAddr() (netip.AddrPort, error) { func (NoopConn) LocalAddr() (netip.AddrPort, error) {
return netip.AddrPort{}, nil return netip.AddrPort{}, nil
} }
func (NoopConn) ListenOut(_ EncReader) error { func (NoopConn) ListenOut(_ EncReader, _ func()) error {
return nil return nil
} }
func (NoopConn) SupportsMultipleReaders() bool { func (NoopConn) SupportsMultipleReaders() bool {
@@ -40,6 +73,9 @@ func (NoopConn) SupportsMultipleReaders() bool {
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
return nil return nil
} }
func (NoopConn) WriteBatch(_ [][]byte, _ []netip.AddrPort, _ []byte) error {
return nil
}
func (NoopConn) ReloadConfig(_ *config.C) { func (NoopConn) ReloadConfig(_ *config.C) {
return return
} }

View File

@@ -140,6 +140,15 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
} }
} }
func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error {
for i, b := range bufs {
if err := u.WriteTo(b, addrs[i]); err != nil {
return err
}
}
return nil
}
func (u *StdConn) LocalAddr() (netip.AddrPort, error) { func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
a := u.UDPConn.LocalAddr() a := u.UDPConn.LocalAddr()
@@ -165,7 +174,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() {
return func() {} return func() {}
} }
func (u *StdConn) ListenOut(r EncReader) error { func (u *StdConn) ListenOut(r EncReader, flush func()) error {
buffer := make([]byte, MTU) buffer := make([]byte, MTU)
for { for {
@@ -179,7 +188,8 @@ func (u *StdConn) ListenOut(r EncReader) error {
u.l.Error("unexpected udp socket receive error", "error", err) u.l.Error("unexpected udp socket receive error", "error", err)
} }
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n], RxMeta{})
flush()
} }
} }

View File

@@ -44,6 +44,15 @@ func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error {
return err return err
} }
func (u *GenericConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error {
for i, b := range bufs {
if _, err := u.UDPConn.WriteToUDPAddrPort(b, addrs[i]); err != nil {
return err
}
}
return nil
}
func (u *GenericConn) LocalAddr() (netip.AddrPort, error) { func (u *GenericConn) LocalAddr() (netip.AddrPort, error) {
a := u.UDPConn.LocalAddr() a := u.UDPConn.LocalAddr()
@@ -73,7 +82,7 @@ type rawMessage struct {
Len uint32 Len uint32
} }
func (u *GenericConn) ListenOut(r EncReader) error { func (u *GenericConn) ListenOut(r EncReader, flush func()) error {
buffer := make([]byte, MTU) buffer := make([]byte, MTU)
var lastRecvErr time.Time var lastRecvErr time.Time
@@ -93,7 +102,8 @@ func (u *GenericConn) ListenOut(r EncReader) error {
continue continue
} }
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n], RxMeta{})
flush()
} }
} }

View File

@@ -24,6 +24,22 @@ type StdConn struct {
isV4 bool isV4 bool
l *slog.Logger l *slog.Logger
batch int batch int
// sendmmsg scratch. Each queue has its own StdConn, so no locking is
// needed. Sized to MaxWriteBatch at construction; WriteBatch chunks
// larger inputs.
writeMsgs []rawMessage
writeIovs []iovec
writeNames [][]byte
// sendmmsg(2) callback state. sendmmsgCB is bound once in NewListener
// to the sendmmsgRun method value so passing it to rawConn.Write does
// not allocate a fresh closure per send; sendmmsgN/Sent/Errno carry
// the inputs and outputs across the call without escaping locals.
sendmmsgCB func(fd uintptr) bool
sendmmsgN int
sendmmsgSent int
sendmmsgErrno syscall.Errno
} }
func setReusePort(network, address string, c syscall.RawConn) error { func setReusePort(network, address string, c syscall.RawConn) error {
@@ -70,9 +86,23 @@ func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int)
} }
out.isV4 = af == unix.AF_INET out.isV4 = af == unix.AF_INET
out.prepareWriteMessages(MaxWriteBatch)
out.sendmmsgCB = out.sendmmsgRun
return out, nil return out, nil
} }
func (u *StdConn) prepareWriteMessages(n int) {
u.writeMsgs = make([]rawMessage, n)
u.writeIovs = make([]iovec, n)
u.writeNames = make([][]byte, n)
for i := range u.writeMsgs {
u.writeNames[i] = make([]byte, unix.SizeofSockaddrInet6)
u.writeMsgs[i].Hdr.Name = &u.writeNames[i][0]
}
}
func (u *StdConn) SupportsMultipleReaders() bool { func (u *StdConn) SupportsMultipleReaders() bool {
return true return true
} }
@@ -171,7 +201,7 @@ func recvmmsg(fd uintptr, msgs []rawMessage) (int, bool, error) {
return int(n), true, nil return int(n), true, nil
} }
func (u *StdConn) listenOutSingle(r EncReader) error { func (u *StdConn) listenOutSingle(r EncReader, flush func()) error {
var err error var err error
var n int var n int
var from netip.AddrPort var from netip.AddrPort
@@ -183,16 +213,33 @@ func (u *StdConn) listenOutSingle(r EncReader) error {
return err return err
} }
from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port()) from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port())
r(from, buffer[:n]) // listenOutSingle uses ReadFromUDPAddrPort which discards cmsgs,
// so the outer ECN field is not visible on this path. Zero RxMeta
// (Not-ECT) means RFC 6040 combine is a no-op.
r(from, buffer[:n], RxMeta{})
flush()
} }
} }
func (u *StdConn) listenOutBatch(r EncReader) error { // readSockaddr decodes the source address out of a recvmmsg name buffer
func (u *StdConn) readSockaddr(name []byte) netip.AddrPort {
var ip netip.Addr var ip netip.Addr
// It's ok to skip the ok check here, the slicing is the only error that can occur and it will panic
if u.isV4 {
ip, _ = netip.AddrFromSlice(name[4:8])
} else {
ip, _ = netip.AddrFromSlice(name[8:24])
}
return netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(name[2:4]))
}
func (u *StdConn) listenOutBatch(r EncReader, flush func()) error {
var n int var n int
var operr error var operr error
msgs, buffers, names := u.PrepareRawMessages(u.batch) bufSize := MTU
cmsgSpace := 0
msgs, buffers, names, _ := u.PrepareRawMessages(u.batch, bufSize, cmsgSpace)
//reader needs to capture variables from this function, since it's used as a lambda with rawConn.Read //reader needs to capture variables from this function, since it's used as a lambda with rawConn.Read
//defining it outside the loop so it gets re-used //defining it outside the loop so it gets re-used
@@ -211,22 +258,18 @@ func (u *StdConn) listenOutBatch(r EncReader) error {
} }
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic r(u.readSockaddr(names[i]), buffers[i][:msgs[i].Len], RxMeta{})
if u.isV4 {
ip, _ = netip.AddrFromSlice(names[i][4:8])
} else {
ip, _ = netip.AddrFromSlice(names[i][8:24])
}
r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len])
} }
flush()
} }
} }
func (u *StdConn) ListenOut(r EncReader) error { func (u *StdConn) ListenOut(r EncReader, flush func()) error {
if u.batch == 1 { if u.batch == 1 {
return u.listenOutSingle(r) return u.listenOutSingle(r, flush)
} else { } else {
return u.listenOutBatch(r) return u.listenOutBatch(r, flush)
} }
} }
@@ -235,6 +278,120 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
return err return err
} }
// WriteBatch sends bufs via sendmmsg(2) using the preallocated scratch on
// StdConn. If supported, consecutive packets to the same destination with
// matching segment sizes (all but possibly the last) are coalesced into a
// single mmsghdr entry
//
// If sendmmsg returns an error and zero entries went out, we fall back to
// per-packet WriteTo for that chunk so the caller still gets best-effort
// delivery. On a partial send we resume at the first un-acked entry on
// the next iteration.
func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error {
for i := 0; i < len(bufs); {
chunk := min(len(bufs)-i, len(u.writeMsgs))
for k := 0; k < chunk; k++ {
u.writeIovs[k].Base = &bufs[i+k][0]
setIovLen(&u.writeIovs[k], len(bufs[i+k]))
nlen, err := writeSockaddr(u.writeNames[k], addrs[i+k], u.isV4)
if err != nil {
return err
}
hdr := &u.writeMsgs[k].Hdr
hdr.Iov = &u.writeIovs[k]
setMsgIovlen(hdr, 1)
hdr.Namelen = uint32(nlen)
}
sent, serr := u.sendmmsg(chunk)
if serr != nil && sent <= 0 {
// sendmmsg returns -1 / sent=0 when entry 0 itself failed; log
// that entry's destination and fall back to per-packet WriteTo
// for the whole chunk so the caller still gets best-effort
// delivery without duplicating packets the kernel accepted.
u.l.Warn("sendmmsg failed, falling back to per-packet WriteTo",
"err", serr,
"entries", chunk,
"entry0_dst", addrs[i],
"isV4", u.isV4,
)
for k := 0; k < chunk; k++ {
if werr := u.WriteTo(bufs[i+k], addrs[i+k]); werr != nil {
return werr
}
}
i += chunk
continue
}
i += sent
}
return nil
}
// sendmmsg issues sendmmsg(2) against the first n entries of u.writeMsgs.
// The bound u.sendmmsgCB is passed to rawConn.Write so no closure is
// allocated per call; inputs and outputs ride on the StdConn fields.
func (u *StdConn) sendmmsg(n int) (int, error) {
u.sendmmsgN = n
u.sendmmsgSent = 0
u.sendmmsgErrno = 0
if err := u.rawConn.Write(u.sendmmsgCB); err != nil {
return u.sendmmsgSent, err
}
if u.sendmmsgErrno != 0 {
return u.sendmmsgSent, &net.OpError{Op: "sendmmsg", Err: u.sendmmsgErrno}
}
return u.sendmmsgSent, nil
}
// sendmmsgRun is the rawConn.Write callback. It is bound once into
// u.sendmmsgCB at construction so it stays alloc-free in the hot path;
// inputs (sendmmsgN) and outputs (sendmmsgSent, sendmmsgErrno) ride on
// the receiver rather than escaping locals.
func (u *StdConn) sendmmsgRun(fd uintptr) bool {
r1, _, errno := unix.Syscall6(unix.SYS_SENDMMSG, fd,
uintptr(unsafe.Pointer(&u.writeMsgs[0])), uintptr(u.sendmmsgN),
0, 0, 0,
)
if errno == syscall.EAGAIN || errno == syscall.EWOULDBLOCK {
return false
}
u.sendmmsgSent = int(r1)
u.sendmmsgErrno = errno
return true
}
// writeSockaddr encodes addr into buf (which must be at least
// SizeofSockaddrInet6 bytes). Returns the number of bytes used. If isV4 is
// true and addr is not a v4 (or v4-in-v6) address, returns an error.
func writeSockaddr(buf []byte, addr netip.AddrPort, isV4 bool) (int, error) {
ap := addr.Addr().Unmap()
if isV4 {
if !ap.Is4() {
return 0, ErrInvalidIPv6RemoteForSocket
}
// struct sockaddr_in: { sa_family_t(2), in_port_t(2, BE), in_addr(4), zero(8) }
// sa_family is host endian.
binary.NativeEndian.PutUint16(buf[0:2], unix.AF_INET)
binary.BigEndian.PutUint16(buf[2:4], addr.Port())
ip4 := ap.As4()
copy(buf[4:8], ip4[:])
clear(buf[8:16])
return unix.SizeofSockaddrInet4, nil
}
// struct sockaddr_in6: { sa_family_t(2), in_port_t(2, BE), flowinfo(4), in6_addr(16), scope_id(4) }
binary.NativeEndian.PutUint16(buf[0:2], unix.AF_INET6)
binary.BigEndian.PutUint16(buf[2:4], addr.Port())
binary.NativeEndian.PutUint32(buf[4:8], 0)
ip6 := addr.Addr().As16()
copy(buf[8:24], ip6[:])
binary.NativeEndian.PutUint32(buf[24:28], 0)
return unix.SizeofSockaddrInet6, nil
}
func (u *StdConn) ReloadConfig(c *config.C) { func (u *StdConn) ReloadConfig(c *config.C) {
b := c.GetInt("listen.read_buffer", 0) b := c.GetInt("listen.read_buffer", 0)
if b > 0 { if b > 0 {

View File

@@ -30,13 +30,18 @@ type rawMessage struct {
Len uint32 Len uint32
} }
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { func (u *StdConn) PrepareRawMessages(n, bufSize, cmsgSpace int) ([]rawMessage, [][]byte, [][]byte, []byte) {
msgs := make([]rawMessage, n) msgs := make([]rawMessage, n)
buffers := make([][]byte, n) buffers := make([][]byte, n)
names := make([][]byte, n) names := make([][]byte, n)
var cmsgs []byte
if cmsgSpace > 0 {
cmsgs = make([]byte, n*cmsgSpace)
}
for i := range msgs { for i := range msgs {
buffers[i] = make([]byte, MTU) buffers[i] = make([]byte, bufSize)
names[i] = make([]byte, unix.SizeofSockaddrInet6) names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{ vs := []iovec{
@@ -48,7 +53,28 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
msgs[i].Hdr.Name = &names[i][0] msgs[i].Hdr.Name = &names[i][0]
msgs[i].Hdr.Namelen = uint32(len(names[i])) msgs[i].Hdr.Namelen = uint32(len(names[i]))
if cmsgSpace > 0 {
msgs[i].Hdr.Control = &cmsgs[i*cmsgSpace]
msgs[i].Hdr.Controllen = uint32(cmsgSpace)
}
} }
return msgs, buffers, names return msgs, buffers, names, cmsgs
}
func setIovLen(v *iovec, n int) {
v.Len = uint32(n)
}
func setMsgIovlen(m *msghdr, n int) {
m.Iovlen = uint32(n)
}
func setMsgControllen(m *msghdr, n int) {
m.Controllen = uint32(n)
}
func setCmsgLen(h *unix.Cmsghdr, n int) {
h.Len = uint32(n)
} }

View File

@@ -33,13 +33,18 @@ type rawMessage struct {
Pad0 [4]byte Pad0 [4]byte
} }
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { func (u *StdConn) PrepareRawMessages(n, bufSize, cmsgSpace int) ([]rawMessage, [][]byte, [][]byte, []byte) {
msgs := make([]rawMessage, n) msgs := make([]rawMessage, n)
buffers := make([][]byte, n) buffers := make([][]byte, n)
names := make([][]byte, n) names := make([][]byte, n)
var cmsgs []byte
if cmsgSpace > 0 {
cmsgs = make([]byte, n*cmsgSpace)
}
for i := range msgs { for i := range msgs {
buffers[i] = make([]byte, MTU) buffers[i] = make([]byte, bufSize)
names[i] = make([]byte, unix.SizeofSockaddrInet6) names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{ vs := []iovec{
@@ -51,7 +56,28 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
msgs[i].Hdr.Name = &names[i][0] msgs[i].Hdr.Name = &names[i][0]
msgs[i].Hdr.Namelen = uint32(len(names[i])) msgs[i].Hdr.Namelen = uint32(len(names[i]))
if cmsgSpace > 0 {
msgs[i].Hdr.Control = &cmsgs[i*cmsgSpace]
msgs[i].Hdr.Controllen = uint64(cmsgSpace)
}
} }
return msgs, buffers, names return msgs, buffers, names, cmsgs
}
func setIovLen(v *iovec, n int) {
v.Len = uint64(n)
}
func setMsgIovlen(m *msghdr, n int) {
m.Iovlen = uint64(n)
}
func setMsgControllen(m *msghdr, n int) {
m.Controllen = uint64(n)
}
func setCmsgLen(h *unix.Cmsghdr, n int) {
h.Len = uint64(n)
} }

View File

@@ -140,7 +140,7 @@ func (u *RIOConn) bind(l *slog.Logger, sa windows.Sockaddr) error {
return nil return nil
} }
func (u *RIOConn) ListenOut(r EncReader) error { func (u *RIOConn) ListenOut(r EncReader, flush func()) error {
buffer := make([]byte, MTU) buffer := make([]byte, MTU)
var lastRecvErr time.Time var lastRecvErr time.Time
@@ -161,7 +161,8 @@ func (u *RIOConn) ListenOut(r EncReader) error {
continue continue
} }
r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n]) r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n], RxMeta{})
flush()
} }
} }
@@ -316,6 +317,15 @@ func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error {
return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
} }
func (u *RIOConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error {
for i, b := range bufs {
if err := u.WriteTo(b, addrs[i]); err != nil {
return err
}
}
return nil
}
func (u *RIOConn) LocalAddr() (netip.AddrPort, error) { func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
sa, err := windows.Getsockname(u.sock) sa, err := windows.Getsockname(u.sock)
if err != nil { if err != nil {

View File

@@ -157,15 +157,24 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
return nil return nil
} }
} }
func (u *TesterConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error {
for i, b := range bufs {
if err := u.WriteTo(b, addrs[i]); err != nil {
return err
}
}
return nil
}
func (u *TesterConn) ListenOut(r EncReader) error { func (u *TesterConn) ListenOut(r EncReader, flush func()) error {
for { for {
select { select {
case <-u.done: case <-u.done:
return os.ErrClosed return os.ErrClosed
case p := <-u.RxPackets: case p := <-u.RxPackets:
r(p.From, p.Data) r(p.From, p.Data, RxMeta{})
p.Release() p.Release()
flush()
} }
} }
} }