Files
nebula/overlay/tio/tun_linux_offload_test.go
2026-05-11 11:32:57 -05:00

795 lines
24 KiB
Go

//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
package tio
import (
"encoding/binary"
"os"
"testing"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/checksum"
"github.com/slackhq/nebula/overlay/tio/virtio"
)
// testSegScratchSize is a generous segmentation scratch sized to fit any
// of the synthetic TSO/USO superpackets these tests generate (one
// worst-case 64 KiB superpacket plus replicated per-segment headers).
const testSegScratchSize = 192 * 1024
// verifyChecksum confirms that the one's-complement sum across `b`, seeded
// with a folded pseudo-header sum, equals all-ones (valid).
func verifyChecksum(b []byte, pseudo uint16) bool {
return checksum.Checksum(b, pseudo) == 0xffff
}
// segmentForTest is the test-only counterpart to the production
// SegmentSuperpacket path. It handles GSO_NONE (with optional
// finishChecksum) inline and dispatches GSO superpackets through
// SegmentSuperpacket, draining each yielded segment into a
// freshly-copied [][]byte slot so callers can iterate after the call
// returns. Tests pre-set hdr.HdrLen correctly, so correctHdrLen is not
// invoked here.
func segmentForTest(pkt []byte, hdr virtio.Hdr, out *[][]byte, scratch []byte) error {
if hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_NONE {
cp := append([]byte(nil), pkt...)
if hdr.Flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
if err := virtio.FinishChecksum(cp, hdr); err != nil {
return err
}
}
*out = append(*out, cp)
return nil
}
proto, err := protoFromGSOType(hdr.GSOType)
if err != nil {
return err
}
gso := GSOInfo{
Size: hdr.GSOSize,
HdrLen: hdr.HdrLen,
CsumStart: hdr.CsumStart,
Proto: proto,
}
return SegmentSuperpacket(Packet{Bytes: pkt, GSO: gso}, func(seg []byte) error {
*out = append(*out, append([]byte(nil), seg...))
return nil
})
}
// pseudoHeaderIPv4 returns the folded pseudo-header sum used to verify a
// TCP/UDP segment's checksum in tests. src/dst are 4 bytes each.
func pseudoHeaderIPv4(src, dst []byte, proto byte, l4Len int) uint16 {
s := uint32(checksum.Checksum(src, 0)) + uint32(checksum.Checksum(dst, 0))
s += uint32(proto) + uint32(l4Len)
s = (s & 0xffff) + (s >> 16)
s = (s & 0xffff) + (s >> 16)
return uint16(s)
}
// pseudoHeaderIPv6 returns the folded pseudo-header sum used to verify a
// TCP/UDP segment's checksum in tests. src/dst are 16 bytes each.
func pseudoHeaderIPv6(src, dst []byte, proto byte, l4Len int) uint16 {
s := uint32(checksum.Checksum(src, 0)) + uint32(checksum.Checksum(dst, 0))
s += uint32(l4Len>>16) + uint32(l4Len&0xffff) + uint32(proto)
s = (s & 0xffff) + (s >> 16)
s = (s & 0xffff) + (s >> 16)
return uint16(s)
}
// buildTSOv4 builds a synthetic IPv4/TCP TSO superpacket with a payload of
// `payLen` bytes split at `mss`.
func buildTSOv4(t *testing.T, payLen, mss int) ([]byte, virtio.Hdr) {
t.Helper()
const ipLen = 20
const tcpLen = 20
pkt := make([]byte, ipLen+tcpLen+payLen)
// IPv4 header
pkt[0] = 0x45 // version 4, IHL 5
// total length is meaningless for TSO but set it anyway
binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+tcpLen+payLen))
binary.BigEndian.PutUint16(pkt[4:6], 0x4242) // original ID
pkt[8] = 64 // TTL
pkt[9] = unix.IPPROTO_TCP
copy(pkt[12:16], []byte{10, 0, 0, 1}) // src
copy(pkt[16:20], []byte{10, 0, 0, 2}) // dst
// TCP header
binary.BigEndian.PutUint16(pkt[20:22], 12345) // sport
binary.BigEndian.PutUint16(pkt[22:24], 80) // dport
binary.BigEndian.PutUint32(pkt[24:28], 10000) // seq
binary.BigEndian.PutUint32(pkt[28:32], 20000) // ack
pkt[32] = 0x50 // data offset 5 words
pkt[33] = 0x18 // ACK | PSH
binary.BigEndian.PutUint16(pkt[34:36], 65535) // window
// payload
for i := 0; i < payLen; i++ {
pkt[ipLen+tcpLen+i] = byte(i & 0xff)
}
return pkt, virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
HdrLen: uint16(ipLen + tcpLen),
GSOSize: uint16(mss),
CsumStart: uint16(ipLen),
CsumOffset: 16,
}
}
func TestSegmentTCPv4(t *testing.T) {
const mss = 100
const numSeg = 3
pkt, hdr := buildTSOv4(t, mss*numSeg, mss)
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != numSeg {
t.Fatalf("expected %d segments, got %d", numSeg, len(out))
}
for i, seg := range out {
if len(seg) != 40+mss {
t.Errorf("seg %d: unexpected len %d", i, len(seg))
}
totalLen := binary.BigEndian.Uint16(seg[2:4])
if totalLen != uint16(40+mss) {
t.Errorf("seg %d: total_len=%d want %d", i, totalLen, 40+mss)
}
id := binary.BigEndian.Uint16(seg[4:6])
if id != 0x4242+uint16(i) {
t.Errorf("seg %d: ip id=%#x want %#x", i, id, 0x4242+uint16(i))
}
seq := binary.BigEndian.Uint32(seg[24:28])
wantSeq := uint32(10000 + i*mss)
if seq != wantSeq {
t.Errorf("seg %d: seq=%d want %d", i, seq, wantSeq)
}
flags := seg[33]
wantFlags := byte(0x10) // ACK only, PSH cleared
if i == numSeg-1 {
wantFlags = 0x18 // ACK | PSH preserved on last
}
if flags != wantFlags {
t.Errorf("seg %d: flags=%#x want %#x", i, flags, wantFlags)
}
// IPv4 header checksum must verify against itself.
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
// TCP checksum must verify against the pseudo-header.
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+mss)
if !verifyChecksum(seg[20:], psum) {
t.Errorf("seg %d: bad TCP checksum", i)
}
}
}
func TestSegmentTCPv4OddTail(t *testing.T) {
// Payload of 250 bytes with MSS 100 → segments of 100, 100, 50.
pkt, hdr := buildTSOv4(t, 250, 100)
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != 3 {
t.Fatalf("want 3 segments, got %d", len(out))
}
wantPayLens := []int{100, 100, 50}
for i, seg := range out {
if len(seg)-40 != wantPayLens[i] {
t.Errorf("seg %d: pay len %d want %d", i, len(seg)-40, wantPayLens[i])
}
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+wantPayLens[i])
if !verifyChecksum(seg[20:], psum) {
t.Errorf("seg %d: bad TCP checksum", i)
}
}
}
func TestSegmentTCPv6(t *testing.T) {
const ipLen = 40
const tcpLen = 20
const mss = 120
const numSeg = 2
payLen := mss * numSeg
pkt := make([]byte, ipLen+tcpLen+payLen)
// IPv6 header
pkt[0] = 0x60 // version 6
binary.BigEndian.PutUint16(pkt[4:6], uint16(tcpLen+payLen))
pkt[6] = unix.IPPROTO_TCP
pkt[7] = 64
// src/dst fe80::1 / fe80::2
pkt[8] = 0xfe
pkt[9] = 0x80
pkt[23] = 1
pkt[24] = 0xfe
pkt[25] = 0x80
pkt[39] = 2
// TCP header
binary.BigEndian.PutUint16(pkt[40:42], 12345)
binary.BigEndian.PutUint16(pkt[42:44], 80)
binary.BigEndian.PutUint32(pkt[44:48], 7)
binary.BigEndian.PutUint32(pkt[48:52], 99)
pkt[52] = 0x50
pkt[53] = 0x19 // FIN | ACK | PSH — exercise FIN clearing too
binary.BigEndian.PutUint16(pkt[54:56], 65535)
for i := 0; i < payLen; i++ {
pkt[ipLen+tcpLen+i] = byte(i)
}
hdr := virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
HdrLen: uint16(ipLen + tcpLen),
GSOSize: uint16(mss),
CsumStart: uint16(ipLen),
CsumOffset: 16,
}
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != numSeg {
t.Fatalf("want %d segments, got %d", numSeg, len(out))
}
for i, seg := range out {
if len(seg) != ipLen+tcpLen+mss {
t.Errorf("seg %d: len %d want %d", i, len(seg), ipLen+tcpLen+mss)
}
pl := binary.BigEndian.Uint16(seg[4:6])
if pl != uint16(tcpLen+mss) {
t.Errorf("seg %d: payload_length=%d want %d", i, pl, tcpLen+mss)
}
seq := binary.BigEndian.Uint32(seg[44:48])
if seq != uint32(7+i*mss) {
t.Errorf("seg %d: seq=%d want %d", i, seq, 7+i*mss)
}
flags := seg[53]
// Original flags = 0x19 (FIN|ACK|PSH). FIN(0x01)+PSH(0x08) should be
// cleared on all but the last; ACK(0x10) always preserved.
wantFlags := byte(0x10)
if i == numSeg-1 {
wantFlags = 0x19
}
if flags != wantFlags {
t.Errorf("seg %d: flags=%#x want %#x", i, flags, wantFlags)
}
psum := pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_TCP, tcpLen+mss)
if !verifyChecksum(seg[ipLen:], psum) {
t.Errorf("seg %d: bad TCP checksum", i)
}
}
}
func TestSegmentGSONonePassesThrough(t *testing.T) {
pkt, hdr := buildTSOv4(t, 100, 100)
hdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE
hdr.Flags = 0 // no NEEDS_CSUM, leave packet untouched
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != 1 {
t.Fatalf("want 1 segment, got %d", len(out))
}
if len(out[0]) != len(pkt) {
t.Fatalf("unexpected length: %d vs %d", len(out[0]), len(pkt))
}
}
// TestSegmentRejectsLegacyUDPGSO ensures the legacy GSO_UDP (UFO) marker is
// still rejected; only modern GSO_UDP_L4 (USO) is supported.
func TestSegmentRejectsLegacyUDPGSO(t *testing.T) {
hdr := virtio.Hdr{GSOType: unix.VIRTIO_NET_HDR_GSO_UDP}
var out [][]byte
if err := segmentForTest(nil, hdr, &out, nil); err == nil {
t.Fatalf("expected rejection for legacy UDP GSO")
}
}
// buildUSOv4 builds a synthetic IPv4/UDP USO superpacket with payload of
// payLen bytes, segmented at gsoSize.
func buildUSOv4(t *testing.T, payLen, gsoSize int) ([]byte, virtio.Hdr) {
t.Helper()
const ipLen = 20
const udpLen = 8
pkt := make([]byte, ipLen+udpLen+payLen)
// IPv4 header
pkt[0] = 0x45 // version 4, IHL 5
binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+udpLen+payLen))
binary.BigEndian.PutUint16(pkt[4:6], 0x4242)
pkt[8] = 64
pkt[9] = unix.IPPROTO_UDP
copy(pkt[12:16], []byte{10, 0, 0, 1})
copy(pkt[16:20], []byte{10, 0, 0, 2})
// UDP header (length + checksum filled in per segment by segmentUDPYield)
binary.BigEndian.PutUint16(pkt[20:22], 12345) // sport
binary.BigEndian.PutUint16(pkt[22:24], 53) // dport
for i := 0; i < payLen; i++ {
pkt[ipLen+udpLen+i] = byte(i & 0xff)
}
return pkt, virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
HdrLen: uint16(ipLen + udpLen),
GSOSize: uint16(gsoSize),
CsumStart: uint16(ipLen),
CsumOffset: 6,
}
}
func TestSegmentUDPv4(t *testing.T) {
const gso = 100
const numSeg = 3
pkt, hdr := buildUSOv4(t, gso*numSeg, gso)
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != numSeg {
t.Fatalf("expected %d segments, got %d", numSeg, len(out))
}
for i, seg := range out {
if len(seg) != 28+gso {
t.Errorf("seg %d: len %d want %d", i, len(seg), 28+gso)
}
totalLen := binary.BigEndian.Uint16(seg[2:4])
if totalLen != uint16(28+gso) {
t.Errorf("seg %d: total_len=%d want %d", i, totalLen, 28+gso)
}
// kernel UDP-GSO does NOT bump the IPv4 ID across segments; every
// segment carries the same ID as the seed.
id := binary.BigEndian.Uint16(seg[4:6])
if id != 0x4242 {
t.Errorf("seg %d: ip id=%#x want %#x", i, id, 0x4242)
}
udpLen := binary.BigEndian.Uint16(seg[24:26])
if udpLen != uint16(8+gso) {
t.Errorf("seg %d: udp len=%d want %d", i, udpLen, 8+gso)
}
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_UDP, 8+gso)
if !verifyChecksum(seg[20:], psum) {
t.Errorf("seg %d: bad UDP checksum", i)
}
}
}
func TestSegmentUDPv4OddTail(t *testing.T) {
// 250 bytes payload, gsoSize=100 → segments of 100, 100, 50.
pkt, hdr := buildUSOv4(t, 250, 100)
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != 3 {
t.Fatalf("want 3 segments, got %d", len(out))
}
wantPay := []int{100, 100, 50}
for i, seg := range out {
if len(seg)-28 != wantPay[i] {
t.Errorf("seg %d: pay len %d want %d", i, len(seg)-28, wantPay[i])
}
udpLen := binary.BigEndian.Uint16(seg[24:26])
if udpLen != uint16(8+wantPay[i]) {
t.Errorf("seg %d: udp len=%d want %d", i, udpLen, 8+wantPay[i])
}
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_UDP, 8+wantPay[i])
if !verifyChecksum(seg[20:], psum) {
t.Errorf("seg %d: bad UDP checksum", i)
}
}
}
func TestSegmentUDPv6(t *testing.T) {
const ipLen = 40
const udpLen = 8
const gso = 120
const numSeg = 2
payLen := gso * numSeg
pkt := make([]byte, ipLen+udpLen+payLen)
// IPv6 header
pkt[0] = 0x60
binary.BigEndian.PutUint16(pkt[4:6], uint16(udpLen+payLen))
pkt[6] = unix.IPPROTO_UDP
pkt[7] = 64
pkt[8] = 0xfe
pkt[9] = 0x80
pkt[23] = 1
pkt[24] = 0xfe
pkt[25] = 0x80
pkt[39] = 2
binary.BigEndian.PutUint16(pkt[40:42], 12345)
binary.BigEndian.PutUint16(pkt[42:44], 53)
for i := 0; i < payLen; i++ {
pkt[ipLen+udpLen+i] = byte(i)
}
hdr := virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
HdrLen: uint16(ipLen + udpLen),
GSOSize: uint16(gso),
CsumStart: uint16(ipLen),
CsumOffset: 6,
}
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != numSeg {
t.Fatalf("want %d segments, got %d", numSeg, len(out))
}
for i, seg := range out {
if len(seg) != ipLen+udpLen+gso {
t.Errorf("seg %d: len %d want %d", i, len(seg), ipLen+udpLen+gso)
}
pl := binary.BigEndian.Uint16(seg[4:6])
if pl != uint16(udpLen+gso) {
t.Errorf("seg %d: payload_length=%d want %d", i, pl, udpLen+gso)
}
ul := binary.BigEndian.Uint16(seg[ipLen+4 : ipLen+6])
if ul != uint16(udpLen+gso) {
t.Errorf("seg %d: udp len=%d want %d", i, ul, udpLen+gso)
}
psum := pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_UDP, udpLen+gso)
if !verifyChecksum(seg[ipLen:], psum) {
t.Errorf("seg %d: bad UDP checksum", i)
}
}
}
// TestSegmentUDPCEPropagates confirms IP-level CE marks on the seed appear on
// every segment. UDP has no transport-level CWR/ECE: the IP TOS/TC byte is
// copied verbatim into every segment by the segment-prefix copy.
func TestSegmentUDPCEPropagates(t *testing.T) {
pkt, hdr := buildUSOv4(t, 200, 100)
pkt[1] = 0x03 // CE codepoint in IP-ECN
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != 2 {
t.Fatalf("want 2 segments, got %d", len(out))
}
for i, seg := range out {
if seg[1]&0x03 != 0x03 {
t.Errorf("seg %d: CE missing (tos=%#x)", i, seg[1])
}
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
}
}
// TestSegmentTCPCwrFirstSegmentOnly confirms RFC 3168 §6.1.2: when a TSO
// burst's seed has CWR set, only the first emitted segment carries CWR.
// ECE is preserved on every segment (different signal, persistent state).
func TestSegmentTCPCwrFirstSegmentOnly(t *testing.T) {
const mss = 100
const numSeg = 3
pkt, hdr := buildTSOv4(t, mss*numSeg, mss)
// Seed flags: CWR | ECE | ACK | PSH.
pkt[33] = 0x80 | 0x40 | 0x10 | 0x08
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != numSeg {
t.Fatalf("expected %d segments, got %d", numSeg, len(out))
}
for i, seg := range out {
flags := seg[33]
hasCwr := flags&0x80 != 0
hasEce := flags&0x40 != 0
hasPsh := flags&0x08 != 0
wantCwr := i == 0
wantPsh := i == numSeg-1
if hasCwr != wantCwr {
t.Errorf("seg %d: CWR=%v want %v (flags=%#x)", i, hasCwr, wantCwr, flags)
}
if !hasEce {
t.Errorf("seg %d: ECE missing (flags=%#x)", i, flags)
}
if hasPsh != wantPsh {
t.Errorf("seg %d: PSH=%v want %v (flags=%#x)", i, hasPsh, wantPsh, flags)
}
// IP and TCP checksums must still verify after the flag rewrite.
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+mss)
if !verifyChecksum(seg[20:], psum) {
t.Errorf("seg %d: bad TCP checksum", i)
}
}
}
func BenchmarkSegmentTCPv4(b *testing.B) {
sizes := []struct {
name string
payLen int
mss int
}{
{"64KiB_MSS1460", 65000, 1460},
{"16KiB_MSS1460", 16384, 1460},
{"4KiB_MSS1460", 4096, 1460},
}
for _, sz := range sizes {
b.Run(sz.name, func(b *testing.B) {
const ipLen = 20
const tcpLen = 20
pkt := make([]byte, ipLen+tcpLen+sz.payLen)
pkt[0] = 0x45
binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+tcpLen+sz.payLen))
binary.BigEndian.PutUint16(pkt[4:6], 0x4242)
pkt[8] = 64
pkt[9] = unix.IPPROTO_TCP
copy(pkt[12:16], []byte{10, 0, 0, 1})
copy(pkt[16:20], []byte{10, 0, 0, 2})
binary.BigEndian.PutUint16(pkt[20:22], 12345)
binary.BigEndian.PutUint16(pkt[22:24], 80)
binary.BigEndian.PutUint32(pkt[24:28], 10000)
binary.BigEndian.PutUint32(pkt[28:32], 20000)
pkt[32] = 0x50
pkt[33] = 0x18
binary.BigEndian.PutUint16(pkt[34:36], 65535)
for i := 0; i < sz.payLen; i++ {
pkt[ipLen+tcpLen+i] = byte(i)
}
hdr := virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
HdrLen: uint16(ipLen + tcpLen),
GSOSize: uint16(sz.mss),
CsumStart: uint16(ipLen),
CsumOffset: 16,
}
scratch := make([]byte, testSegScratchSize)
out := make([][]byte, 0, 64)
// SegmentSuperpacket consumes its input destructively; restore
// pkt from a master copy each iteration. The restore mirrors the
// kernel→userspace copy that hands a fresh GSO blob to the
// segmenter in production, so it's representative cost rather
// than bench overhead.
master := append([]byte(nil), pkt...)
work := make([]byte, len(pkt))
b.SetBytes(int64(len(pkt)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
copy(work, master)
out = out[:0]
if err := segmentForTest(work, hdr, &out, scratch); err != nil {
b.Fatal(err)
}
}
})
}
}
// TestTunFileWriteVnetHdrNoAlloc verifies the IFF_VNET_HDR fast-path write is
// allocation-free. We write to /dev/null so every call succeeds synchronously.
func TestTunFileWriteVnetHdrNoAlloc(t *testing.T) {
fd, err := unix.Open("/dev/null", os.O_WRONLY, 0)
if err != nil {
t.Fatalf("open /dev/null: %v", err)
}
t.Cleanup(func() { _ = unix.Close(fd) })
tf := &Offload{fd: fd}
payload := make([]byte, 1400)
// Warm up (first call may trigger one-time internal allocations elsewhere).
if _, err := tf.Write(payload); err != nil {
t.Fatalf("Write: %v", err)
}
allocs := testing.AllocsPerRun(1000, func() {
if _, err := tf.Write(payload); err != nil {
t.Fatalf("Write: %v", err)
}
})
if allocs != 0 {
t.Fatalf("Write allocated %.1f times per call, want 0", allocs)
}
}
// buildTSOv6 builds a synthetic IPv6/TCP TSO superpacket with payLen bytes
// of payload, segmented at gso. Returns the packet bytes only; the
// virtio_net_hdr is the caller's responsibility.
func buildTSOv6(payLen, gso int) []byte {
const ipLen = 40
const tcpLen = 20
pkt := make([]byte, ipLen+tcpLen+payLen)
pkt[0] = 0x60 // version 6
binary.BigEndian.PutUint16(pkt[4:6], uint16(tcpLen+payLen))
pkt[6] = unix.IPPROTO_TCP
pkt[7] = 64
pkt[8] = 0xfe
pkt[9] = 0x80
pkt[23] = 1
pkt[24] = 0xfe
pkt[25] = 0x80
pkt[39] = 2
binary.BigEndian.PutUint16(pkt[40:42], 12345)
binary.BigEndian.PutUint16(pkt[42:44], 80)
binary.BigEndian.PutUint32(pkt[44:48], 7)
binary.BigEndian.PutUint32(pkt[48:52], 99)
pkt[52] = 0x50
pkt[53] = 0x10 // ACK only
binary.BigEndian.PutUint16(pkt[54:56], 65535)
for i := 0; i < payLen; i++ {
pkt[ipLen+tcpLen+i] = byte(i)
}
return pkt
}
// TestDecodeReadFitsMaxTSOAtDrainThreshold proves the rxBuf sizing is
// correct: when rxOff is at the maximum value the drain headroom check
// allows, decodeRead must still be able to absorb a worst-case 64KiB
// TSO superpacket without dropping the burst. With segmentation deferred
// to encrypt time, decodeRead writes only the kernel-supplied bytes into
// rxBuf, so the size requirement is just "fit one worst-case input."
//
// Regression history: in a prior layout the rx buffer doubled as the
// segmentation output, a near-threshold drain read returned "scratch too
// small", the whole 45-segment TSO burst was dropped, and the remote's TCP
// fast-retransmit collapsed cwnd. Keeping this test in the new layout
// guards against re-introducing a drain headroom shortfall.
func TestDecodeReadFitsMaxTSOAtDrainThreshold(t *testing.T) {
const ipv6HdrLen = 40
const tcpHdrLen = 20
const headerLen = ipv6HdrLen + tcpHdrLen
// Maximum TUN read body. The tunReadBufSize cap on readv's body iovec
// is what bounds the kernel's superpacket length.
pktLen := tunReadBufSize
payLen := pktLen - headerLen
const targetSegs = 64
gsoSize := (payLen + targetSegs - 1) / targetSegs
pkt := buildTSOv6(payLen, gsoSize)
if len(pkt) != pktLen {
t.Fatalf("buildTSOv6 produced %d bytes, want %d", len(pkt), pktLen)
}
o := &Offload{
rxBuf: make([]byte, tunRxBufCap),
}
// rxOff at the maximum value the drain headroom check permits before
// it would refuse another read. Any drain-time read up to this
// threshold MUST still process correctly.
o.rxOff = tunRxBufCap - tunRxBufSize
// Stage the body in rxBuf as if readv(2) just placed it there.
copy(o.rxBuf[o.rxOff:], pkt)
// Encode the matching virtio_net_hdr.
hdr := virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
HdrLen: uint16(headerLen),
GSOSize: uint16(gsoSize),
CsumStart: uint16(ipv6HdrLen),
CsumOffset: 16,
}
hdr.Encode(o.readVnetScratch[:])
startRxOff := o.rxOff
if err := o.decodeRead(pktLen); err != nil {
t.Fatalf("decodeRead at drain threshold returned %v — rxBuf sizing regression: "+
"tunRxBufSize=%d must hold one worst-case input (%d)",
err, tunRxBufSize, pktLen)
}
if len(o.pending) != 1 {
t.Fatalf("got %d packets, want 1 superpacket entry", len(o.pending))
}
got := o.pending[0]
if !got.GSO.IsSuperpacket() {
t.Fatalf("expected superpacket GSO metadata, got %+v", got.GSO)
}
if got.GSO.Proto != GSOProtoTCP {
t.Errorf("GSO.Proto=%d want TCP", got.GSO.Proto)
}
if got.GSO.Size != uint16(gsoSize) {
t.Errorf("GSO.Size=%d want %d", got.GSO.Size, gsoSize)
}
if got.GSO.HdrLen != uint16(headerLen) {
t.Errorf("GSO.HdrLen=%d want %d", got.GSO.HdrLen, headerLen)
}
if got.GSO.CsumStart != uint16(ipv6HdrLen) {
t.Errorf("GSO.CsumStart=%d want %d", got.GSO.CsumStart, ipv6HdrLen)
}
if len(got.Bytes) != pktLen {
t.Errorf("len(Bytes)=%d want %d", len(got.Bytes), pktLen)
}
// rxOff advances exactly by the kernel-supplied body length — no
// segmentation output to account for any more.
if o.rxOff != startRxOff+pktLen {
t.Errorf("rxOff=%d want %d", o.rxOff, startRxOff+pktLen)
}
if o.rxOff > tunRxBufCap {
t.Fatalf("rxOff=%d overran rxBuf (cap=%d)", o.rxOff, tunRxBufCap)
}
// Validate that segmenting the returned superpacket reproduces the
// expected per-segment IPv6 payload length and TCP checksum.
wantSegs := (payLen + gsoSize - 1) / gsoSize
gotSegs := 0
if err := SegmentSuperpacket(got, func(seg []byte) error {
defer func() { gotSegs++ }()
if len(seg) < headerLen+1 {
t.Errorf("seg %d too short: %d", gotSegs, len(seg))
return nil
}
if seg[0]>>4 != 6 {
t.Errorf("seg %d: bad IP version %#x", gotSegs, seg[0])
}
segPay := len(seg) - headerLen
gotPL := binary.BigEndian.Uint16(seg[4:6])
if gotPL != uint16(tcpHdrLen+segPay) {
t.Errorf("seg %d: payload_len=%d want %d", gotSegs, gotPL, tcpHdrLen+segPay)
}
psum := pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_TCP, tcpHdrLen+segPay)
if !verifyChecksum(seg[ipv6HdrLen:], psum) {
t.Errorf("seg %d: bad TCP checksum", gotSegs)
}
return nil
}); err != nil {
t.Fatalf("SegmentSuperpacket: %v", err)
}
if gotSegs != wantSegs {
t.Fatalf("got %d segments, want %d", gotSegs, wantSegs)
}
}