switch to ASM vector checksum

This commit is contained in:
JackDoan
2026-05-04 11:56:58 -05:00
parent 6a46a2913a
commit 924268cc1f
9 changed files with 560 additions and 24 deletions

View File

@@ -22,7 +22,7 @@ type TxBatcher interface {
// to leave the outer ECN field unset. // to leave the outer ECN field unset.
Commit(pkt []byte, dst netip.AddrPort, outerECN byte) Commit(pkt []byte, dst netip.AddrPort, outerECN byte)
// Flush emits every queued packet via the underlying batch writer in // Flush emits every queued packet via the underlying batch writer in
// arrival order. Returns the first error observed. After Flush returns, // arrival order. Returns an errors.Join of one or more errors. After Flush returns,
// borrowed payload slices may be recycled. // borrowed payload slices may be recycled.
Flush() error Flush() error
} }

View File

@@ -1,6 +1,7 @@
package batch package batch
import ( import (
"errors"
"io" "io"
) )
@@ -60,16 +61,12 @@ func (m *MultiCoalescer) Reserve(sz int) []byte {
} }
// Commit dispatches pkt to the appropriate lane based on IP version + L4 // Commit dispatches pkt to the appropriate lane based on IP version + L4
// proto. Borrowed slice contract is identical to the single-lane batchers // proto. Borrowed slice contract is identical to the single-lane batchers,
// pkt must remain valid until the next Flush. // pkt must remain valid until the next Flush.
// //
// On the success path the IP/TCP-or-UDP parse happens here once and the // On the success path the IP/TCP-or-UDP parse happens here once and the
// parsed struct is handed to the lane via commitParsed so the lane doesn't // parsed struct is handed to the lane via commitParsed so the lane doesn't
// re-walk the header. On a parse failure we fall through to the lane's // re-walk the header.
// public Commit, which re-runs the parse before passthrough — that path
// only fires for malformed/unsupported packets so the duplicated parse is
// not on the hot path. The lane's public Commit still works for direct
// callers.
func (m *MultiCoalescer) Commit(pkt []byte) error { func (m *MultiCoalescer) Commit(pkt []byte) error {
if len(pkt) < 20 { if len(pkt) < 20 {
return m.pt.Commit(pkt) return m.pt.Commit(pkt)
@@ -92,9 +89,10 @@ func (m *MultiCoalescer) Commit(pkt []byte) error {
if m.tcp != nil { if m.tcp != nil {
info, ok := parseTCPBase(pkt) info, ok := parseTCPBase(pkt)
if !ok { if !ok {
// Malformed/unsupported TCP shape (IP options, fragments, ...) // Malformed/unsupported TCP shape (IP options, fragments, ...).
// — the TCP lane handles this as passthrough. // Handle this via passthrough support in the TCP coalescer, to attempt to preserve flow order.
return m.tcp.Commit(pkt) m.tcp.addPassthrough(pkt)
return nil
} }
return m.tcp.commitParsed(pkt, info) return m.tcp.commitParsed(pkt, info)
} }
@@ -102,7 +100,8 @@ func (m *MultiCoalescer) Commit(pkt []byte) error {
if m.udp != nil { if m.udp != nil {
info, ok := parseUDP(pkt) info, ok := parseUDP(pkt)
if !ok { if !ok {
return m.udp.Commit(pkt) m.udp.addPassthrough(pkt) //we could also m.pt.Commit() here I guess?
return nil
} }
return m.udp.commitParsed(pkt, info) return m.udp.commitParsed(pkt, info)
} }
@@ -111,23 +110,24 @@ func (m *MultiCoalescer) Commit(pkt []byte) error {
} }
// Flush drains every lane in a fixed order: TCP, UDP, passthrough. Errors // Flush drains every lane in a fixed order: TCP, UDP, passthrough. Errors
// from a lane do not stop subsequent lanes from flushing we keep // from a lane do not stop subsequent lanes from flushing, we keep
// draining and return the first observed error so a single bad packet // draining and return the first observed error so a single bad packet
// doesn't strand the others. // doesn't strand the others.
func (m *MultiCoalescer) Flush() error { func (m *MultiCoalescer) Flush() error {
var first error var errs []error
keep := func(err error) {
if err != nil && first == nil {
first = err
}
}
if m.tcp != nil { if m.tcp != nil {
keep(m.tcp.Flush()) if err := m.tcp.Flush(); err != nil {
errs = append(errs, err)
}
} }
if m.udp != nil { if m.udp != nil {
keep(m.udp.Flush()) if err := m.udp.Flush(); err != nil {
errs = append(errs, err)
}
}
if err := m.pt.Flush(); err != nil {
errs = append(errs, err)
} }
keep(m.pt.Flush())
m.backing = m.backing[:0] m.backing = m.backing[:0]
return first return errors.Join(errs...)
} }

View File

@@ -0,0 +1,23 @@
package checksum
import (
"golang.org/x/sys/cpu"
gvisorchecksum "gvisor.dev/gvisor/pkg/tcpip/checksum"
)
//go:noescape
func checksumAVX2(buf []byte, initial uint16) uint16
var hasAVX2 = cpu.X86.HasAVX2
// Checksum computes the RFC 1071 ones-complement sum of buf, seeded with
// initial. It is a drop-in replacement for gvisor's checksum.Checksum that
// dispatches to a hand-written AVX2 routine on amd64 CPUs that support it,
// falling back to gvisor's pure-Go implementation otherwise. The result
// matches gvisor's bit-for-bit for any buffer length and initial seed.
func Checksum(buf []byte, initial uint16) uint16 {
if hasAVX2 {
return checksumAVX2(buf, initial)
}
return gvisorchecksum.Checksum(buf, initial)
}

View File

@@ -0,0 +1,157 @@
#include "textflag.h"
// func checksumAVX2(buf []byte, initial uint16) uint16
//
// Computes the RFC 1071 ones-complement sum of buf, seeded with initial.
//
// Algorithm: sum the buffer treating it as a stream of uint32s in machine
// (little-endian) byte order, accumulating into 64-bit lanes (top 32 bits
// hold cross-add carries at 1 byte / lane / iter we have 32 bits of
// headroom which is far more than the 16 KB/64 KB max practical inputs).
// At the end we fold to 16 bits and byte-swap once to recover the on-wire
// (big-endian) result. RFC 1071 §1.2.B byte-order independence makes this
// equivalent to summing as 16-bit big-endian words.
//
// The ymm accumulators (Y4..Y7) hold 4 uint64 lanes each = 16 parallel
// partial sums. The main loop loads 64 bytes per iter as four 16-byte
// chunks, zero-extending each chunk's four uint32s into a ymm via
// VPMOVZXDQ-from-memory, then VPADDQ into a separate accumulator per
// chunk to break the dep chain. After the vector loop the lane sums are
// horizontally reduced and merged with a scalar accumulator that handles
// the trailing 0..63 bytes plus the (byte-swapped) initial seed.
TEXT ·checksumAVX2(SB), NOSPLIT, $0-34
MOVQ buf_base+0(FP), SI
MOVQ buf_len+8(FP), CX
MOVWQZX initial+24(FP), AX
// Pre-byteswap initial into the LE-summing space so it merges directly
// with the rest of the accumulator. The final fold's bswap16 will undo
// this and convert the whole result back to BE.
XCHGB AH, AL
CMPQ CX, $32
JLT scalar_tail
VPXOR Y4, Y4, Y4
VPXOR Y5, Y5, Y5
VPXOR Y6, Y6, Y6
VPXOR Y7, Y7, Y7
CMPQ CX, $64
JLT loop32
loop64:
VPMOVZXDQ (SI), Y0
VPMOVZXDQ 16(SI), Y1
VPMOVZXDQ 32(SI), Y2
VPMOVZXDQ 48(SI), Y3
VPADDQ Y0, Y4, Y4
VPADDQ Y1, Y5, Y5
VPADDQ Y2, Y6, Y6
VPADDQ Y3, Y7, Y7
ADDQ $64, SI
SUBQ $64, CX
CMPQ CX, $64
JGE loop64
loop32:
CMPQ CX, $32
JLT reduce_vec
VPMOVZXDQ (SI), Y0
VPMOVZXDQ 16(SI), Y1
VPADDQ Y0, Y4, Y4
VPADDQ Y1, Y5, Y5
ADDQ $32, SI
SUBQ $32, CX
JMP loop32
reduce_vec:
// Combine the four ymm accumulators into Y4.
VPADDQ Y5, Y4, Y4
VPADDQ Y7, Y6, Y6
VPADDQ Y6, Y4, Y4
// Horizontally reduce Y4's four uint64 lanes to a single scalar.
VEXTRACTI128 $1, Y4, X5
VPADDQ X5, X4, X4
VPSHUFD $0x4e, X4, X5
VPADDQ X5, X4, X4
VMOVQ X4, R8
VZEROUPPER
ADDQ R8, AX
ADCQ $0, AX
scalar_tail:
// Handle remaining 0..63 bytes (or the entire buffer if it was < 32).
CMPQ CX, $8
JLT tail4
loop8:
ADDQ (SI), AX
ADCQ $0, AX
ADDQ $8, SI
SUBQ $8, CX
CMPQ CX, $8
JGE loop8
tail4:
CMPQ CX, $4
JLT tail2
MOVL (SI), R8
ADDQ R8, AX
ADCQ $0, AX
ADDQ $4, SI
SUBQ $4, CX
tail2:
CMPQ CX, $2
JLT tail1
MOVWQZX (SI), R8
ADDQ R8, AX
ADCQ $0, AX
ADDQ $2, SI
SUBQ $2, CX
tail1:
TESTQ CX, CX
JZ fold
MOVBQZX (SI), R8
ADDQ R8, AX
ADCQ $0, AX
fold:
// Fold the 64-bit accumulator to 16 bits via four rounds, mirroring
// gvisor's reduce(). Each pair (split, add) halves the live width;
// the truncation steps absorb the single bit that may be left over
// after each add so the next round's bound holds.
// 64 33 bits.
MOVQ AX, R8
SHRQ $32, R8
MOVL AX, AX
ADDQ R8, AX
// 33 32 bits. AX += (AX>>32); truncate to 32. AX is now ≤ 0xFFFF_FFFF.
MOVQ AX, R8
SHRQ $32, R8
ADDQ R8, AX
MOVL AX, AX
// 32 17 bits.
MOVQ AX, R8
SHRQ $16, R8
MOVWQZX AX, AX
ADDQ R8, AX
// 17 16 bits. AX += (AX>>16); the trailing MOVW truncates bit 16.
MOVQ AX, R8
SHRQ $16, R8
ADDQ R8, AX
// AX low 16 bits hold the 16-bit sum in machine (LE) byte order; flip
// to big-endian to match the gvisor API contract.
XCHGB AH, AL
MOVW AX, ret+32(FP)
RET

View File

@@ -0,0 +1,12 @@
package checksum
//go:noescape
func checksumNEON(buf []byte, initial uint16) uint16
// Checksum computes the RFC 1071 ones-complement sum of buf, seeded with
// initial. It is a drop-in replacement for gvisor's checksum.Checksum
// that dispatches to a hand-written NEON routine. NEON is mandatory in
// armv8 so no feature check is needed.
func Checksum(buf []byte, initial uint16) uint16 {
return checksumNEON(buf, initial)
}

View File

@@ -0,0 +1,143 @@
#include "textflag.h"
// func checksumNEON(buf []byte, initial uint16) uint16
//
// Mirrors the algorithm in checksum_amd64.s: sum the buffer treating it as
// a stream of uint32s in machine (little-endian) byte order, accumulating
// into 64-bit lanes that have ample carry headroom; fold and byte-swap once
// at the very end to recover the on-wire (big-endian) result.
//
// Each loop iteration loads 64 bytes via VLD1.P into V0..V3 (4 Q regs).
// VUADDW takes the low two uint32 lanes of a Q reg, zero-extends them to
// uint64, and adds them into a 2×uint64 accumulator; VUADDW2 does the same
// for the high two lanes. Four ymm-equivalent accumulators (V8..V11) get
// updated twice per iter to break the dep chain. Tail bytes go through a
// scalar ADCS chain seeded with the byte-swapped initial.
TEXT ·checksumNEON(SB), NOSPLIT, $0-34
MOVD buf_base+0(FP), R0
MOVD buf_len+8(FP), R1
MOVHU initial+24(FP), R2
// Pre-byteswap initial into the LE-summing space so it merges directly
// with the rest of the accumulator.
REV16W R2, R2
MOVD ZR, R3 // scalar accumulator
CMP $32, R1
BLT scalar_tail
VEOR V8.B16, V8.B16, V8.B16
VEOR V9.B16, V9.B16, V9.B16
VEOR V10.B16, V10.B16, V10.B16
VEOR V11.B16, V11.B16, V11.B16
CMP $64, R1
BLT loop16_init
loop64:
VLD1.P 64(R0), [V0.B16, V1.B16, V2.B16, V3.B16]
VUADDW V0.S2, V8.D2, V8.D2
VUADDW2 V0.S4, V9.D2, V9.D2
VUADDW V1.S2, V10.D2, V10.D2
VUADDW2 V1.S4, V11.D2, V11.D2
VUADDW V2.S2, V8.D2, V8.D2
VUADDW2 V2.S4, V9.D2, V9.D2
VUADDW V3.S2, V10.D2, V10.D2
VUADDW2 V3.S4, V11.D2, V11.D2
SUB $64, R1, R1
CMP $64, R1
BGE loop64
loop16_init:
CMP $16, R1
BLT reduce_vec
loop16:
VLD1.P 16(R0), [V0.B16]
VUADDW V0.S2, V8.D2, V8.D2
VUADDW2 V0.S4, V9.D2, V9.D2
SUB $16, R1, R1
CMP $16, R1
BGE loop16
reduce_vec:
// Combine the four accumulators into V8.
VADD V9.D2, V8.D2, V8.D2
VADD V11.D2, V10.D2, V10.D2
VADD V10.D2, V8.D2, V8.D2
// Horizontal-add the two lanes of V8.D2 into a single uint64.
VADDP V8.D2, V8.D2, V8.D2
VMOV V8.D[0], R8
ADDS R8, R3, R3
ADC ZR, R3, R3
scalar_tail:
CMP $8, R1
BLT tail4
loop8:
MOVD.P 8(R0), R8
ADDS R8, R3, R3
ADC ZR, R3, R3
SUB $8, R1, R1
CMP $8, R1
BGE loop8
tail4:
CMP $4, R1
BLT tail2
MOVWU.P 4(R0), R8
ADDS R8, R3, R3
ADC ZR, R3, R3
SUB $4, R1, R1
tail2:
CMP $2, R1
BLT tail1
MOVHU.P 2(R0), R8
ADDS R8, R3, R3
ADC ZR, R3, R3
SUB $2, R1, R1
tail1:
CBZ R1, fold
MOVBU (R0), R8
ADDS R8, R3, R3
ADC ZR, R3, R3
fold:
// Merge the byte-swapped initial into our LE-form accumulator.
ADDS R2, R3, R3
ADC ZR, R3, R3
// 64 33 bits.
LSR $32, R3, R8
AND $0xffffffff, R3, R3
ADD R8, R3, R3
// 33 32 (truncate after adding bit 32 back).
LSR $32, R3, R8
ADD R8, R3, R3
AND $0xffffffff, R3, R3
// 32 17.
LSR $16, R3, R8
AND $0xffff, R3, R3
ADD R8, R3, R3
// 17 16 (truncation absorbs bit 16 below).
LSR $16, R3, R8
ADD R8, R3, R3
// AX low 16 bits hold the 16-bit sum in machine (LE) byte order; flip
// to big-endian to match the gvisor API contract. REV16W swaps bytes
// within each 16-bit halfword of the low 32 bits, so it acts as a
// 16-bit byte-swap on the live low 16.
REV16W R3, R3
AND $0xffff, R3, R3
MOVH R3, ret+32(FP)
RET

View File

@@ -0,0 +1,10 @@
//go:build !amd64 && !arm64
package checksum
import gvisorchecksum "gvisor.dev/gvisor/pkg/tcpip/checksum"
// Checksum delegates to gvisor on architectures without a hand-written body.
func Checksum(buf []byte, initial uint16) uint16 {
return gvisorchecksum.Checksum(buf, initial)
}

View File

@@ -0,0 +1,190 @@
package checksum
import (
"fmt"
"math/rand/v2"
"testing"
gvisorchecksum "gvisor.dev/gvisor/pkg/tcpip/checksum"
)
// TestChecksumMatchesGvisor walks lengths from 0 to 4096, with several initial
// seeds and a handful of starting alignments, asserting that our local
// Checksum matches gvisor's reference bit-for-bit.
func TestChecksumMatchesGvisor(t *testing.T) {
rng := rand.New(rand.NewPCG(1, 2))
const padFront = 16
// Random pool large enough for the longest case + alignment slop.
pool := make([]byte, 4096+padFront)
for i := range pool {
pool[i] = byte(rng.Uint32())
}
seeds := []uint16{0, 0x0001, 0xabcd, 0xffff, 0x1234, 0xfedc}
offsets := []int{0, 1, 2, 3, 4, 5, 7, 8, 15, 16}
for length := 0; length <= 4096; length++ {
for _, seed := range seeds {
for _, off := range offsets {
if off+length > len(pool) {
continue
}
buf := pool[off : off+length]
want := gvisorchecksum.Checksum(buf, seed)
got := Checksum(buf, seed)
if got != want {
t.Fatalf("len=%d off=%d seed=%#x: got %#04x want %#04x",
length, off, seed, got, want)
}
}
}
}
}
// TestChecksumPatternedBuffers exercises specific byte patterns that have
// historically tripped up checksum implementations: all-zero, all-0xff,
// alternating, and ascending sequences.
func TestChecksumPatternedBuffers(t *testing.T) {
for length := 0; length <= 256; length++ {
patterns := map[string][]byte{
"zeros": make([]byte, length),
"ones": bytes(length, 0xff),
"alternating": pattern(length, []byte{0xa5, 0x5a}),
"ascending": ascending(length),
}
for name, buf := range patterns {
for _, seed := range []uint16{0, 0xffff, 0x8000} {
want := gvisorchecksum.Checksum(buf, seed)
got := Checksum(buf, seed)
if got != want {
t.Fatalf("%s len=%d seed=%#x: got %#04x want %#04x",
name, length, seed, got, want)
}
}
}
}
}
func bytes(n int, v byte) []byte {
b := make([]byte, n)
for i := range b {
b[i] = v
}
return b
}
func pattern(n int, p []byte) []byte {
b := make([]byte, n)
for i := range b {
b[i] = p[i%len(p)]
}
return b
}
func ascending(n int) []byte {
b := make([]byte, n)
for i := range b {
b[i] = byte(i)
}
return b
}
// TestChecksumTailPaths targets every combination of (SIMD body iterations,
// trailing tail bytes) the asm handlers walk through. The tail handlers
// peel off 8 → 4 → 2 → 1 byte chunks in turn; this test exercises each by
// constructing lengths of the form 64*k + tail for tail ∈ [0, 63] and a
// representative spread of k values, including k=0 (no main loop, all tail)
// and k=1 (one main loop iter, then tail). It's explicit coverage for
// payload sizes that are odd, not divisible by 4, by 8, or by 32.
func TestChecksumTailPaths(t *testing.T) {
rng := rand.New(rand.NewPCG(42, 17))
const padFront = 16
const maxK = 8
pool := make([]byte, 64*maxK+padFront+64)
for i := range pool {
pool[i] = byte(rng.Uint32())
}
seeds := []uint16{0, 0xffff, 0xabcd}
offsets := []int{0, 1, 3, 7, 15} // mix of aligned and odd starts
for k := 0; k <= maxK; k++ {
for tail := 0; tail < 64; tail++ {
length := 64*k + tail
for _, seed := range seeds {
for _, off := range offsets {
if off+length > len(pool) {
continue
}
buf := pool[off : off+length]
want := gvisorchecksum.Checksum(buf, seed)
got := Checksum(buf, seed)
if got != want {
t.Fatalf("k=%d tail=%d (len=%d) off=%d seed=%#x: got %#04x want %#04x",
k, tail, length, off, seed, got, want)
}
}
}
}
}
}
// BenchmarkChecksumTailSizes covers payload sizes that aren't clean multiples
// of the SIMD body's 32-byte (amd64) or 16-byte (arm64) chunks, so the tail
// handler is meaningfully on the hot path. Sizes are picked to either exercise
// every tail branch (tiny lengths) or sit slightly off realistic packet
// boundaries (e.g. 1499 = MTU 1).
func BenchmarkChecksumTailSizes(b *testing.B) {
sizes := []int{
1, 3, 7, 15, 31, // sub-SIMD; entire work is scalar tail
33, 35, 47, 63, // one loop32 + assorted tails
65, 95, 127, // one loop64 + assorted tails
1447, 1471, 1499, 1501, // around MTU
8191, 8193, // around USO
65531, 65533, // near the kernel max
}
for _, size := range sizes {
buf := make([]byte, size)
for i := range buf {
buf[i] = byte(i)
}
b.Run(fmt.Sprintf("size=%d/local", size), func(b *testing.B) {
b.SetBytes(int64(size))
for i := 0; i < b.N; i++ {
_ = Checksum(buf, 0)
}
})
b.Run(fmt.Sprintf("size=%d/gvisor", size), func(b *testing.B) {
b.SetBytes(int64(size))
for i := 0; i < b.N; i++ {
_ = gvisorchecksum.Checksum(buf, 0)
}
})
}
}
// BenchmarkChecksum compares the local Checksum to gvisor's at sizes that
// match real traffic: a TCP/IP header (60), a typical MSS (1448), a typical
// USO size (8192), and the kernel's max GSO superpacket (65535).
func BenchmarkChecksum(b *testing.B) {
for _, size := range []int{60, 1448, 8192, 65535} {
buf := make([]byte, size)
for i := range buf {
buf[i] = byte(i)
}
b.Run(fmt.Sprintf("size=%d/local", size), func(b *testing.B) {
b.SetBytes(int64(size))
for i := 0; i < b.N; i++ {
_ = Checksum(buf, 0)
}
})
b.Run(fmt.Sprintf("size=%d/gvisor", size), func(b *testing.B) {
b.SetBytes(int64(size))
for i := 0; i < b.N; i++ {
_ = gvisorchecksum.Checksum(buf, 0)
}
})
}
}

View File

@@ -14,7 +14,8 @@ import (
"fmt" "fmt"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/checksum"
"github.com/slackhq/nebula/overlay/checksum"
) )
// Protocol header size bounds used to validate / cap kernel-supplied offsets. // Protocol header size bounds used to validate / cap kernel-supplied offsets.