Files
nebula/overlay/checksum/checksum_test.go
2026-05-11 11:32:57 -05:00

191 lines
5.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package checksum
import (
"fmt"
"math/rand/v2"
"testing"
gvisorchecksum "gvisor.dev/gvisor/pkg/tcpip/checksum"
)
// TestChecksumMatchesGvisor walks lengths from 0 to 4096, with several initial
// seeds and a handful of starting alignments, asserting that our local
// Checksum matches gvisor's reference bit-for-bit.
func TestChecksumMatchesGvisor(t *testing.T) {
rng := rand.New(rand.NewPCG(1, 2))
const padFront = 16
// Random pool large enough for the longest case + alignment slop.
pool := make([]byte, 4096+padFront)
for i := range pool {
pool[i] = byte(rng.Uint32())
}
seeds := []uint16{0, 0x0001, 0xabcd, 0xffff, 0x1234, 0xfedc}
offsets := []int{0, 1, 2, 3, 4, 5, 7, 8, 15, 16}
for length := 0; length <= 4096; length++ {
for _, seed := range seeds {
for _, off := range offsets {
if off+length > len(pool) {
continue
}
buf := pool[off : off+length]
want := gvisorchecksum.Checksum(buf, seed)
got := Checksum(buf, seed)
if got != want {
t.Fatalf("len=%d off=%d seed=%#x: got %#04x want %#04x",
length, off, seed, got, want)
}
}
}
}
}
// TestChecksumPatternedBuffers exercises specific byte patterns that have
// historically tripped up checksum implementations: all-zero, all-0xff,
// alternating, and ascending sequences.
func TestChecksumPatternedBuffers(t *testing.T) {
for length := 0; length <= 256; length++ {
patterns := map[string][]byte{
"zeros": make([]byte, length),
"ones": bytes(length, 0xff),
"alternating": pattern(length, []byte{0xa5, 0x5a}),
"ascending": ascending(length),
}
for name, buf := range patterns {
for _, seed := range []uint16{0, 0xffff, 0x8000} {
want := gvisorchecksum.Checksum(buf, seed)
got := Checksum(buf, seed)
if got != want {
t.Fatalf("%s len=%d seed=%#x: got %#04x want %#04x",
name, length, seed, got, want)
}
}
}
}
}
func bytes(n int, v byte) []byte {
b := make([]byte, n)
for i := range b {
b[i] = v
}
return b
}
func pattern(n int, p []byte) []byte {
b := make([]byte, n)
for i := range b {
b[i] = p[i%len(p)]
}
return b
}
func ascending(n int) []byte {
b := make([]byte, n)
for i := range b {
b[i] = byte(i)
}
return b
}
// TestChecksumTailPaths targets every combination of (SIMD body iterations,
// trailing tail bytes) the asm handlers walk through. The tail handlers
// peel off 8 → 4 → 2 → 1 byte chunks in turn; this test exercises each by
// constructing lengths of the form 64*k + tail for tail ∈ [0, 63] and a
// representative spread of k values, including k=0 (no main loop, all tail)
// and k=1 (one main loop iter, then tail). It's explicit coverage for
// payload sizes that are odd, not divisible by 4, by 8, or by 32.
func TestChecksumTailPaths(t *testing.T) {
rng := rand.New(rand.NewPCG(42, 17))
const padFront = 16
const maxK = 8
pool := make([]byte, 64*maxK+padFront+64)
for i := range pool {
pool[i] = byte(rng.Uint32())
}
seeds := []uint16{0, 0xffff, 0xabcd}
offsets := []int{0, 1, 3, 7, 15} // mix of aligned and odd starts
for k := 0; k <= maxK; k++ {
for tail := 0; tail < 64; tail++ {
length := 64*k + tail
for _, seed := range seeds {
for _, off := range offsets {
if off+length > len(pool) {
continue
}
buf := pool[off : off+length]
want := gvisorchecksum.Checksum(buf, seed)
got := Checksum(buf, seed)
if got != want {
t.Fatalf("k=%d tail=%d (len=%d) off=%d seed=%#x: got %#04x want %#04x",
k, tail, length, off, seed, got, want)
}
}
}
}
}
}
// BenchmarkChecksumTailSizes covers payload sizes that aren't clean multiples
// of the SIMD body's 32-byte (amd64) or 16-byte (arm64) chunks, so the tail
// handler is meaningfully on the hot path. Sizes are picked to either exercise
// every tail branch (tiny lengths) or sit slightly off realistic packet
// boundaries (e.g. 1499 = MTU 1).
func BenchmarkChecksumTailSizes(b *testing.B) {
sizes := []int{
1, 3, 7, 15, 31, // sub-SIMD; entire work is scalar tail
33, 35, 47, 63, // one loop32 + assorted tails
65, 95, 127, // one loop64 + assorted tails
1447, 1471, 1499, 1501, // around MTU
8191, 8193, // around USO
65531, 65533, // near the kernel max
}
for _, size := range sizes {
buf := make([]byte, size)
for i := range buf {
buf[i] = byte(i)
}
b.Run(fmt.Sprintf("size=%d/local", size), func(b *testing.B) {
b.SetBytes(int64(size))
for i := 0; i < b.N; i++ {
_ = Checksum(buf, 0)
}
})
b.Run(fmt.Sprintf("size=%d/gvisor", size), func(b *testing.B) {
b.SetBytes(int64(size))
for i := 0; i < b.N; i++ {
_ = gvisorchecksum.Checksum(buf, 0)
}
})
}
}
// BenchmarkChecksum compares the local Checksum to gvisor's at sizes that
// match real traffic: a TCP/IP header (60), a typical MSS (1448), a typical
// USO size (8192), and the kernel's max GSO superpacket (65535).
func BenchmarkChecksum(b *testing.B) {
for _, size := range []int{60, 1448, 8192, 65535} {
buf := make([]byte, size)
for i := range buf {
buf[i] = byte(i)
}
b.Run(fmt.Sprintf("size=%d/local", size), func(b *testing.B) {
b.SetBytes(int64(size))
for i := 0; i < b.N; i++ {
_ = Checksum(buf, 0)
}
})
b.Run(fmt.Sprintf("size=%d/gvisor", size), func(b *testing.B) {
b.SetBytes(int64(size))
for i := 0; i < b.N; i++ {
_ = gvisorchecksum.Checksum(buf, 0)
}
})
}
}