diff --git a/batch_pipeline.go b/batch_pipeline.go new file mode 100644 index 0000000..789c9d2 --- /dev/null +++ b/batch_pipeline.go @@ -0,0 +1,164 @@ +package nebula + +import ( + "net/netip" + + "github.com/slackhq/nebula/overlay" + "github.com/slackhq/nebula/udp" +) + +// batchPipelines tracks whether the inside device can operate on packet batches +// and, if so, holds the shared packet pool sized for the virtio headroom and +// payload limits advertised by the device. It also owns the fan-in/fan-out +// queues between the TUN readers, encrypt/decrypt workers, and the UDP writers. +type batchPipelines struct { + enabled bool + inside overlay.BatchCapableDevice + headroom int + payloadCap int + pool *overlay.PacketPool + batchSize int + routines int + rxQueues []chan *overlay.Packet + txQueues []chan queuedDatagram + tunQueues []chan *overlay.Packet +} + +type queuedDatagram struct { + packet *overlay.Packet + addr netip.AddrPort +} + +func (bp *batchPipelines) init(device overlay.Device, routines int, queueDepth int, maxSegments int) { + if device == nil || routines <= 0 { + return + } + bcap, ok := device.(overlay.BatchCapableDevice) + if !ok { + return + } + headroom := bcap.BatchHeadroom() + payload := bcap.BatchPayloadCap() + if maxSegments < 1 { + maxSegments = 1 + } + requiredPayload := udp.MTU * maxSegments + if payload < requiredPayload { + payload = requiredPayload + } + batchSize := bcap.BatchSize() + if headroom <= 0 || payload <= 0 || batchSize <= 0 { + return + } + bp.enabled = true + bp.inside = bcap + bp.headroom = headroom + bp.payloadCap = payload + bp.batchSize = batchSize + bp.routines = routines + bp.pool = overlay.NewPacketPool(headroom, payload) + queueCap := batchSize * defaultBatchQueueDepthFactor + if queueDepth > 0 { + queueCap = queueDepth + } + if queueCap < batchSize { + queueCap = batchSize + } + bp.rxQueues = make([]chan *overlay.Packet, routines) + bp.txQueues = make([]chan queuedDatagram, routines) + bp.tunQueues = make([]chan *overlay.Packet, routines) + for i := 0; i < routines; i++ { + bp.rxQueues[i] = make(chan *overlay.Packet, queueCap) + bp.txQueues[i] = make(chan queuedDatagram, queueCap) + bp.tunQueues[i] = make(chan *overlay.Packet, queueCap) + } +} + +func (bp *batchPipelines) Pool() *overlay.PacketPool { + if bp == nil || !bp.enabled { + return nil + } + return bp.pool +} + +func (bp *batchPipelines) Enabled() bool { + return bp != nil && bp.enabled +} + +func (bp *batchPipelines) batchSizeHint() int { + if bp == nil || bp.batchSize <= 0 { + return 1 + } + return bp.batchSize +} + +func (bp *batchPipelines) rxQueue(i int) chan *overlay.Packet { + if bp == nil || !bp.enabled || i < 0 || i >= len(bp.rxQueues) { + return nil + } + return bp.rxQueues[i] +} + +func (bp *batchPipelines) txQueue(i int) chan queuedDatagram { + if bp == nil || !bp.enabled || i < 0 || i >= len(bp.txQueues) { + return nil + } + return bp.txQueues[i] +} + +func (bp *batchPipelines) tunQueue(i int) chan *overlay.Packet { + if bp == nil || !bp.enabled || i < 0 || i >= len(bp.tunQueues) { + return nil + } + return bp.tunQueues[i] +} + +func (bp *batchPipelines) txQueueLen(i int) int { + q := bp.txQueue(i) + if q == nil { + return 0 + } + return len(q) +} + +func (bp *batchPipelines) tunQueueLen(i int) int { + q := bp.tunQueue(i) + if q == nil { + return 0 + } + return len(q) +} + +func (bp *batchPipelines) enqueueRx(i int, pkt *overlay.Packet) bool { + q := bp.rxQueue(i) + if q == nil { + return false + } + q <- pkt + return true +} + +func (bp *batchPipelines) enqueueTx(i int, pkt *overlay.Packet, addr netip.AddrPort) bool { + q := bp.txQueue(i) + if q == nil { + return false + } + q <- queuedDatagram{packet: pkt, addr: addr} + return true +} + +func (bp *batchPipelines) enqueueTun(i int, pkt *overlay.Packet) bool { + q := bp.tunQueue(i) + if q == nil { + return false + } + q <- pkt + return true +} + +func (bp *batchPipelines) newPacket() *overlay.Packet { + if bp == nil || !bp.enabled || bp.pool == nil { + return nil + } + return bp.pool.Get() +} diff --git a/inside.go b/inside.go index d24ed31..22b63d4 100644 --- a/inside.go +++ b/inside.go @@ -8,6 +8,7 @@ import ( "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/noiseutil" + "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/routing" ) @@ -335,9 +336,21 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType if ci.eKey == nil { return } - useRelay := !remote.IsValid() && !hostinfo.remote.IsValid() + target := remote + if !target.IsValid() { + target = hostinfo.remote + } + useRelay := !target.IsValid() fullOut := out + var pkt *overlay.Packet + if !useRelay && f.batches.Enabled() { + pkt = f.batches.newPacket() + if pkt != nil { + out = pkt.Payload()[:0] + } + } + if useRelay { if len(out) < header.Len { // out always has a capacity of mtu, but not always a length greater than the header.Len. @@ -376,36 +389,61 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType ci.writeLock.Unlock() } if err != nil { + if pkt != nil { + pkt.Release() + } hostinfo.logger(f.l).WithError(err). - WithField("udpAddr", remote).WithField("counter", c). + WithField("udpAddr", target).WithField("counter", c). WithField("attemptedCounter", c). Error("Failed to encrypt outgoing packet") return } - if remote.IsValid() { - err = f.writers[q].WriteTo(out, remote) - if err != nil { - hostinfo.logger(f.l).WithError(err). - WithField("udpAddr", remote).Error("Failed to write outgoing packet") - } - } else if hostinfo.remote.IsValid() { - err = f.writers[q].WriteTo(out, hostinfo.remote) - if err != nil { - hostinfo.logger(f.l).WithError(err). - WithField("udpAddr", remote).Error("Failed to write outgoing packet") - } - } else { - // Try to send via a relay - for _, relayIP := range hostinfo.relayState.CopyRelayIps() { - relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP) - if err != nil { - hostinfo.relayState.DeleteRelay(relayIP) - hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo") - continue + if target.IsValid() { + if pkt != nil { + pkt.Len = len(out) + if f.l.Level >= logrus.DebugLevel { + f.l.WithFields(logrus.Fields{ + "queue": q, + "dest": target, + "payload_len": pkt.Len, + "use_batches": true, + "remote_index": hostinfo.remoteIndexId, + }).Debug("enqueueing packet to UDP batch queue") } - f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true) - break + if f.tryQueuePacket(q, pkt, target) { + return + } + if f.l.Level >= logrus.DebugLevel { + f.l.WithFields(logrus.Fields{ + "queue": q, + "dest": target, + }).Debug("failed to enqueue packet; falling back to immediate send") + } + f.writeImmediatePacket(q, pkt, target, hostinfo) + return } + if f.tryQueueDatagram(q, out, target) { + return + } + f.writeImmediate(q, out, target, hostinfo) + return + } + + // fall back to relay path + if pkt != nil { + pkt.Release() + } + + // Try to send via a relay + for _, relayIP := range hostinfo.relayState.CopyRelayIps() { + relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP) + if err != nil { + hostinfo.relayState.DeleteRelay(relayIP) + hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo") + continue + } + f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true) + break } } diff --git a/interface.go b/interface.go index 082906d..05ad918 100644 --- a/interface.go +++ b/interface.go @@ -21,7 +21,13 @@ import ( "github.com/slackhq/nebula/udp" ) -const mtu = 9001 +const ( + mtu = 9001 + defaultGSOFlushInterval = 150 * time.Microsecond + defaultBatchQueueDepthFactor = 4 + defaultGSOMaxSegments = 8 + maxKernelGSOSegments = 64 +) type InterfaceConfig struct { HostMap *HostMap @@ -36,6 +42,9 @@ type InterfaceConfig struct { connectionManager *connectionManager DropLocalBroadcast bool DropMulticast bool + EnableGSO bool + EnableGRO bool + GSOMaxSegments int routines int MessageMetrics *MessageMetrics version string @@ -47,6 +56,8 @@ type InterfaceConfig struct { reQueryWait time.Duration ConntrackCacheTimeout time.Duration + BatchFlushInterval time.Duration + BatchQueueDepth int l *logrus.Logger } @@ -84,9 +95,20 @@ type Interface struct { version string conntrackCacheTimeout time.Duration + batchQueueDepth int + enableGSO bool + enableGRO bool + gsoMaxSegments int + batchUDPQueueGauge metrics.Gauge + batchUDPFlushCounter metrics.Counter + batchTunQueueGauge metrics.Gauge + batchTunFlushCounter metrics.Counter + batchFlushInterval atomic.Int64 + sendSem chan struct{} writers []udp.Conn readers []io.ReadWriteCloser + batches batchPipelines metricHandshakes metrics.Histogram messageMetrics *MessageMetrics @@ -161,6 +183,22 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { return nil, errors.New("no connection manager") } + if c.GSOMaxSegments <= 0 { + c.GSOMaxSegments = defaultGSOMaxSegments + } + if c.GSOMaxSegments > maxKernelGSOSegments { + c.GSOMaxSegments = maxKernelGSOSegments + } + if c.BatchQueueDepth <= 0 { + c.BatchQueueDepth = c.routines * defaultBatchQueueDepthFactor + } + if c.BatchFlushInterval < 0 { + c.BatchFlushInterval = 0 + } + if c.BatchFlushInterval == 0 && c.EnableGSO { + c.BatchFlushInterval = defaultGSOFlushInterval + } + cs := c.pki.getCertState() ifce := &Interface{ pki: c.pki, @@ -186,6 +224,10 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { relayManager: c.relayManager, connectionManager: c.connectionManager, conntrackCacheTimeout: c.ConntrackCacheTimeout, + batchQueueDepth: c.BatchQueueDepth, + enableGSO: c.EnableGSO, + enableGRO: c.EnableGRO, + gsoMaxSegments: c.GSOMaxSegments, metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)), messageMetrics: c.MessageMetrics, @@ -198,8 +240,25 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { } ifce.tryPromoteEvery.Store(c.tryPromoteEvery) + ifce.batchUDPQueueGauge = metrics.GetOrRegisterGauge("batch.udp.queue_depth", nil) + ifce.batchUDPFlushCounter = metrics.GetOrRegisterCounter("batch.udp.flushes", nil) + ifce.batchTunQueueGauge = metrics.GetOrRegisterGauge("batch.tun.queue_depth", nil) + ifce.batchTunFlushCounter = metrics.GetOrRegisterCounter("batch.tun.flushes", nil) + ifce.batchFlushInterval.Store(int64(c.BatchFlushInterval)) + ifce.sendSem = make(chan struct{}, c.routines) + ifce.batches.init(c.Inside, c.routines, c.BatchQueueDepth, c.GSOMaxSegments) ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryWait.Store(int64(c.reQueryWait)) + if c.l.Level >= logrus.DebugLevel { + c.l.WithFields(logrus.Fields{ + "enableGSO": c.EnableGSO, + "enableGRO": c.EnableGRO, + "gsoMaxSegments": c.GSOMaxSegments, + "batchQueueDepth": c.BatchQueueDepth, + "batchFlush": c.BatchFlushInterval, + "batching": ifce.batches.Enabled(), + }).Debug("initialized batch pipelines") + } ifce.connectionManager.intf = ifce @@ -248,6 +307,18 @@ func (f *Interface) run() { go f.listenOut(i) } + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("batching", f.batches.Enabled()).Debug("starting interface run loops") + } + + if f.batches.Enabled() { + for i := 0; i < f.routines; i++ { + go f.runInsideBatchWorker(i) + go f.runTunWriteQueue(i) + go f.runSendQueue(i) + } + } + // Launch n queues to read packets from tun dev for i := 0; i < f.routines; i++ { go f.listenIn(f.readers[i], i) @@ -279,6 +350,17 @@ func (f *Interface) listenOut(i int) { func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { runtime.LockOSThread() + if f.batches.Enabled() { + if br, ok := reader.(overlay.BatchReader); ok { + f.listenInBatchLocked(reader, br, i) + return + } + } + + f.listenInLegacyLocked(reader, i) +} + +func (f *Interface) listenInLegacyLocked(reader io.ReadWriteCloser, i int) { packet := make([]byte, mtu) out := make([]byte, mtu) fwPacket := &firewall.Packet{} @@ -302,6 +384,489 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { } } +func (f *Interface) listenInBatchLocked(raw io.ReadWriteCloser, reader overlay.BatchReader, i int) { + pool := f.batches.Pool() + if pool == nil { + f.l.Warn("batch pipeline enabled without an allocated pool; falling back to single-packet reads") + f.listenInLegacyLocked(raw, i) + return + } + + for { + packets, err := reader.ReadIntoBatch(pool) + if err != nil { + if errors.Is(err, os.ErrClosed) && f.closed.Load() { + return + } + + f.l.WithError(err).Error("Error while reading outbound packet batch") + os.Exit(2) + } + + if len(packets) == 0 { + continue + } + + for _, pkt := range packets { + if pkt == nil { + continue + } + if !f.batches.enqueueRx(i, pkt) { + pkt.Release() + } + } + } +} + +func (f *Interface) runInsideBatchWorker(i int) { + queue := f.batches.rxQueue(i) + if queue == nil { + return + } + + out := make([]byte, mtu) + fwPacket := &firewall.Packet{} + nb := make([]byte, 12, 12) + conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) + + for pkt := range queue { + if pkt == nil { + continue + } + f.consumeInsidePacket(pkt.Payload(), fwPacket, nb, out, i, conntrackCache.Get(f.l)) + pkt.Release() + } +} + +func (f *Interface) runSendQueue(i int) { + queue := f.batches.txQueue(i) + if queue == nil { + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("queue", i).Debug("tx queue not initialized; batching disabled for writer") + } + return + } + writer := f.writerForIndex(i) + if writer == nil { + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("queue", i).Debug("no UDP writer for batch queue") + } + return + } + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("queue", i).Debug("send queue worker started") + } + defer func() { + if f.l.Level >= logrus.WarnLevel { + f.l.WithField("queue", i).Warn("send queue worker exited") + } + }() + + batchCap := f.batches.batchSizeHint() + if batchCap <= 0 { + batchCap = 1 + } + gsoLimit := f.effectiveGSOMaxSegments() + if gsoLimit > batchCap { + batchCap = gsoLimit + } + pending := make([]queuedDatagram, 0, batchCap) + var ( + flushTimer *time.Timer + flushC <-chan time.Time + ) + dispatch := func(reason string, timerFired bool) { + if len(pending) == 0 { + return + } + batch := pending + f.flushAndReleaseBatch(i, writer, batch, reason) + for idx := range batch { + batch[idx] = queuedDatagram{} + } + pending = pending[:0] + if flushTimer != nil { + if !timerFired { + if !flushTimer.Stop() { + select { + case <-flushTimer.C: + default: + } + } + } + flushTimer = nil + flushC = nil + } + } + armTimer := func() { + delay := f.currentBatchFlushInterval() + if delay <= 0 { + dispatch("nogso", false) + return + } + if flushTimer == nil { + flushTimer = time.NewTimer(delay) + flushC = flushTimer.C + } + } + + for { + select { + case d := <-queue: + if d.packet == nil { + continue + } + if f.l.Level >= logrus.DebugLevel { + f.l.WithFields(logrus.Fields{ + "queue": i, + "payload_len": d.packet.Len, + "dest": d.addr, + }).Debug("send queue received packet") + } + pending = append(pending, d) + if gsoLimit > 0 && len(pending) >= gsoLimit { + dispatch("gso", false) + continue + } + if len(pending) >= cap(pending) { + dispatch("cap", false) + continue + } + armTimer() + f.observeUDPQueueLen(i) + case <-flushC: + dispatch("timer", true) + } + } +} + +func (f *Interface) runTunWriteQueue(i int) { + queue := f.batches.tunQueue(i) + if queue == nil { + return + } + writer := f.batches.inside + if writer == nil { + return + } + + batchCap := f.batches.batchSizeHint() + if batchCap <= 0 { + batchCap = 1 + } + pending := make([]*overlay.Packet, 0, batchCap) + var ( + flushTimer *time.Timer + flushC <-chan time.Time + ) + flush := func(reason string, timerFired bool) { + if len(pending) == 0 { + return + } + if _, err := writer.WriteBatch(pending); err != nil { + f.l.WithError(err). + WithField("queue", i). + WithField("reason", reason). + Warn("Failed to write tun batch") + } + for idx := range pending { + if pending[idx] != nil { + pending[idx].Release() + } + } + pending = pending[:0] + if flushTimer != nil { + if !timerFired { + if !flushTimer.Stop() { + select { + case <-flushTimer.C: + default: + } + } + } + flushTimer = nil + flushC = nil + } + } + armTimer := func() { + delay := f.currentBatchFlushInterval() + if delay <= 0 { + return + } + if flushTimer == nil { + flushTimer = time.NewTimer(delay) + flushC = flushTimer.C + } + } + + for { + select { + case pkt := <-queue: + if pkt == nil { + continue + } + pending = append(pending, pkt) + if len(pending) >= cap(pending) { + flush("cap", false) + continue + } + armTimer() + f.observeTunQueueLen(i) + case <-flushC: + flush("timer", true) + } + } +} + +func (f *Interface) flushAndReleaseBatch(index int, writer udp.Conn, batch []queuedDatagram, reason string) { + if len(batch) == 0 { + return + } + f.flushDatagrams(index, writer, batch, reason) + for idx := range batch { + if batch[idx].packet != nil { + batch[idx].packet.Release() + batch[idx].packet = nil + } + } + if f.batchUDPFlushCounter != nil { + f.batchUDPFlushCounter.Inc(int64(len(batch))) + } +} + +func (f *Interface) flushDatagrams(index int, writer udp.Conn, batch []queuedDatagram, reason string) { + if len(batch) == 0 { + return + } + if f.l.Level >= logrus.DebugLevel { + f.l.WithFields(logrus.Fields{ + "writer": index, + "reason": reason, + "pending": len(batch), + }).Debug("udp batch flush summary") + } + maxSeg := f.effectiveGSOMaxSegments() + if bw, ok := writer.(udp.BatchConn); ok { + chunkCap := maxSeg + if chunkCap <= 0 { + chunkCap = len(batch) + } + chunk := make([]udp.Datagram, 0, chunkCap) + var ( + currentAddr netip.AddrPort + segments int + ) + flushChunk := func() { + if len(chunk) == 0 { + return + } + if f.l.Level >= logrus.DebugLevel { + f.l.WithFields(logrus.Fields{ + "writer": index, + "segments": len(chunk), + "dest": chunk[0].Addr, + "reason": reason, + "pending_total": len(batch), + }).Debug("flushing UDP batch") + } + if err := bw.WriteBatch(chunk); err != nil { + f.l.WithError(err). + WithField("writer", index). + WithField("reason", reason). + Warn("Failed to write UDP batch") + } + chunk = chunk[:0] + segments = 0 + } + for _, item := range batch { + if item.packet == nil || !item.addr.IsValid() { + continue + } + payload := item.packet.Payload()[:item.packet.Len] + if segments == 0 { + currentAddr = item.addr + } + if item.addr != currentAddr || (maxSeg > 0 && segments >= maxSeg) { + flushChunk() + currentAddr = item.addr + } + chunk = append(chunk, udp.Datagram{Payload: payload, Addr: item.addr}) + segments++ + } + flushChunk() + return + } + for _, item := range batch { + if item.packet == nil || !item.addr.IsValid() { + continue + } + if f.l.Level >= logrus.DebugLevel { + f.l.WithFields(logrus.Fields{ + "writer": index, + "reason": reason, + "dest": item.addr, + "segments": 1, + }).Debug("flushing UDP batch") + } + if err := writer.WriteTo(item.packet.Payload()[:item.packet.Len], item.addr); err != nil { + f.l.WithError(err). + WithField("writer", index). + WithField("udpAddr", item.addr). + WithField("reason", reason). + Warn("Failed to write UDP packet") + } + } +} + +func (f *Interface) tryQueueDatagram(q int, buf []byte, addr netip.AddrPort) bool { + if !addr.IsValid() || !f.batches.Enabled() { + return false + } + pkt := f.batches.newPacket() + if pkt == nil { + return false + } + payload := pkt.Payload() + if len(payload) < len(buf) { + pkt.Release() + return false + } + copy(payload, buf) + pkt.Len = len(buf) + if f.batches.enqueueTx(q, pkt, addr) { + f.observeUDPQueueLen(q) + return true + } + pkt.Release() + return false +} + +func (f *Interface) writerForIndex(i int) udp.Conn { + if i < 0 || i >= len(f.writers) { + return nil + } + return f.writers[i] +} + +func (f *Interface) writeImmediate(q int, buf []byte, addr netip.AddrPort, hostinfo *HostInfo) { + writer := f.writerForIndex(q) + if writer == nil { + f.l.WithField("udpAddr", addr). + WithField("writer", q). + Error("Failed to write outgoing packet: no writer available") + return + } + if err := writer.WriteTo(buf, addr); err != nil { + hostinfo.logger(f.l). + WithError(err). + WithField("udpAddr", addr). + Error("Failed to write outgoing packet") + } +} + +func (f *Interface) tryQueuePacket(q int, pkt *overlay.Packet, addr netip.AddrPort) bool { + if pkt == nil || !addr.IsValid() || !f.batches.Enabled() { + return false + } + if f.batches.enqueueTx(q, pkt, addr) { + f.observeUDPQueueLen(q) + return true + } + return false +} + +func (f *Interface) writeImmediatePacket(q int, pkt *overlay.Packet, addr netip.AddrPort, hostinfo *HostInfo) { + if pkt == nil { + return + } + writer := f.writerForIndex(q) + if writer == nil { + f.l.WithField("udpAddr", addr). + WithField("writer", q). + Error("Failed to write outgoing packet: no writer available") + pkt.Release() + return + } + if err := writer.WriteTo(pkt.Payload()[:pkt.Len], addr); err != nil { + hostinfo.logger(f.l). + WithError(err). + WithField("udpAddr", addr). + Error("Failed to write outgoing packet") + } + pkt.Release() +} + +func (f *Interface) writePacketToTun(q int, pkt *overlay.Packet) { + if pkt == nil { + return + } + writer := f.readers[q] + if writer == nil { + pkt.Release() + return + } + if _, err := writer.Write(pkt.Payload()[:pkt.Len]); err != nil { + f.l.WithError(err).Error("Failed to write to tun") + } + pkt.Release() +} + +func (f *Interface) observeUDPQueueLen(i int) { + if f.batchUDPQueueGauge == nil { + return + } + f.batchUDPQueueGauge.Update(int64(f.batches.txQueueLen(i))) +} + +func (f *Interface) observeTunQueueLen(i int) { + if f.batchTunQueueGauge == nil { + return + } + f.batchTunQueueGauge.Update(int64(f.batches.tunQueueLen(i))) +} + +func (f *Interface) currentBatchFlushInterval() time.Duration { + if v := f.batchFlushInterval.Load(); v > 0 { + return time.Duration(v) + } + return 0 +} + +func (f *Interface) effectiveGSOMaxSegments() int { + max := f.gsoMaxSegments + if max <= 0 { + max = defaultGSOMaxSegments + } + if max > maxKernelGSOSegments { + max = maxKernelGSOSegments + } + if !f.enableGSO { + return 1 + } + return max +} + +type udpOffloadConfigurator interface { + ConfigureOffload(enableGSO, enableGRO bool, maxSegments int) +} + +func (f *Interface) applyOffloadConfig(enableGSO, enableGRO bool, maxSegments int) { + if maxSegments <= 0 { + maxSegments = defaultGSOMaxSegments + } + if maxSegments > maxKernelGSOSegments { + maxSegments = maxKernelGSOSegments + } + f.enableGSO = enableGSO + f.enableGRO = enableGRO + f.gsoMaxSegments = maxSegments + for _, writer := range f.writers { + if cfg, ok := writer.(udpOffloadConfigurator); ok { + cfg.ConfigureOffload(enableGSO, enableGRO, maxSegments) + } + } +} + func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { c.RegisterReloadCallback(f.reloadFirewall) c.RegisterReloadCallback(f.reloadSendRecvError) @@ -404,6 +969,42 @@ func (f *Interface) reloadMisc(c *config.C) { f.reQueryWait.Store(int64(n)) f.l.Info("timers.requery_wait_duration has changed") } + + if c.HasChanged("listen.gso_flush_timeout") { + d := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushInterval) + if d < 0 { + d = 0 + } + f.batchFlushInterval.Store(int64(d)) + f.l.WithField("duration", d).Info("listen.gso_flush_timeout has changed") + } else if c.HasChanged("batch.flush_interval") { + d := c.GetDuration("batch.flush_interval", defaultGSOFlushInterval) + if d < 0 { + d = 0 + } + f.batchFlushInterval.Store(int64(d)) + f.l.WithField("duration", d).Warn("batch.flush_interval is deprecated; use listen.gso_flush_timeout") + } + + if c.HasChanged("batch.queue_depth") { + n := c.GetInt("batch.queue_depth", f.batchQueueDepth) + if n != f.batchQueueDepth { + f.batchQueueDepth = n + f.l.Warn("batch.queue_depth changes require a restart to take effect") + } + } + + if c.HasChanged("listen.enable_gso") || c.HasChanged("listen.enable_gro") || c.HasChanged("listen.gso_max_segments") { + enableGSO := c.GetBool("listen.enable_gso", f.enableGSO) + enableGRO := c.GetBool("listen.enable_gro", f.enableGRO) + maxSeg := c.GetInt("listen.gso_max_segments", f.gsoMaxSegments) + f.applyOffloadConfig(enableGSO, enableGRO, maxSeg) + f.l.WithFields(logrus.Fields{ + "enableGSO": enableGSO, + "enableGRO": enableGRO, + "gsoMaxSegments": maxSeg, + }).Info("listen GSO/GRO configuration updated") + } } func (f *Interface) emitStats(ctx context.Context, i time.Duration) { diff --git a/main.go b/main.go index d9666ce..9ef8ab8 100644 --- a/main.go +++ b/main.go @@ -144,6 +144,20 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg // set up our UDP listener udpConns := make([]udp.Conn, routines) port := c.GetInt("listen.port", 0) + enableGSO := c.GetBool("listen.enable_gso", true) + enableGRO := c.GetBool("listen.enable_gro", true) + gsoMaxSegments := c.GetInt("listen.gso_max_segments", defaultGSOMaxSegments) + if gsoMaxSegments <= 0 { + gsoMaxSegments = defaultGSOMaxSegments + } + if gsoMaxSegments > maxKernelGSOSegments { + gsoMaxSegments = maxKernelGSOSegments + } + gsoFlushTimeout := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushInterval) + if gsoFlushTimeout < 0 { + gsoFlushTimeout = 0 + } + batchQueueDepth := c.GetInt("batch.queue_depth", 0) if !configTest { rawListenHost := c.GetString("listen.host", "0.0.0.0") @@ -179,6 +193,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) } udpServer.ReloadConfig(c) + if cfg, ok := udpServer.(interface { + ConfigureOffload(bool, bool, int) + }); ok { + cfg.ConfigureOffload(enableGSO, enableGRO, gsoMaxSegments) + } udpConns[i] = udpServer // If port is dynamic, discover it before the next pass through the for loop @@ -246,12 +265,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait), DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false), DropMulticast: c.GetBool("tun.drop_multicast", false), + EnableGSO: enableGSO, + EnableGRO: enableGRO, + GSOMaxSegments: gsoMaxSegments, routines: routines, MessageMetrics: messageMetrics, version: buildVersion, relayManager: NewRelayManager(ctx, l, hostMap, c), punchy: punchy, ConntrackCacheTimeout: conntrackCacheTimeout, + BatchFlushInterval: gsoFlushTimeout, + BatchQueueDepth: batchQueueDepth, l: l, } @@ -263,6 +287,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } ifce.writers = udpConns + ifce.applyOffloadConfig(enableGSO, enableGRO, gsoMaxSegments) lightHouse.ifce = ifce ifce.RegisterConfigChangeCallbacks(c) diff --git a/outside.go b/outside.go index 5ff87bd..6135fd0 100644 --- a/outside.go +++ b/outside.go @@ -12,6 +12,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/overlay" "golang.org/x/net/ipv4" ) @@ -466,22 +467,41 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] } func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool { - var err error + var ( + err error + pkt *overlay.Packet + ) + + if f.batches.tunQueue(q) != nil { + pkt = f.batches.newPacket() + if pkt != nil { + out = pkt.Payload()[:0] + } + } out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) if err != nil { + if pkt != nil { + pkt.Release() + } hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") return false } err = newPacket(out, true, fwPacket) if err != nil { + if pkt != nil { + pkt.Release() + } hostinfo.logger(f.l).WithError(err).WithField("packet", out). Warnf("Error while validating inbound packet") return false } if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) { + if pkt != nil { + pkt.Release() + } hostinfo.logger(f.l).WithField("fwPacket", fwPacket). Debugln("dropping out of window packet") return false @@ -489,6 +509,9 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) if dropReason != nil { + if pkt != nil { + pkt.Release() + } // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // This gives us a buffer to build the reject packet in f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q) @@ -501,8 +524,17 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out } f.connectionManager.In(hostinfo) - _, err = f.readers[q].Write(out) - if err != nil { + if pkt != nil { + pkt.Len = len(out) + if f.batches.enqueueTun(q, pkt) { + f.observeTunQueueLen(q) + return true + } + f.writePacketToTun(q, pkt) + return true + } + + if _, err = f.readers[q].Write(out); err != nil { f.l.WithError(err).Error("Failed to write to tun") } return true diff --git a/overlay/device.go b/overlay/device.go index 07146ab..b44b095 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -3,6 +3,7 @@ package overlay import ( "io" "net/netip" + "sync" "github.com/slackhq/nebula/routing" ) @@ -15,3 +16,84 @@ type Device interface { RoutesFor(netip.Addr) routing.Gateways NewMultiQueueReader() (io.ReadWriteCloser, error) } + +// Packet represents a single packet buffer with optional headroom to carry +// metadata (for example virtio-net headers). +type Packet struct { + Buf []byte + Offset int + Len int + release func() +} + +func (p *Packet) Payload() []byte { + return p.Buf[p.Offset : p.Offset+p.Len] +} + +func (p *Packet) Reset() { + p.Len = 0 + p.Offset = 0 + p.release = nil +} + +func (p *Packet) Release() { + if p.release != nil { + p.release() + p.release = nil + } +} + +func (p *Packet) Capacity() int { + return len(p.Buf) - p.Offset +} + +// PacketPool manages reusable buffers with headroom. +type PacketPool struct { + headroom int + blksz int + pool sync.Pool +} + +func NewPacketPool(headroom, payload int) *PacketPool { + p := &PacketPool{headroom: headroom, blksz: headroom + payload} + p.pool.New = func() any { + buf := make([]byte, p.blksz) + return &Packet{Buf: buf, Offset: headroom} + } + return p +} + +func (p *PacketPool) Get() *Packet { + pkt := p.pool.Get().(*Packet) + pkt.Offset = p.headroom + pkt.Len = 0 + pkt.release = func() { p.put(pkt) } + return pkt +} + +func (p *PacketPool) put(pkt *Packet) { + pkt.Reset() + p.pool.Put(pkt) +} + +// BatchReader allows reading multiple packets into a shared pool with +// preallocated headroom (e.g. virtio-net headers). +type BatchReader interface { + ReadIntoBatch(pool *PacketPool) ([]*Packet, error) +} + +// BatchWriter writes a slice of packets that carry their own metadata. +type BatchWriter interface { + WriteBatch(packets []*Packet) (int, error) +} + +// BatchCapableDevice describes a device that can efficiently read and write +// batches of packets with virtio headroom. +type BatchCapableDevice interface { + Device + BatchReader + BatchWriter + BatchHeadroom() int + BatchPayloadCap() int + BatchSize() int +} diff --git a/overlay/tun_linux_batch.go b/overlay/tun_linux_batch.go new file mode 100644 index 0000000..290d165 --- /dev/null +++ b/overlay/tun_linux_batch.go @@ -0,0 +1,56 @@ +//go:build linux && !android && !e2e_testing + +package overlay + +import "fmt" + +func (t *tun) batchIO() (*wireguardTunIO, bool) { + io, ok := t.ReadWriteCloser.(*wireguardTunIO) + return io, ok +} + +func (t *tun) ReadIntoBatch(pool *PacketPool) ([]*Packet, error) { + io, ok := t.batchIO() + if !ok { + return nil, fmt.Errorf("wireguard batch I/O not enabled") + } + return io.ReadIntoBatch(pool) +} + +func (t *tun) WriteBatch(packets []*Packet) (int, error) { + io, ok := t.batchIO() + if ok { + return io.WriteBatch(packets) + } + for _, pkt := range packets { + if pkt == nil { + continue + } + if _, err := t.Write(pkt.Payload()[:pkt.Len]); err != nil { + return 0, err + } + pkt.Release() + } + return len(packets), nil +} + +func (t *tun) BatchHeadroom() int { + if io, ok := t.batchIO(); ok { + return io.BatchHeadroom() + } + return 0 +} + +func (t *tun) BatchPayloadCap() int { + if io, ok := t.batchIO(); ok { + return io.BatchPayloadCap() + } + return 0 +} + +func (t *tun) BatchSize() int { + if io, ok := t.batchIO(); ok { + return io.BatchSize() + } + return 1 +} diff --git a/overlay/wireguard_tun_linux.go b/overlay/wireguard_tun_linux.go index 1c0fc5d..c8a36fc 100644 --- a/overlay/wireguard_tun_linux.go +++ b/overlay/wireguard_tun_linux.go @@ -14,15 +14,15 @@ type wireguardTunIO struct { mtu int batchSize int - readMu sync.Mutex - readBufs [][]byte - readLens []int - pending [][]byte - pendIdx int + readMu sync.Mutex + readBuffers [][]byte + readLens []int + legacyBuf []byte - writeMu sync.Mutex - writeBuf []byte - writeWrap [][]byte + writeMu sync.Mutex + writeBuf []byte + writeWrap [][]byte + writeBuffers [][]byte } func newWireguardTunIO(dev wgtun.Device, mtu int) *wireguardTunIO { @@ -33,17 +33,12 @@ func newWireguardTunIO(dev wgtun.Device, mtu int) *wireguardTunIO { if mtu <= 0 { mtu = DefaultMTU } - bufs := make([][]byte, batch) - for i := range bufs { - bufs[i] = make([]byte, wgtun.VirtioNetHdrLen+mtu) - } return &wireguardTunIO{ dev: dev, mtu: mtu, batchSize: batch, - readBufs: bufs, readLens: make([]int, batch), - pending: make([][]byte, 0, batch), + legacyBuf: make([]byte, wgtun.VirtioNetHdrLen+mtu), writeBuf: make([]byte, wgtun.VirtioNetHdrLen+mtu), writeWrap: make([][]byte, 1), } @@ -53,29 +48,21 @@ func (w *wireguardTunIO) Read(p []byte) (int, error) { w.readMu.Lock() defer w.readMu.Unlock() - for { - if w.pendIdx < len(w.pending) { - segment := w.pending[w.pendIdx] - w.pendIdx++ - n := copy(p, segment) - return n, nil - } - - n, err := w.dev.Read(w.readBufs, w.readLens, wgtun.VirtioNetHdrLen) - if err != nil { - return 0, err - } - w.pending = w.pending[:0] - w.pendIdx = 0 - for i := 0; i < n; i++ { - length := w.readLens[i] - if length == 0 { - continue - } - segment := w.readBufs[i][wgtun.VirtioNetHdrLen : wgtun.VirtioNetHdrLen+length] - w.pending = append(w.pending, segment) - } + bufs := w.readBuffers + if len(bufs) == 0 { + bufs = [][]byte{w.legacyBuf} + w.readBuffers = bufs } + n, err := w.dev.Read(bufs[:1], w.readLens[:1], wgtun.VirtioNetHdrLen) + if err != nil { + return 0, err + } + if n == 0 { + return 0, nil + } + length := w.readLens[0] + copy(p, w.legacyBuf[wgtun.VirtioNetHdrLen:wgtun.VirtioNetHdrLen+length]) + return length, nil } func (w *wireguardTunIO) Write(p []byte) (int, error) { @@ -97,6 +84,134 @@ func (w *wireguardTunIO) Write(p []byte) (int, error) { return len(p), nil } +func (w *wireguardTunIO) ReadIntoBatch(pool *PacketPool) ([]*Packet, error) { + if pool == nil { + return nil, fmt.Errorf("wireguard tun: packet pool is nil") + } + + w.readMu.Lock() + defer w.readMu.Unlock() + + if len(w.readBuffers) < w.batchSize { + w.readBuffers = make([][]byte, w.batchSize) + } + if len(w.readLens) < w.batchSize { + w.readLens = make([]int, w.batchSize) + } + + packets := make([]*Packet, w.batchSize) + requiredHeadroom := w.BatchHeadroom() + requiredPayload := w.BatchPayloadCap() + headroom := 0 + for i := 0; i < w.batchSize; i++ { + pkt := pool.Get() + if pkt == nil { + releasePackets(packets[:i]) + return nil, fmt.Errorf("wireguard tun: packet pool returned nil packet") + } + if pkt.Capacity() < requiredPayload { + pkt.Release() + releasePackets(packets[:i]) + return nil, fmt.Errorf("wireguard tun: packet capacity %d below required %d", pkt.Capacity(), requiredPayload) + } + if i == 0 { + headroom = pkt.Offset + if headroom < requiredHeadroom { + pkt.Release() + releasePackets(packets[:i]) + return nil, fmt.Errorf("wireguard tun: packet headroom %d below virtio requirement %d", headroom, requiredHeadroom) + } + } else if pkt.Offset != headroom { + pkt.Release() + releasePackets(packets[:i]) + return nil, fmt.Errorf("wireguard tun: inconsistent packet headroom (%d != %d)", pkt.Offset, headroom) + } + packets[i] = pkt + w.readBuffers[i] = pkt.Buf + } + + n, err := w.dev.Read(w.readBuffers[:w.batchSize], w.readLens[:w.batchSize], headroom) + if err != nil { + releasePackets(packets) + return nil, err + } + if n == 0 { + releasePackets(packets) + return nil, nil + } + for i := 0; i < n; i++ { + packets[i].Len = w.readLens[i] + } + for i := n; i < w.batchSize; i++ { + packets[i].Release() + packets[i] = nil + } + return packets[:n], nil +} + +func (w *wireguardTunIO) WriteBatch(packets []*Packet) (int, error) { + if len(packets) == 0 { + return 0, nil + } + requiredHeadroom := w.BatchHeadroom() + offset := packets[0].Offset + if offset < requiredHeadroom { + releasePackets(packets) + return 0, fmt.Errorf("wireguard tun: packet offset %d smaller than required headroom %d", offset, requiredHeadroom) + } + for _, pkt := range packets { + if pkt == nil { + continue + } + if pkt.Offset != offset { + releasePackets(packets) + return 0, fmt.Errorf("wireguard tun: mixed packet offsets not supported") + } + limit := pkt.Offset + pkt.Len + if limit > len(pkt.Buf) { + releasePackets(packets) + return 0, fmt.Errorf("wireguard tun: packet length %d exceeds buffer capacity %d", pkt.Len, len(pkt.Buf)-pkt.Offset) + } + } + w.writeMu.Lock() + defer w.writeMu.Unlock() + + if len(w.writeBuffers) < len(packets) { + w.writeBuffers = make([][]byte, len(packets)) + } + for i, pkt := range packets { + if pkt == nil { + w.writeBuffers[i] = nil + continue + } + limit := pkt.Offset + pkt.Len + w.writeBuffers[i] = pkt.Buf[:limit] + } + n, err := w.dev.Write(w.writeBuffers[:len(packets)], offset) + releasePackets(packets) + return n, err +} + +func (w *wireguardTunIO) BatchHeadroom() int { + return wgtun.VirtioNetHdrLen +} + +func (w *wireguardTunIO) BatchPayloadCap() int { + return w.mtu +} + +func (w *wireguardTunIO) BatchSize() int { + return w.batchSize +} + func (w *wireguardTunIO) Close() error { return nil } + +func releasePackets(pkts []*Packet) { + for _, pkt := range pkts { + if pkt != nil { + pkt.Release() + } + } +} diff --git a/udp/conn.go b/udp/conn.go index 895b0df..fdadba3 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -22,6 +22,18 @@ type Conn interface { Close() error } +// Datagram represents a UDP payload destined to a specific address. +type Datagram struct { + Payload []byte + Addr netip.AddrPort +} + +// BatchConn can send multiple datagrams in one syscall. +type BatchConn interface { + Conn + WriteBatch(pkts []Datagram) error +} + type NoopConn struct{} func (NoopConn) Rebind() error { diff --git a/udp/wireguard_conn_linux.go b/udp/wireguard_conn_linux.go index c6a1ede..c3f9e9a 100644 --- a/udp/wireguard_conn_linux.go +++ b/udp/wireguard_conn_linux.go @@ -20,8 +20,12 @@ type WGConn struct { bind *wgconn.StdNetBind recvers []wgconn.ReceiveFunc batch int + reqBatch int localIP netip.Addr localPort uint16 + enableGSO bool + enableGRO bool + gsoMaxSeg int closed atomic.Bool closeOnce sync.Once @@ -34,7 +38,9 @@ func NewWireguardListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, if err != nil { return nil, err } - if batch <= 0 || batch > bind.BatchSize() { + if batch <= 0 { + batch = bind.BatchSize() + } else if batch > bind.BatchSize() { batch = bind.BatchSize() } return &WGConn{ @@ -42,6 +48,7 @@ func NewWireguardListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, bind: bind, recvers: recvers, batch: batch, + reqBatch: batch, localIP: ip, localPort: actualPort, }, nil @@ -118,6 +125,92 @@ func (c *WGConn) WriteTo(b []byte, addr netip.AddrPort) error { return c.bind.Send([][]byte{b}, ep) } +func (c *WGConn) WriteBatch(datagrams []Datagram) error { + if len(datagrams) == 0 { + return nil + } + if c.closed.Load() { + return net.ErrClosed + } + max := c.batch + if max <= 0 { + max = len(datagrams) + if max == 0 { + max = 1 + } + } + bufs := make([][]byte, 0, max) + var ( + current netip.AddrPort + endpoint *wgconn.StdNetEndpoint + haveAddr bool + ) + flush := func() error { + if len(bufs) == 0 || endpoint == nil { + bufs = bufs[:0] + return nil + } + err := c.bind.Send(bufs, endpoint) + bufs = bufs[:0] + return err + } + + for _, d := range datagrams { + if len(d.Payload) == 0 || !d.Addr.IsValid() { + continue + } + if !haveAddr || d.Addr != current { + if err := flush(); err != nil { + return err + } + current = d.Addr + endpoint = &wgconn.StdNetEndpoint{AddrPort: current} + haveAddr = true + } + bufs = append(bufs, d.Payload) + if len(bufs) >= max { + if err := flush(); err != nil { + return err + } + } + } + return flush() +} + +func (c *WGConn) ConfigureOffload(enableGSO, enableGRO bool, maxSegments int) { + c.enableGSO = enableGSO + c.enableGRO = enableGRO + if maxSegments <= 0 { + maxSegments = 1 + } else if maxSegments > wgconn.IdealBatchSize { + maxSegments = wgconn.IdealBatchSize + } + c.gsoMaxSeg = maxSegments + + effectiveBatch := c.reqBatch + if enableGSO && c.bind != nil { + bindBatch := c.bind.BatchSize() + if effectiveBatch < bindBatch { + if c.l != nil { + c.l.WithFields(logrus.Fields{ + "requested": c.reqBatch, + "effective": bindBatch, + }).Warn("listen.batch below wireguard minimum; using bind batch size for UDP GSO support") + } + effectiveBatch = bindBatch + } + } + c.batch = effectiveBatch + + if c.l != nil { + c.l.WithFields(logrus.Fields{ + "enableGSO": enableGSO, + "enableGRO": enableGRO, + "gsoMaxSegments": maxSegments, + }).Debug("configured wireguard UDP offload") + } +} + func (c *WGConn) ReloadConfig(*config.C) { // WireGuard bind currently does not expose runtime configuration knobs. } diff --git a/wgstack/conn/errors_default.go b/wgstack/conn/errors_default.go new file mode 100644 index 0000000..f1e5b90 --- /dev/null +++ b/wgstack/conn/errors_default.go @@ -0,0 +1,12 @@ +//go:build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +func errShouldDisableUDPGSO(err error) bool { + return false +} diff --git a/wgstack/conn/errors_linux.go b/wgstack/conn/errors_linux.go new file mode 100644 index 0000000..8e61000 --- /dev/null +++ b/wgstack/conn/errors_linux.go @@ -0,0 +1,26 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "errors" + "os" + + "golang.org/x/sys/unix" +) + +func errShouldDisableUDPGSO(err error) bool { + var serr *os.SyscallError + if errors.As(err, &serr) { + // EIO is returned by udp_send_skb() if the device driver does not have + // tx checksumming enabled, which is a hard requirement of UDP_SEGMENT. + // See: + // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 + // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 + return serr.Err == unix.EIO + } + return false +} diff --git a/wgstack/conn/features_default.go b/wgstack/conn/features_default.go new file mode 100644 index 0000000..d53ff5f --- /dev/null +++ b/wgstack/conn/features_default.go @@ -0,0 +1,15 @@ +//go:build !linux +// +build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import "net" + +func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { + return +} diff --git a/wgstack/conn/features_linux.go b/wgstack/conn/features_linux.go new file mode 100644 index 0000000..8959d93 --- /dev/null +++ b/wgstack/conn/features_linux.go @@ -0,0 +1,29 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "net" + + "golang.org/x/sys/unix" +) + +func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { + rc, err := conn.SyscallConn() + if err != nil { + return + } + err = rc.Control(func(fd uintptr) { + _, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT) + txOffload = errSyscall == nil + opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO) + rxOffload = errSyscall == nil && opt == 1 + }) + if err != nil { + return false, false + } + return txOffload, rxOffload +} diff --git a/wgstack/conn/gso_default.go b/wgstack/conn/gso_default.go new file mode 100644 index 0000000..57780db --- /dev/null +++ b/wgstack/conn/gso_default.go @@ -0,0 +1,21 @@ +//go:build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. +func getGSOSize(control []byte) (int, error) { + return 0, nil +} + +// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. +func setGSOSize(control *[]byte, gsoSize uint16) { +} + +// gsoControlSize returns the recommended buffer size for pooling sticky and UDP +// offloading control data. +const gsoControlSize = 0 diff --git a/wgstack/conn/gso_linux.go b/wgstack/conn/gso_linux.go new file mode 100644 index 0000000..8596b29 --- /dev/null +++ b/wgstack/conn/gso_linux.go @@ -0,0 +1,65 @@ +//go:build linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "fmt" + "unsafe" + + "golang.org/x/sys/unix" +) + +const ( + sizeOfGSOData = 2 +) + +// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. +func getGSOSize(control []byte) (int, error) { + var ( + hdr unix.Cmsghdr + data []byte + rem = control + err error + ) + + for len(rem) > unix.SizeofCmsghdr { + hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) + if err != nil { + return 0, fmt.Errorf("error parsing socket control message: %w", err) + } + if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData { + var gso uint16 + copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData]) + return int(gso), nil + } + } + return 0, nil +} + +// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing +// data in control untouched. +func setGSOSize(control *[]byte, gsoSize uint16) { + existingLen := len(*control) + avail := cap(*control) - existingLen + space := unix.CmsgSpace(sizeOfGSOData) + if avail < space { + return + } + *control = (*control)[:cap(*control)] + gsoControl := (*control)[existingLen:] + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0])) + hdr.Level = unix.SOL_UDP + hdr.Type = unix.UDP_SEGMENT + hdr.SetLen(unix.CmsgLen(sizeOfGSOData)) + copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData)) + *control = (*control)[:existingLen+space] +} + +// gsoControlSize returns the recommended buffer size for pooling UDP +// offloading control data. +var gsoControlSize = unix.CmsgSpace(sizeOfGSOData) diff --git a/wgstack/conn/sticky_default.go b/wgstack/conn/sticky_default.go new file mode 100644 index 0000000..0b21386 --- /dev/null +++ b/wgstack/conn/sticky_default.go @@ -0,0 +1,42 @@ +//go:build !linux || android + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import "net/netip" + +func (e *StdNetEndpoint) SrcIP() netip.Addr { + return netip.Addr{} +} + +func (e *StdNetEndpoint) SrcIfidx() int32 { + return 0 +} + +func (e *StdNetEndpoint) SrcToString() string { + return "" +} + +// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets +// {get,set}srcControl feature set, but use alternatively named flags and need +// ports and require testing. + +// getSrcFromControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func getSrcFromControl(control []byte, ep *StdNetEndpoint) { +} + +// setSrcControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func setSrcControl(control *[]byte, ep *StdNetEndpoint) { +} + +// stickyControlSize returns the recommended buffer size for pooling sticky +// offloading control data. +const stickyControlSize = 0 + +const StdNetSupportsStickySockets = false