diff --git a/inside.go b/inside.go index 22b63d4..9e3672c 100644 --- a/inside.go +++ b/inside.go @@ -2,6 +2,7 @@ package nebula import ( "net/netip" + "unsafe" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" @@ -384,6 +385,11 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType } var err error + if len(p) > 0 && slicesOverlap(out, p) { + tmp := make([]byte, len(p)) + copy(tmp, p) + p = tmp + } out, err = ci.eKey.EncryptDanger(out, out, p, c, nb) if noiseutil.EncryptLockNeeded { ci.writeLock.Unlock() @@ -447,3 +453,17 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType break } } + +// slicesOverlap reports whether the two byte slices share any portion of memory. +// cipher.AEAD.Seal requires plaintext and dst to live in disjoint regions. +func slicesOverlap(a, b []byte) bool { + if len(a) == 0 || len(b) == 0 { + return false + } + + aStart := uintptr(unsafe.Pointer(&a[0])) + aEnd := aStart + uintptr(len(a)) + bStart := uintptr(unsafe.Pointer(&b[0])) + bEnd := bStart + uintptr(len(b)) + return aStart < bEnd && bStart < aEnd +} diff --git a/interface.go b/interface.go index 05ad918..b4abade 100644 --- a/interface.go +++ b/interface.go @@ -8,6 +8,7 @@ import ( "net/netip" "os" "runtime" + "strings" "sync/atomic" "time" @@ -399,6 +400,12 @@ func (f *Interface) listenInBatchLocked(raw io.ReadWriteCloser, reader overlay.B return } + if isVirtioHeadroomError(err) { + f.l.WithError(err).Warn("Batch reader fell back due to tun headroom issue") + f.listenInLegacyLocked(raw, i) + return + } + f.l.WithError(err).Error("Error while reading outbound packet batch") os.Exit(2) } @@ -549,6 +556,7 @@ func (f *Interface) runTunWriteQueue(i int) { if writer == nil { return } + requiredHeadroom := writer.BatchHeadroom() batchCap := f.batches.batchSizeHint() if batchCap <= 0 { @@ -563,15 +571,27 @@ func (f *Interface) runTunWriteQueue(i int) { if len(pending) == 0 { return } - if _, err := writer.WriteBatch(pending); err != nil { - f.l.WithError(err). - WithField("queue", i). - WithField("reason", reason). - Warn("Failed to write tun batch") - } + valid := pending[:0] for idx := range pending { + if !f.ensurePacketHeadroom(&pending[idx], requiredHeadroom, i, reason) { + pending[idx] = nil + continue + } if pending[idx] != nil { - pending[idx].Release() + valid = append(valid, pending[idx]) + } + } + if len(valid) > 0 { + if _, err := writer.WriteBatch(valid); err != nil { + f.l.WithError(err). + WithField("queue", i). + WithField("reason", reason). + Warn("Failed to write tun batch") + for _, pkt := range valid { + if pkt != nil { + f.writePacketToTun(i, pkt) + } + } } } pending = pending[:0] @@ -605,7 +625,9 @@ func (f *Interface) runTunWriteQueue(i int) { if pkt == nil { continue } - pending = append(pending, pkt) + if f.ensurePacketHeadroom(&pkt, requiredHeadroom, i, "queue") { + pending = append(pending, pkt) + } if len(pending) >= cap(pending) { flush("cap", false) continue @@ -811,6 +833,40 @@ func (f *Interface) writePacketToTun(q int, pkt *overlay.Packet) { pkt.Release() } +func (f *Interface) clonePacketWithHeadroom(pkt *overlay.Packet, required int) *overlay.Packet { + if pkt == nil { + return nil + } + payload := pkt.Payload()[:pkt.Len] + if len(payload) == 0 && required <= 0 { + return pkt + } + + pool := f.batches.Pool() + if pool != nil { + if clone := pool.Get(); clone != nil { + if len(clone.Payload()) >= len(payload) { + clone.Len = copy(clone.Payload(), payload) + pkt.Release() + return clone + } + clone.Release() + } + } + + if required < 0 { + required = 0 + } + buf := make([]byte, required+len(payload)) + n := copy(buf[required:], payload) + pkt.Release() + return &overlay.Packet{ + Buf: buf, + Offset: required, + Len: n, + } +} + func (f *Interface) observeUDPQueueLen(i int) { if f.batchUDPQueueGauge == nil { return @@ -832,6 +888,34 @@ func (f *Interface) currentBatchFlushInterval() time.Duration { return 0 } +func (f *Interface) ensurePacketHeadroom(pkt **overlay.Packet, required int, queue int, reason string) bool { + p := *pkt + if p == nil { + return false + } + if required <= 0 || p.Offset >= required { + return true + } + clone := f.clonePacketWithHeadroom(p, required) + if clone == nil { + f.l.WithFields(logrus.Fields{ + "queue": queue, + "reason": reason, + }).Warn("dropping packet lacking tun headroom") + return false + } + *pkt = clone + return true +} + +func isVirtioHeadroomError(err error) bool { + if err == nil { + return false + } + msg := err.Error() + return strings.Contains(msg, "headroom") || strings.Contains(msg, "virtio") +} + func (f *Interface) effectiveGSOMaxSegments() int { max := f.gsoMaxSegments if max <= 0 { diff --git a/overlay/wireguard_tun_linux.go b/overlay/wireguard_tun_linux.go index c8a36fc..1579c4b 100644 --- a/overlay/wireguard_tun_linux.go +++ b/overlay/wireguard_tun_linux.go @@ -188,8 +188,11 @@ func (w *wireguardTunIO) WriteBatch(packets []*Packet) (int, error) { w.writeBuffers[i] = pkt.Buf[:limit] } n, err := w.dev.Write(w.writeBuffers[:len(packets)], offset) + if err != nil { + return n, err + } releasePackets(packets) - return n, err + return n, nil } func (w *wireguardTunIO) BatchHeadroom() int {