mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-16 04:47:38 +02:00
334 lines
9.4 KiB
Go
334 lines
9.4 KiB
Go
//go:build linux && !android && !e2e_testing
|
|
// +build linux,!android,!e2e_testing
|
|
|
|
package tio
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"os"
|
|
"testing"
|
|
|
|
"golang.org/x/sys/unix"
|
|
)
|
|
|
|
// verifyChecksum confirms that the one's-complement sum across `b`, optionally
|
|
// seeded with a pseudo-header sum, folds to all-ones (valid).
|
|
func verifyChecksum(b []byte, pseudo uint32) bool {
|
|
sum := checksumBytes(b, pseudo)
|
|
for sum>>16 != 0 {
|
|
sum = (sum & 0xffff) + (sum >> 16)
|
|
}
|
|
return uint16(sum) == 0xffff
|
|
}
|
|
|
|
// buildTSOv4 builds a synthetic IPv4/TCP TSO superpacket with a payload of
|
|
// `payLen` bytes split at `mss`.
|
|
func buildTSOv4(t *testing.T, payLen, mss int) ([]byte, VirtioNetHdr) {
|
|
t.Helper()
|
|
const ipLen = 20
|
|
const tcpLen = 20
|
|
pkt := make([]byte, ipLen+tcpLen+payLen)
|
|
|
|
// IPv4 header
|
|
pkt[0] = 0x45 // version 4, IHL 5
|
|
// total length is meaningless for TSO but set it anyway
|
|
binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+tcpLen+payLen))
|
|
binary.BigEndian.PutUint16(pkt[4:6], 0x4242) // original ID
|
|
pkt[8] = 64 // TTL
|
|
pkt[9] = unix.IPPROTO_TCP
|
|
copy(pkt[12:16], []byte{10, 0, 0, 1}) // src
|
|
copy(pkt[16:20], []byte{10, 0, 0, 2}) // dst
|
|
|
|
// TCP header
|
|
binary.BigEndian.PutUint16(pkt[20:22], 12345) // sport
|
|
binary.BigEndian.PutUint16(pkt[22:24], 80) // dport
|
|
binary.BigEndian.PutUint32(pkt[24:28], 10000) // seq
|
|
binary.BigEndian.PutUint32(pkt[28:32], 20000) // ack
|
|
pkt[32] = 0x50 // data offset 5 words
|
|
pkt[33] = 0x18 // ACK | PSH
|
|
binary.BigEndian.PutUint16(pkt[34:36], 65535) // window
|
|
|
|
// payload
|
|
for i := 0; i < payLen; i++ {
|
|
pkt[ipLen+tcpLen+i] = byte(i & 0xff)
|
|
}
|
|
|
|
return pkt, VirtioNetHdr{
|
|
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
|
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
|
|
HdrLen: uint16(ipLen + tcpLen),
|
|
GSOSize: uint16(mss),
|
|
CsumStart: uint16(ipLen),
|
|
CsumOffset: 16,
|
|
}
|
|
}
|
|
|
|
func TestSegmentTCPv4(t *testing.T) {
|
|
const mss = 100
|
|
const numSeg = 3
|
|
pkt, hdr := buildTSOv4(t, mss*numSeg, mss)
|
|
|
|
scratch := make([]byte, tunSegBufSize)
|
|
var out [][]byte
|
|
if err := segmentTCP(pkt, hdr, &out, scratch); err != nil {
|
|
t.Fatalf("segmentTCP: %v", err)
|
|
}
|
|
if len(out) != numSeg {
|
|
t.Fatalf("expected %d segments, got %d", numSeg, len(out))
|
|
}
|
|
|
|
for i, seg := range out {
|
|
if len(seg) != 40+mss {
|
|
t.Errorf("seg %d: unexpected len %d", i, len(seg))
|
|
}
|
|
totalLen := binary.BigEndian.Uint16(seg[2:4])
|
|
if totalLen != uint16(40+mss) {
|
|
t.Errorf("seg %d: total_len=%d want %d", i, totalLen, 40+mss)
|
|
}
|
|
id := binary.BigEndian.Uint16(seg[4:6])
|
|
if id != 0x4242+uint16(i) {
|
|
t.Errorf("seg %d: ip id=%#x want %#x", i, id, 0x4242+uint16(i))
|
|
}
|
|
seq := binary.BigEndian.Uint32(seg[24:28])
|
|
wantSeq := uint32(10000 + i*mss)
|
|
if seq != wantSeq {
|
|
t.Errorf("seg %d: seq=%d want %d", i, seq, wantSeq)
|
|
}
|
|
flags := seg[33]
|
|
wantFlags := byte(0x10) // ACK only, PSH cleared
|
|
if i == numSeg-1 {
|
|
wantFlags = 0x18 // ACK | PSH preserved on last
|
|
}
|
|
if flags != wantFlags {
|
|
t.Errorf("seg %d: flags=%#x want %#x", i, flags, wantFlags)
|
|
}
|
|
// IPv4 header checksum must verify against itself.
|
|
if !verifyChecksum(seg[:20], 0) {
|
|
t.Errorf("seg %d: bad IPv4 header checksum", i)
|
|
}
|
|
// TCP checksum must verify against the pseudo-header.
|
|
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+mss)
|
|
if !verifyChecksum(seg[20:], psum) {
|
|
t.Errorf("seg %d: bad TCP checksum", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestSegmentTCPv4OddTail(t *testing.T) {
|
|
// Payload of 250 bytes with MSS 100 → segments of 100, 100, 50.
|
|
pkt, hdr := buildTSOv4(t, 250, 100)
|
|
scratch := make([]byte, tunSegBufSize)
|
|
var out [][]byte
|
|
if err := segmentTCP(pkt, hdr, &out, scratch); err != nil {
|
|
t.Fatalf("segmentTCP: %v", err)
|
|
}
|
|
if len(out) != 3 {
|
|
t.Fatalf("want 3 segments, got %d", len(out))
|
|
}
|
|
wantPayLens := []int{100, 100, 50}
|
|
for i, seg := range out {
|
|
if len(seg)-40 != wantPayLens[i] {
|
|
t.Errorf("seg %d: pay len %d want %d", i, len(seg)-40, wantPayLens[i])
|
|
}
|
|
if !verifyChecksum(seg[:20], 0) {
|
|
t.Errorf("seg %d: bad IPv4 header checksum", i)
|
|
}
|
|
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+wantPayLens[i])
|
|
if !verifyChecksum(seg[20:], psum) {
|
|
t.Errorf("seg %d: bad TCP checksum", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestSegmentTCPv6(t *testing.T) {
|
|
const ipLen = 40
|
|
const tcpLen = 20
|
|
const mss = 120
|
|
const numSeg = 2
|
|
payLen := mss * numSeg
|
|
pkt := make([]byte, ipLen+tcpLen+payLen)
|
|
|
|
// IPv6 header
|
|
pkt[0] = 0x60 // version 6
|
|
binary.BigEndian.PutUint16(pkt[4:6], uint16(tcpLen+payLen))
|
|
pkt[6] = unix.IPPROTO_TCP
|
|
pkt[7] = 64
|
|
// src/dst fe80::1 / fe80::2
|
|
pkt[8] = 0xfe
|
|
pkt[9] = 0x80
|
|
pkt[23] = 1
|
|
pkt[24] = 0xfe
|
|
pkt[25] = 0x80
|
|
pkt[39] = 2
|
|
|
|
// TCP header
|
|
binary.BigEndian.PutUint16(pkt[40:42], 12345)
|
|
binary.BigEndian.PutUint16(pkt[42:44], 80)
|
|
binary.BigEndian.PutUint32(pkt[44:48], 7)
|
|
binary.BigEndian.PutUint32(pkt[48:52], 99)
|
|
pkt[52] = 0x50
|
|
pkt[53] = 0x19 // FIN | ACK | PSH — exercise FIN clearing too
|
|
binary.BigEndian.PutUint16(pkt[54:56], 65535)
|
|
|
|
for i := 0; i < payLen; i++ {
|
|
pkt[ipLen+tcpLen+i] = byte(i)
|
|
}
|
|
|
|
hdr := VirtioNetHdr{
|
|
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
|
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
|
|
HdrLen: uint16(ipLen + tcpLen),
|
|
GSOSize: uint16(mss),
|
|
CsumStart: uint16(ipLen),
|
|
CsumOffset: 16,
|
|
}
|
|
|
|
scratch := make([]byte, tunSegBufSize)
|
|
var out [][]byte
|
|
if err := segmentTCP(pkt, hdr, &out, scratch); err != nil {
|
|
t.Fatalf("segmentTCP: %v", err)
|
|
}
|
|
if len(out) != numSeg {
|
|
t.Fatalf("want %d segments, got %d", numSeg, len(out))
|
|
}
|
|
|
|
for i, seg := range out {
|
|
if len(seg) != ipLen+tcpLen+mss {
|
|
t.Errorf("seg %d: len %d want %d", i, len(seg), ipLen+tcpLen+mss)
|
|
}
|
|
pl := binary.BigEndian.Uint16(seg[4:6])
|
|
if pl != uint16(tcpLen+mss) {
|
|
t.Errorf("seg %d: payload_length=%d want %d", i, pl, tcpLen+mss)
|
|
}
|
|
seq := binary.BigEndian.Uint32(seg[44:48])
|
|
if seq != uint32(7+i*mss) {
|
|
t.Errorf("seg %d: seq=%d want %d", i, seq, 7+i*mss)
|
|
}
|
|
flags := seg[53]
|
|
// Original flags = 0x19 (FIN|ACK|PSH). FIN(0x01)+PSH(0x08) should be
|
|
// cleared on all but the last; ACK(0x10) always preserved.
|
|
wantFlags := byte(0x10)
|
|
if i == numSeg-1 {
|
|
wantFlags = 0x19
|
|
}
|
|
if flags != wantFlags {
|
|
t.Errorf("seg %d: flags=%#x want %#x", i, flags, wantFlags)
|
|
}
|
|
psum := pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_TCP, tcpLen+mss)
|
|
if !verifyChecksum(seg[ipLen:], psum) {
|
|
t.Errorf("seg %d: bad TCP checksum", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestSegmentGSONonePassesThrough(t *testing.T) {
|
|
pkt, hdr := buildTSOv4(t, 100, 100)
|
|
hdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE
|
|
hdr.Flags = 0 // no NEEDS_CSUM, leave packet untouched
|
|
|
|
scratch := make([]byte, tunSegBufSize)
|
|
var out [][]byte
|
|
if err := segmentInto(pkt, hdr, &out, scratch); err != nil {
|
|
t.Fatalf("segmentInto: %v", err)
|
|
}
|
|
if len(out) != 1 {
|
|
t.Fatalf("want 1 segment, got %d", len(out))
|
|
}
|
|
if len(out[0]) != len(pkt) {
|
|
t.Fatalf("unexpected length: %d vs %d", len(out[0]), len(pkt))
|
|
}
|
|
}
|
|
|
|
func TestSegmentRejectsUDP(t *testing.T) {
|
|
hdr := VirtioNetHdr{GSOType: unix.VIRTIO_NET_HDR_GSO_UDP}
|
|
var out [][]byte
|
|
if err := segmentInto(nil, hdr, &out, nil); err == nil {
|
|
t.Fatalf("expected rejection for UDP GSO")
|
|
}
|
|
}
|
|
|
|
func BenchmarkSegmentTCPv4(b *testing.B) {
|
|
sizes := []struct {
|
|
name string
|
|
payLen int
|
|
mss int
|
|
}{
|
|
{"64KiB_MSS1460", 65000, 1460},
|
|
{"16KiB_MSS1460", 16384, 1460},
|
|
{"4KiB_MSS1460", 4096, 1460},
|
|
}
|
|
for _, sz := range sizes {
|
|
b.Run(sz.name, func(b *testing.B) {
|
|
const ipLen = 20
|
|
const tcpLen = 20
|
|
pkt := make([]byte, ipLen+tcpLen+sz.payLen)
|
|
pkt[0] = 0x45
|
|
binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+tcpLen+sz.payLen))
|
|
binary.BigEndian.PutUint16(pkt[4:6], 0x4242)
|
|
pkt[8] = 64
|
|
pkt[9] = unix.IPPROTO_TCP
|
|
copy(pkt[12:16], []byte{10, 0, 0, 1})
|
|
copy(pkt[16:20], []byte{10, 0, 0, 2})
|
|
binary.BigEndian.PutUint16(pkt[20:22], 12345)
|
|
binary.BigEndian.PutUint16(pkt[22:24], 80)
|
|
binary.BigEndian.PutUint32(pkt[24:28], 10000)
|
|
binary.BigEndian.PutUint32(pkt[28:32], 20000)
|
|
pkt[32] = 0x50
|
|
pkt[33] = 0x18
|
|
binary.BigEndian.PutUint16(pkt[34:36], 65535)
|
|
for i := 0; i < sz.payLen; i++ {
|
|
pkt[ipLen+tcpLen+i] = byte(i)
|
|
}
|
|
hdr := VirtioNetHdr{
|
|
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
|
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
|
|
HdrLen: uint16(ipLen + tcpLen),
|
|
GSOSize: uint16(sz.mss),
|
|
CsumStart: uint16(ipLen),
|
|
CsumOffset: 16,
|
|
}
|
|
|
|
scratch := make([]byte, tunSegBufSize)
|
|
out := make([][]byte, 0, 64)
|
|
|
|
b.SetBytes(int64(len(pkt)))
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
out = out[:0]
|
|
if err := segmentTCP(pkt, hdr, &out, scratch); err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestTunFileWriteVnetHdrNoAlloc verifies the IFF_VNET_HDR fast-path write is
|
|
// allocation-free. We write to /dev/null so every call succeeds synchronously.
|
|
func TestTunFileWriteVnetHdrNoAlloc(t *testing.T) {
|
|
fd, err := unix.Open("/dev/null", os.O_WRONLY, 0)
|
|
if err != nil {
|
|
t.Fatalf("open /dev/null: %v", err)
|
|
}
|
|
t.Cleanup(func() { _ = unix.Close(fd) })
|
|
|
|
tf := &tunFile{fd: fd}
|
|
tf.writeIovs[0].Base = &validVnetHdr[0]
|
|
tf.writeIovs[0].SetLen(virtioNetHdrLen)
|
|
|
|
payload := make([]byte, 1400)
|
|
// Warm up (first call may trigger one-time internal allocations elsewhere).
|
|
if _, err := tf.Write(payload); err != nil {
|
|
t.Fatalf("Write: %v", err)
|
|
}
|
|
|
|
allocs := testing.AllocsPerRun(1000, func() {
|
|
if _, err := tf.Write(payload); err != nil {
|
|
t.Fatalf("Write: %v", err)
|
|
}
|
|
})
|
|
if allocs != 0 {
|
|
t.Fatalf("Write allocated %.1f times per call, want 0", allocs)
|
|
}
|
|
}
|