This commit is contained in:
JackDoan
2026-04-20 11:57:48 -05:00
parent f34e8fe0e6
commit f4907b6634
3 changed files with 27 additions and 33 deletions

View File

@@ -47,7 +47,7 @@ type tunFile struct {
segOff int // cursor into segBuf for the current ReadBatch drain segOff int // cursor into segBuf for the current ReadBatch drain
pending [][]byte // segments waiting to be drained by Read pending [][]byte // segments waiting to be drained by Read
pendingIdx int pendingIdx int
writeIovs [2]unix.Iovec // preallocated iovecs for Write (coalescer passthrough); iovs[0] is fixed to zeroVnetHdr writeIovs [2]unix.Iovec // preallocated iovecs for Write (coalescer passthrough); iovs[0] is fixed to validVnetHdr
// rejectIovs is a second preallocated iovec scratch used exclusively by // rejectIovs is a second preallocated iovec scratch used exclusively by
// WriteReject (reject + self-forward from the inside path). It mirrors // WriteReject (reject + self-forward from the inside path). It mirrors
// writeIovs but lets listenIn goroutines emit reject packets without // writeIovs but lets listenIn goroutines emit reject packets without
@@ -55,7 +55,7 @@ type tunFile struct {
rejectIovs [2]unix.Iovec rejectIovs [2]unix.Iovec
// gsoHdrBuf is a per-queue 10-byte scratch for the virtio_net_hdr emitted // gsoHdrBuf is a per-queue 10-byte scratch for the virtio_net_hdr emitted
// by WriteGSO. Separate from zeroVnetHdr so a concurrent non-GSO Write on // by WriteGSO. Separate from validVnetHdr so a concurrent non-GSO Write on
// another queue never observes a half-written header. // another queue never observes a half-written header.
gsoHdrBuf [virtioNetHdrLen]byte gsoHdrBuf [virtioNetHdrLen]byte
// gsoIovs is the writev iovec scratch for WriteGSO. Sized to hold the // gsoIovs is the writev iovec scratch for WriteGSO. Sized to hold the
@@ -69,10 +69,15 @@ type tunFile struct {
// any reallocations. // any reallocations.
const gsoInitialPayIovs = 66 const gsoInitialPayIovs = 66
// zeroVnetHdr is the 10-byte virtio_net_hdr we prepend to every TUN write when // validVnetHdr is the 10-byte virtio_net_hdr we prepend to every non-GSO TUN
// IFF_VNET_HDR is active. All-zero signals "no GSO, no checksum offload"; the // write. Only flag set is VIRTIO_NET_HDR_F_DATA_VALID, which marks the skb
// kernel accepts the packet as-is. // CHECKSUM_UNNECESSARY so the receiving network stack skips L4 checksum
var zeroVnetHdr [virtioNetHdrLen]byte // verification. All packets that reach the plain Write / WriteReject paths
// already carry a valid L4 checksum (either supplied by a remote peer whose
// ciphertext we AEAD-authenticated, or produced by finishChecksum during TSO
// segmentation, or built locally by CreateRejectPacket), so trusting them is
// safe.
var validVnetHdr = [virtioNetHdrLen]byte{unix.VIRTIO_NET_HDR_F_DATA_VALID}
// newFriend makes a tunFile for a MultiQueueReader that copies the shutdown eventfd from the parent tun // newFriend makes a tunFile for a MultiQueueReader that copies the shutdown eventfd from the parent tun
func (r *tunFile) newFriend(fd int) (*tunFile, error) { func (r *tunFile) newFriend(fd int) (*tunFile, error) {
@@ -95,9 +100,9 @@ func (r *tunFile) newFriend(fd int) (*tunFile, error) {
} }
if r.vnetHdr { if r.vnetHdr {
out.segBuf = make([]byte, tunSegBufCap) out.segBuf = make([]byte, tunSegBufCap)
out.writeIovs[0].Base = &zeroVnetHdr[0] out.writeIovs[0].Base = &validVnetHdr[0]
out.writeIovs[0].SetLen(virtioNetHdrLen) out.writeIovs[0].SetLen(virtioNetHdrLen)
out.rejectIovs[0].Base = &zeroVnetHdr[0] out.rejectIovs[0].Base = &validVnetHdr[0]
out.rejectIovs[0].SetLen(virtioNetHdrLen) out.rejectIovs[0].SetLen(virtioNetHdrLen)
out.gsoIovs = make([]unix.Iovec, 2, 2+gsoInitialPayIovs) out.gsoIovs = make([]unix.Iovec, 2, 2+gsoInitialPayIovs)
out.gsoIovs[0].Base = &out.gsoHdrBuf[0] out.gsoIovs[0].Base = &out.gsoHdrBuf[0]
@@ -133,9 +138,9 @@ func newTunFd(fd int, vnetHdr bool) (*tunFile, error) {
} }
if vnetHdr { if vnetHdr {
out.segBuf = make([]byte, tunSegBufCap) out.segBuf = make([]byte, tunSegBufCap)
out.writeIovs[0].Base = &zeroVnetHdr[0] out.writeIovs[0].Base = &validVnetHdr[0]
out.writeIovs[0].SetLen(virtioNetHdrLen) out.writeIovs[0].SetLen(virtioNetHdrLen)
out.rejectIovs[0].Base = &zeroVnetHdr[0] out.rejectIovs[0].Base = &validVnetHdr[0]
out.rejectIovs[0].SetLen(virtioNetHdrLen) out.rejectIovs[0].SetLen(virtioNetHdrLen)
out.gsoIovs = make([]unix.Iovec, 2, 2+gsoInitialPayIovs) out.gsoIovs = make([]unix.Iovec, 2, 2+gsoInitialPayIovs)
out.gsoIovs[0].Base = &out.gsoHdrBuf[0] out.gsoIovs[0].Base = &out.gsoHdrBuf[0]
@@ -340,7 +345,7 @@ func (r *tunFile) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error)
return 0, nil return 0, nil
} }
// Point the payload iovec at the caller's buffer. iovs[0] is pre-wired // Point the payload iovec at the caller's buffer. iovs[0] is pre-wired
// to zeroVnetHdr during tunFile construction so we don't rebuild it here. // to validVnetHdr during tunFile construction so we don't rebuild it here.
iovs[1].Base = &buf[0] iovs[1].Base = &buf[0]
iovs[1].SetLen(len(buf)) iovs[1].SetLen(len(buf))
iovPtr := uintptr(unsafe.Pointer(&iovs[0])) iovPtr := uintptr(unsafe.Pointer(&iovs[0]))

View File

@@ -313,7 +313,7 @@ func TestTunFileWriteVnetHdrNoAlloc(t *testing.T) {
t.Cleanup(func() { _ = unix.Close(fd) }) t.Cleanup(func() { _ = unix.Close(fd) })
tf := &tunFile{fd: fd, vnetHdr: true} tf := &tunFile{fd: fd, vnetHdr: true}
tf.writeIovs[0].Base = &zeroVnetHdr[0] tf.writeIovs[0].Base = &validVnetHdr[0]
tf.writeIovs[0].SetLen(virtioNetHdrLen) tf.writeIovs[0].SetLen(virtioNetHdrLen)
payload := make([]byte, 1400) payload := make([]byte, 1400)

View File

@@ -1,6 +1,7 @@
package nebula package nebula
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"io" "io"
@@ -392,55 +393,43 @@ func headersMatch(a, b []byte, isV6 bool, ipHdrLen int) bool {
if isV6 { if isV6 {
// IPv6: bytes [0:4] = version/TC/flow-label, [6:8] = next_hdr/hop, // IPv6: bytes [0:4] = version/TC/flow-label, [6:8] = next_hdr/hop,
// [8:40] = src+dst. Skip [4:6] payload length. // [8:40] = src+dst. Skip [4:6] payload length.
if !bytesEq(a[0:4], b[0:4]) { if !bytes.Equal(a[0:4], b[0:4]) {
return false return false
} }
if !bytesEq(a[6:40], b[6:40]) { if !bytes.Equal(a[6:40], b[6:40]) {
return false return false
} }
} else { } else {
// IPv4: [0:2] version/IHL/TOS, [6:10] flags/fragoff/TTL/proto, // IPv4: [0:2] version/IHL/TOS, [6:10] flags/fragoff/TTL/proto,
// [12:20] src+dst. Skip [2:4] total len, [4:6] id, [10:12] csum. // [12:20] src+dst. Skip [2:4] total len, [4:6] id, [10:12] csum.
if !bytesEq(a[0:2], b[0:2]) { if !bytes.Equal(a[0:2], b[0:2]) {
return false return false
} }
if !bytesEq(a[6:10], b[6:10]) { if !bytes.Equal(a[6:10], b[6:10]) {
return false return false
} }
if !bytesEq(a[12:20], b[12:20]) { if !bytes.Equal(a[12:20], b[12:20]) {
return false return false
} }
} }
// TCP: compare [0:4] ports, [8:13] ack+dataoff, [14:16] window, // TCP: compare [0:4] ports, [8:13] ack+dataoff, [14:16] window,
// [18:tcpHdrLen] options (incl. urgent). // [18:tcpHdrLen] options (incl. urgent).
tcp := ipHdrLen tcp := ipHdrLen
if !bytesEq(a[tcp:tcp+4], b[tcp:tcp+4]) { if !bytes.Equal(a[tcp:tcp+4], b[tcp:tcp+4]) {
return false return false
} }
if !bytesEq(a[tcp+8:tcp+13], b[tcp+8:tcp+13]) { if !bytes.Equal(a[tcp+8:tcp+13], b[tcp+8:tcp+13]) {
return false return false
} }
if !bytesEq(a[tcp+14:tcp+16], b[tcp+14:tcp+16]) { if !bytes.Equal(a[tcp+14:tcp+16], b[tcp+14:tcp+16]) {
return false return false
} }
if !bytesEq(a[tcp+18:], b[tcp+18:]) { if !bytes.Equal(a[tcp+18:], b[tcp+18:]) {
return false return false
} }
return true return true
} }
func bytesEq(a, b []byte) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
// ipv4HdrChecksum computes the IPv4 header checksum over hdr (which must // ipv4HdrChecksum computes the IPv4 header checksum over hdr (which must
// already have its checksum field zeroed) and returns the folded/inverted // already have its checksum field zeroed) and returns the folded/inverted
// 16-bit value to store. // 16-bit value to store.