mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-16 04:47:38 +02:00
switch to ASM vector checksum
This commit is contained in:
@@ -22,7 +22,7 @@ type TxBatcher interface {
|
|||||||
// to leave the outer ECN field unset.
|
// to leave the outer ECN field unset.
|
||||||
Commit(pkt []byte, dst netip.AddrPort, outerECN byte)
|
Commit(pkt []byte, dst netip.AddrPort, outerECN byte)
|
||||||
// Flush emits every queued packet via the underlying batch writer in
|
// Flush emits every queued packet via the underlying batch writer in
|
||||||
// arrival order. Returns the first error observed. After Flush returns,
|
// arrival order. Returns an errors.Join of one or more errors. After Flush returns,
|
||||||
// borrowed payload slices may be recycled.
|
// borrowed payload slices may be recycled.
|
||||||
Flush() error
|
Flush() error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package batch
|
package batch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -60,16 +61,12 @@ func (m *MultiCoalescer) Reserve(sz int) []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Commit dispatches pkt to the appropriate lane based on IP version + L4
|
// Commit dispatches pkt to the appropriate lane based on IP version + L4
|
||||||
// proto. Borrowed slice contract is identical to the single-lane batchers
|
// proto. Borrowed slice contract is identical to the single-lane batchers,
|
||||||
// — pkt must remain valid until the next Flush.
|
// pkt must remain valid until the next Flush.
|
||||||
//
|
//
|
||||||
// On the success path the IP/TCP-or-UDP parse happens here once and the
|
// On the success path the IP/TCP-or-UDP parse happens here once and the
|
||||||
// parsed struct is handed to the lane via commitParsed so the lane doesn't
|
// parsed struct is handed to the lane via commitParsed so the lane doesn't
|
||||||
// re-walk the header. On a parse failure we fall through to the lane's
|
// re-walk the header.
|
||||||
// public Commit, which re-runs the parse before passthrough — that path
|
|
||||||
// only fires for malformed/unsupported packets so the duplicated parse is
|
|
||||||
// not on the hot path. The lane's public Commit still works for direct
|
|
||||||
// callers.
|
|
||||||
func (m *MultiCoalescer) Commit(pkt []byte) error {
|
func (m *MultiCoalescer) Commit(pkt []byte) error {
|
||||||
if len(pkt) < 20 {
|
if len(pkt) < 20 {
|
||||||
return m.pt.Commit(pkt)
|
return m.pt.Commit(pkt)
|
||||||
@@ -92,9 +89,10 @@ func (m *MultiCoalescer) Commit(pkt []byte) error {
|
|||||||
if m.tcp != nil {
|
if m.tcp != nil {
|
||||||
info, ok := parseTCPBase(pkt)
|
info, ok := parseTCPBase(pkt)
|
||||||
if !ok {
|
if !ok {
|
||||||
// Malformed/unsupported TCP shape (IP options, fragments, ...)
|
// Malformed/unsupported TCP shape (IP options, fragments, ...).
|
||||||
// — the TCP lane handles this as passthrough.
|
// Handle this via passthrough support in the TCP coalescer, to attempt to preserve flow order.
|
||||||
return m.tcp.Commit(pkt)
|
m.tcp.addPassthrough(pkt)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return m.tcp.commitParsed(pkt, info)
|
return m.tcp.commitParsed(pkt, info)
|
||||||
}
|
}
|
||||||
@@ -102,7 +100,8 @@ func (m *MultiCoalescer) Commit(pkt []byte) error {
|
|||||||
if m.udp != nil {
|
if m.udp != nil {
|
||||||
info, ok := parseUDP(pkt)
|
info, ok := parseUDP(pkt)
|
||||||
if !ok {
|
if !ok {
|
||||||
return m.udp.Commit(pkt)
|
m.udp.addPassthrough(pkt) //we could also m.pt.Commit() here I guess?
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return m.udp.commitParsed(pkt, info)
|
return m.udp.commitParsed(pkt, info)
|
||||||
}
|
}
|
||||||
@@ -111,23 +110,24 @@ func (m *MultiCoalescer) Commit(pkt []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Flush drains every lane in a fixed order: TCP, UDP, passthrough. Errors
|
// Flush drains every lane in a fixed order: TCP, UDP, passthrough. Errors
|
||||||
// from a lane do not stop subsequent lanes from flushing — we keep
|
// from a lane do not stop subsequent lanes from flushing, we keep
|
||||||
// draining and return the first observed error so a single bad packet
|
// draining and return the first observed error so a single bad packet
|
||||||
// doesn't strand the others.
|
// doesn't strand the others.
|
||||||
func (m *MultiCoalescer) Flush() error {
|
func (m *MultiCoalescer) Flush() error {
|
||||||
var first error
|
var errs []error
|
||||||
keep := func(err error) {
|
if m.tcp != nil {
|
||||||
if err != nil && first == nil {
|
if err := m.tcp.Flush(); err != nil {
|
||||||
first = err
|
errs = append(errs, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if m.tcp != nil {
|
|
||||||
keep(m.tcp.Flush())
|
|
||||||
}
|
|
||||||
if m.udp != nil {
|
if m.udp != nil {
|
||||||
keep(m.udp.Flush())
|
if err := m.udp.Flush(); err != nil {
|
||||||
|
errs = append(errs, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := m.pt.Flush(); err != nil {
|
||||||
|
errs = append(errs, err)
|
||||||
}
|
}
|
||||||
keep(m.pt.Flush())
|
|
||||||
m.backing = m.backing[:0]
|
m.backing = m.backing[:0]
|
||||||
return first
|
return errors.Join(errs...)
|
||||||
}
|
}
|
||||||
|
|||||||
23
overlay/checksum/checksum_amd64.go
Normal file
23
overlay/checksum/checksum_amd64.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package checksum
|
||||||
|
|
||||||
|
import (
|
||||||
|
"golang.org/x/sys/cpu"
|
||||||
|
gvisorchecksum "gvisor.dev/gvisor/pkg/tcpip/checksum"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:noescape
|
||||||
|
func checksumAVX2(buf []byte, initial uint16) uint16
|
||||||
|
|
||||||
|
var hasAVX2 = cpu.X86.HasAVX2
|
||||||
|
|
||||||
|
// Checksum computes the RFC 1071 ones-complement sum of buf, seeded with
|
||||||
|
// initial. It is a drop-in replacement for gvisor's checksum.Checksum that
|
||||||
|
// dispatches to a hand-written AVX2 routine on amd64 CPUs that support it,
|
||||||
|
// falling back to gvisor's pure-Go implementation otherwise. The result
|
||||||
|
// matches gvisor's bit-for-bit for any buffer length and initial seed.
|
||||||
|
func Checksum(buf []byte, initial uint16) uint16 {
|
||||||
|
if hasAVX2 {
|
||||||
|
return checksumAVX2(buf, initial)
|
||||||
|
}
|
||||||
|
return gvisorchecksum.Checksum(buf, initial)
|
||||||
|
}
|
||||||
157
overlay/checksum/checksum_amd64.s
Normal file
157
overlay/checksum/checksum_amd64.s
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
#include "textflag.h"
|
||||||
|
|
||||||
|
// func checksumAVX2(buf []byte, initial uint16) uint16
|
||||||
|
//
|
||||||
|
// Computes the RFC 1071 ones-complement sum of buf, seeded with initial.
|
||||||
|
//
|
||||||
|
// Algorithm: sum the buffer treating it as a stream of uint32s in machine
|
||||||
|
// (little-endian) byte order, accumulating into 64-bit lanes (top 32 bits
|
||||||
|
// hold cross-add carries — at 1 byte / lane / iter we have 32 bits of
|
||||||
|
// headroom which is far more than the 16 KB/64 KB max practical inputs).
|
||||||
|
// At the end we fold to 16 bits and byte-swap once to recover the on-wire
|
||||||
|
// (big-endian) result. RFC 1071 §1.2.B byte-order independence makes this
|
||||||
|
// equivalent to summing as 16-bit big-endian words.
|
||||||
|
//
|
||||||
|
// The ymm accumulators (Y4..Y7) hold 4 uint64 lanes each = 16 parallel
|
||||||
|
// partial sums. The main loop loads 64 bytes per iter as four 16-byte
|
||||||
|
// chunks, zero-extending each chunk's four uint32s into a ymm via
|
||||||
|
// VPMOVZXDQ-from-memory, then VPADDQ into a separate accumulator per
|
||||||
|
// chunk to break the dep chain. After the vector loop the lane sums are
|
||||||
|
// horizontally reduced and merged with a scalar accumulator that handles
|
||||||
|
// the trailing 0..63 bytes plus the (byte-swapped) initial seed.
|
||||||
|
TEXT ·checksumAVX2(SB), NOSPLIT, $0-34
|
||||||
|
MOVQ buf_base+0(FP), SI
|
||||||
|
MOVQ buf_len+8(FP), CX
|
||||||
|
MOVWQZX initial+24(FP), AX
|
||||||
|
|
||||||
|
// Pre-byteswap initial into the LE-summing space so it merges directly
|
||||||
|
// with the rest of the accumulator. The final fold's bswap16 will undo
|
||||||
|
// this and convert the whole result back to BE.
|
||||||
|
XCHGB AH, AL
|
||||||
|
|
||||||
|
CMPQ CX, $32
|
||||||
|
JLT scalar_tail
|
||||||
|
|
||||||
|
VPXOR Y4, Y4, Y4
|
||||||
|
VPXOR Y5, Y5, Y5
|
||||||
|
VPXOR Y6, Y6, Y6
|
||||||
|
VPXOR Y7, Y7, Y7
|
||||||
|
|
||||||
|
CMPQ CX, $64
|
||||||
|
JLT loop32
|
||||||
|
|
||||||
|
loop64:
|
||||||
|
VPMOVZXDQ (SI), Y0
|
||||||
|
VPMOVZXDQ 16(SI), Y1
|
||||||
|
VPMOVZXDQ 32(SI), Y2
|
||||||
|
VPMOVZXDQ 48(SI), Y3
|
||||||
|
VPADDQ Y0, Y4, Y4
|
||||||
|
VPADDQ Y1, Y5, Y5
|
||||||
|
VPADDQ Y2, Y6, Y6
|
||||||
|
VPADDQ Y3, Y7, Y7
|
||||||
|
ADDQ $64, SI
|
||||||
|
SUBQ $64, CX
|
||||||
|
CMPQ CX, $64
|
||||||
|
JGE loop64
|
||||||
|
|
||||||
|
loop32:
|
||||||
|
CMPQ CX, $32
|
||||||
|
JLT reduce_vec
|
||||||
|
VPMOVZXDQ (SI), Y0
|
||||||
|
VPMOVZXDQ 16(SI), Y1
|
||||||
|
VPADDQ Y0, Y4, Y4
|
||||||
|
VPADDQ Y1, Y5, Y5
|
||||||
|
ADDQ $32, SI
|
||||||
|
SUBQ $32, CX
|
||||||
|
JMP loop32
|
||||||
|
|
||||||
|
reduce_vec:
|
||||||
|
// Combine the four ymm accumulators into Y4.
|
||||||
|
VPADDQ Y5, Y4, Y4
|
||||||
|
VPADDQ Y7, Y6, Y6
|
||||||
|
VPADDQ Y6, Y4, Y4
|
||||||
|
|
||||||
|
// Horizontally reduce Y4's four uint64 lanes to a single scalar.
|
||||||
|
VEXTRACTI128 $1, Y4, X5
|
||||||
|
VPADDQ X5, X4, X4
|
||||||
|
VPSHUFD $0x4e, X4, X5
|
||||||
|
VPADDQ X5, X4, X4
|
||||||
|
VMOVQ X4, R8
|
||||||
|
VZEROUPPER
|
||||||
|
|
||||||
|
ADDQ R8, AX
|
||||||
|
ADCQ $0, AX
|
||||||
|
|
||||||
|
scalar_tail:
|
||||||
|
// Handle remaining 0..63 bytes (or the entire buffer if it was < 32).
|
||||||
|
CMPQ CX, $8
|
||||||
|
JLT tail4
|
||||||
|
|
||||||
|
loop8:
|
||||||
|
ADDQ (SI), AX
|
||||||
|
ADCQ $0, AX
|
||||||
|
ADDQ $8, SI
|
||||||
|
SUBQ $8, CX
|
||||||
|
CMPQ CX, $8
|
||||||
|
JGE loop8
|
||||||
|
|
||||||
|
tail4:
|
||||||
|
CMPQ CX, $4
|
||||||
|
JLT tail2
|
||||||
|
MOVL (SI), R8
|
||||||
|
ADDQ R8, AX
|
||||||
|
ADCQ $0, AX
|
||||||
|
ADDQ $4, SI
|
||||||
|
SUBQ $4, CX
|
||||||
|
|
||||||
|
tail2:
|
||||||
|
CMPQ CX, $2
|
||||||
|
JLT tail1
|
||||||
|
MOVWQZX (SI), R8
|
||||||
|
ADDQ R8, AX
|
||||||
|
ADCQ $0, AX
|
||||||
|
ADDQ $2, SI
|
||||||
|
SUBQ $2, CX
|
||||||
|
|
||||||
|
tail1:
|
||||||
|
TESTQ CX, CX
|
||||||
|
JZ fold
|
||||||
|
MOVBQZX (SI), R8
|
||||||
|
ADDQ R8, AX
|
||||||
|
ADCQ $0, AX
|
||||||
|
|
||||||
|
fold:
|
||||||
|
// Fold the 64-bit accumulator to 16 bits via four rounds, mirroring
|
||||||
|
// gvisor's reduce(). Each pair (split, add) halves the live width;
|
||||||
|
// the truncation steps absorb the single bit that may be left over
|
||||||
|
// after each add so the next round's bound holds.
|
||||||
|
|
||||||
|
// 64 → 33 bits.
|
||||||
|
MOVQ AX, R8
|
||||||
|
SHRQ $32, R8
|
||||||
|
MOVL AX, AX
|
||||||
|
ADDQ R8, AX
|
||||||
|
|
||||||
|
// 33 → 32 bits. AX += (AX>>32); truncate to 32. AX is now ≤ 0xFFFF_FFFF.
|
||||||
|
MOVQ AX, R8
|
||||||
|
SHRQ $32, R8
|
||||||
|
ADDQ R8, AX
|
||||||
|
MOVL AX, AX
|
||||||
|
|
||||||
|
// 32 → 17 bits.
|
||||||
|
MOVQ AX, R8
|
||||||
|
SHRQ $16, R8
|
||||||
|
MOVWQZX AX, AX
|
||||||
|
ADDQ R8, AX
|
||||||
|
|
||||||
|
// 17 → 16 bits. AX += (AX>>16); the trailing MOVW truncates bit 16.
|
||||||
|
MOVQ AX, R8
|
||||||
|
SHRQ $16, R8
|
||||||
|
ADDQ R8, AX
|
||||||
|
|
||||||
|
// AX low 16 bits hold the 16-bit sum in machine (LE) byte order; flip
|
||||||
|
// to big-endian to match the gvisor API contract.
|
||||||
|
XCHGB AH, AL
|
||||||
|
|
||||||
|
MOVW AX, ret+32(FP)
|
||||||
|
RET
|
||||||
12
overlay/checksum/checksum_arm64.go
Normal file
12
overlay/checksum/checksum_arm64.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package checksum
|
||||||
|
|
||||||
|
//go:noescape
|
||||||
|
func checksumNEON(buf []byte, initial uint16) uint16
|
||||||
|
|
||||||
|
// Checksum computes the RFC 1071 ones-complement sum of buf, seeded with
|
||||||
|
// initial. It is a drop-in replacement for gvisor's checksum.Checksum
|
||||||
|
// that dispatches to a hand-written NEON routine. NEON is mandatory in
|
||||||
|
// armv8 so no feature check is needed.
|
||||||
|
func Checksum(buf []byte, initial uint16) uint16 {
|
||||||
|
return checksumNEON(buf, initial)
|
||||||
|
}
|
||||||
143
overlay/checksum/checksum_arm64.s
Normal file
143
overlay/checksum/checksum_arm64.s
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
#include "textflag.h"
|
||||||
|
|
||||||
|
// func checksumNEON(buf []byte, initial uint16) uint16
|
||||||
|
//
|
||||||
|
// Mirrors the algorithm in checksum_amd64.s: sum the buffer treating it as
|
||||||
|
// a stream of uint32s in machine (little-endian) byte order, accumulating
|
||||||
|
// into 64-bit lanes that have ample carry headroom; fold and byte-swap once
|
||||||
|
// at the very end to recover the on-wire (big-endian) result.
|
||||||
|
//
|
||||||
|
// Each loop iteration loads 64 bytes via VLD1.P into V0..V3 (4 Q regs).
|
||||||
|
// VUADDW takes the low two uint32 lanes of a Q reg, zero-extends them to
|
||||||
|
// uint64, and adds them into a 2×uint64 accumulator; VUADDW2 does the same
|
||||||
|
// for the high two lanes. Four ymm-equivalent accumulators (V8..V11) get
|
||||||
|
// updated twice per iter to break the dep chain. Tail bytes go through a
|
||||||
|
// scalar ADCS chain seeded with the byte-swapped initial.
|
||||||
|
TEXT ·checksumNEON(SB), NOSPLIT, $0-34
|
||||||
|
MOVD buf_base+0(FP), R0
|
||||||
|
MOVD buf_len+8(FP), R1
|
||||||
|
MOVHU initial+24(FP), R2
|
||||||
|
|
||||||
|
// Pre-byteswap initial into the LE-summing space so it merges directly
|
||||||
|
// with the rest of the accumulator.
|
||||||
|
REV16W R2, R2
|
||||||
|
|
||||||
|
MOVD ZR, R3 // scalar accumulator
|
||||||
|
|
||||||
|
CMP $32, R1
|
||||||
|
BLT scalar_tail
|
||||||
|
|
||||||
|
VEOR V8.B16, V8.B16, V8.B16
|
||||||
|
VEOR V9.B16, V9.B16, V9.B16
|
||||||
|
VEOR V10.B16, V10.B16, V10.B16
|
||||||
|
VEOR V11.B16, V11.B16, V11.B16
|
||||||
|
|
||||||
|
CMP $64, R1
|
||||||
|
BLT loop16_init
|
||||||
|
|
||||||
|
loop64:
|
||||||
|
VLD1.P 64(R0), [V0.B16, V1.B16, V2.B16, V3.B16]
|
||||||
|
VUADDW V0.S2, V8.D2, V8.D2
|
||||||
|
VUADDW2 V0.S4, V9.D2, V9.D2
|
||||||
|
VUADDW V1.S2, V10.D2, V10.D2
|
||||||
|
VUADDW2 V1.S4, V11.D2, V11.D2
|
||||||
|
VUADDW V2.S2, V8.D2, V8.D2
|
||||||
|
VUADDW2 V2.S4, V9.D2, V9.D2
|
||||||
|
VUADDW V3.S2, V10.D2, V10.D2
|
||||||
|
VUADDW2 V3.S4, V11.D2, V11.D2
|
||||||
|
SUB $64, R1, R1
|
||||||
|
CMP $64, R1
|
||||||
|
BGE loop64
|
||||||
|
|
||||||
|
loop16_init:
|
||||||
|
CMP $16, R1
|
||||||
|
BLT reduce_vec
|
||||||
|
|
||||||
|
loop16:
|
||||||
|
VLD1.P 16(R0), [V0.B16]
|
||||||
|
VUADDW V0.S2, V8.D2, V8.D2
|
||||||
|
VUADDW2 V0.S4, V9.D2, V9.D2
|
||||||
|
SUB $16, R1, R1
|
||||||
|
CMP $16, R1
|
||||||
|
BGE loop16
|
||||||
|
|
||||||
|
reduce_vec:
|
||||||
|
// Combine the four accumulators into V8.
|
||||||
|
VADD V9.D2, V8.D2, V8.D2
|
||||||
|
VADD V11.D2, V10.D2, V10.D2
|
||||||
|
VADD V10.D2, V8.D2, V8.D2
|
||||||
|
|
||||||
|
// Horizontal-add the two lanes of V8.D2 into a single uint64.
|
||||||
|
VADDP V8.D2, V8.D2, V8.D2
|
||||||
|
VMOV V8.D[0], R8
|
||||||
|
|
||||||
|
ADDS R8, R3, R3
|
||||||
|
ADC ZR, R3, R3
|
||||||
|
|
||||||
|
scalar_tail:
|
||||||
|
CMP $8, R1
|
||||||
|
BLT tail4
|
||||||
|
|
||||||
|
loop8:
|
||||||
|
MOVD.P 8(R0), R8
|
||||||
|
ADDS R8, R3, R3
|
||||||
|
ADC ZR, R3, R3
|
||||||
|
SUB $8, R1, R1
|
||||||
|
CMP $8, R1
|
||||||
|
BGE loop8
|
||||||
|
|
||||||
|
tail4:
|
||||||
|
CMP $4, R1
|
||||||
|
BLT tail2
|
||||||
|
MOVWU.P 4(R0), R8
|
||||||
|
ADDS R8, R3, R3
|
||||||
|
ADC ZR, R3, R3
|
||||||
|
SUB $4, R1, R1
|
||||||
|
|
||||||
|
tail2:
|
||||||
|
CMP $2, R1
|
||||||
|
BLT tail1
|
||||||
|
MOVHU.P 2(R0), R8
|
||||||
|
ADDS R8, R3, R3
|
||||||
|
ADC ZR, R3, R3
|
||||||
|
SUB $2, R1, R1
|
||||||
|
|
||||||
|
tail1:
|
||||||
|
CBZ R1, fold
|
||||||
|
MOVBU (R0), R8
|
||||||
|
ADDS R8, R3, R3
|
||||||
|
ADC ZR, R3, R3
|
||||||
|
|
||||||
|
fold:
|
||||||
|
// Merge the byte-swapped initial into our LE-form accumulator.
|
||||||
|
ADDS R2, R3, R3
|
||||||
|
ADC ZR, R3, R3
|
||||||
|
|
||||||
|
// 64 → 33 bits.
|
||||||
|
LSR $32, R3, R8
|
||||||
|
AND $0xffffffff, R3, R3
|
||||||
|
ADD R8, R3, R3
|
||||||
|
|
||||||
|
// 33 → 32 (truncate after adding bit 32 back).
|
||||||
|
LSR $32, R3, R8
|
||||||
|
ADD R8, R3, R3
|
||||||
|
AND $0xffffffff, R3, R3
|
||||||
|
|
||||||
|
// 32 → 17.
|
||||||
|
LSR $16, R3, R8
|
||||||
|
AND $0xffff, R3, R3
|
||||||
|
ADD R8, R3, R3
|
||||||
|
|
||||||
|
// 17 → 16 (truncation absorbs bit 16 below).
|
||||||
|
LSR $16, R3, R8
|
||||||
|
ADD R8, R3, R3
|
||||||
|
|
||||||
|
// AX low 16 bits hold the 16-bit sum in machine (LE) byte order; flip
|
||||||
|
// to big-endian to match the gvisor API contract. REV16W swaps bytes
|
||||||
|
// within each 16-bit halfword of the low 32 bits, so it acts as a
|
||||||
|
// 16-bit byte-swap on the live low 16.
|
||||||
|
REV16W R3, R3
|
||||||
|
AND $0xffff, R3, R3
|
||||||
|
|
||||||
|
MOVH R3, ret+32(FP)
|
||||||
|
RET
|
||||||
10
overlay/checksum/checksum_fallback.go
Normal file
10
overlay/checksum/checksum_fallback.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
//go:build !amd64 && !arm64
|
||||||
|
|
||||||
|
package checksum
|
||||||
|
|
||||||
|
import gvisorchecksum "gvisor.dev/gvisor/pkg/tcpip/checksum"
|
||||||
|
|
||||||
|
// Checksum delegates to gvisor on architectures without a hand-written body.
|
||||||
|
func Checksum(buf []byte, initial uint16) uint16 {
|
||||||
|
return gvisorchecksum.Checksum(buf, initial)
|
||||||
|
}
|
||||||
190
overlay/checksum/checksum_test.go
Normal file
190
overlay/checksum/checksum_test.go
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
package checksum
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand/v2"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
gvisorchecksum "gvisor.dev/gvisor/pkg/tcpip/checksum"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestChecksumMatchesGvisor walks lengths from 0 to 4096, with several initial
|
||||||
|
// seeds and a handful of starting alignments, asserting that our local
|
||||||
|
// Checksum matches gvisor's reference bit-for-bit.
|
||||||
|
func TestChecksumMatchesGvisor(t *testing.T) {
|
||||||
|
rng := rand.New(rand.NewPCG(1, 2))
|
||||||
|
const padFront = 16
|
||||||
|
|
||||||
|
// Random pool large enough for the longest case + alignment slop.
|
||||||
|
pool := make([]byte, 4096+padFront)
|
||||||
|
for i := range pool {
|
||||||
|
pool[i] = byte(rng.Uint32())
|
||||||
|
}
|
||||||
|
|
||||||
|
seeds := []uint16{0, 0x0001, 0xabcd, 0xffff, 0x1234, 0xfedc}
|
||||||
|
offsets := []int{0, 1, 2, 3, 4, 5, 7, 8, 15, 16}
|
||||||
|
|
||||||
|
for length := 0; length <= 4096; length++ {
|
||||||
|
for _, seed := range seeds {
|
||||||
|
for _, off := range offsets {
|
||||||
|
if off+length > len(pool) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
buf := pool[off : off+length]
|
||||||
|
want := gvisorchecksum.Checksum(buf, seed)
|
||||||
|
got := Checksum(buf, seed)
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("len=%d off=%d seed=%#x: got %#04x want %#04x",
|
||||||
|
length, off, seed, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestChecksumPatternedBuffers exercises specific byte patterns that have
|
||||||
|
// historically tripped up checksum implementations: all-zero, all-0xff,
|
||||||
|
// alternating, and ascending sequences.
|
||||||
|
func TestChecksumPatternedBuffers(t *testing.T) {
|
||||||
|
for length := 0; length <= 256; length++ {
|
||||||
|
patterns := map[string][]byte{
|
||||||
|
"zeros": make([]byte, length),
|
||||||
|
"ones": bytes(length, 0xff),
|
||||||
|
"alternating": pattern(length, []byte{0xa5, 0x5a}),
|
||||||
|
"ascending": ascending(length),
|
||||||
|
}
|
||||||
|
for name, buf := range patterns {
|
||||||
|
for _, seed := range []uint16{0, 0xffff, 0x8000} {
|
||||||
|
want := gvisorchecksum.Checksum(buf, seed)
|
||||||
|
got := Checksum(buf, seed)
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("%s len=%d seed=%#x: got %#04x want %#04x",
|
||||||
|
name, length, seed, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func bytes(n int, v byte) []byte {
|
||||||
|
b := make([]byte, n)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = v
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func pattern(n int, p []byte) []byte {
|
||||||
|
b := make([]byte, n)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = p[i%len(p)]
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func ascending(n int) []byte {
|
||||||
|
b := make([]byte, n)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = byte(i)
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestChecksumTailPaths targets every combination of (SIMD body iterations,
|
||||||
|
// trailing tail bytes) the asm handlers walk through. The tail handlers
|
||||||
|
// peel off 8 → 4 → 2 → 1 byte chunks in turn; this test exercises each by
|
||||||
|
// constructing lengths of the form 64*k + tail for tail ∈ [0, 63] and a
|
||||||
|
// representative spread of k values, including k=0 (no main loop, all tail)
|
||||||
|
// and k=1 (one main loop iter, then tail). It's explicit coverage for
|
||||||
|
// payload sizes that are odd, not divisible by 4, by 8, or by 32.
|
||||||
|
func TestChecksumTailPaths(t *testing.T) {
|
||||||
|
rng := rand.New(rand.NewPCG(42, 17))
|
||||||
|
const padFront = 16
|
||||||
|
const maxK = 8
|
||||||
|
|
||||||
|
pool := make([]byte, 64*maxK+padFront+64)
|
||||||
|
for i := range pool {
|
||||||
|
pool[i] = byte(rng.Uint32())
|
||||||
|
}
|
||||||
|
|
||||||
|
seeds := []uint16{0, 0xffff, 0xabcd}
|
||||||
|
offsets := []int{0, 1, 3, 7, 15} // mix of aligned and odd starts
|
||||||
|
|
||||||
|
for k := 0; k <= maxK; k++ {
|
||||||
|
for tail := 0; tail < 64; tail++ {
|
||||||
|
length := 64*k + tail
|
||||||
|
for _, seed := range seeds {
|
||||||
|
for _, off := range offsets {
|
||||||
|
if off+length > len(pool) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
buf := pool[off : off+length]
|
||||||
|
want := gvisorchecksum.Checksum(buf, seed)
|
||||||
|
got := Checksum(buf, seed)
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("k=%d tail=%d (len=%d) off=%d seed=%#x: got %#04x want %#04x",
|
||||||
|
k, tail, length, off, seed, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkChecksumTailSizes covers payload sizes that aren't clean multiples
|
||||||
|
// of the SIMD body's 32-byte (amd64) or 16-byte (arm64) chunks, so the tail
|
||||||
|
// handler is meaningfully on the hot path. Sizes are picked to either exercise
|
||||||
|
// every tail branch (tiny lengths) or sit slightly off realistic packet
|
||||||
|
// boundaries (e.g. 1499 = MTU − 1).
|
||||||
|
func BenchmarkChecksumTailSizes(b *testing.B) {
|
||||||
|
sizes := []int{
|
||||||
|
1, 3, 7, 15, 31, // sub-SIMD; entire work is scalar tail
|
||||||
|
33, 35, 47, 63, // one loop32 + assorted tails
|
||||||
|
65, 95, 127, // one loop64 + assorted tails
|
||||||
|
1447, 1471, 1499, 1501, // around MTU
|
||||||
|
8191, 8193, // around USO
|
||||||
|
65531, 65533, // near the kernel max
|
||||||
|
}
|
||||||
|
for _, size := range sizes {
|
||||||
|
buf := make([]byte, size)
|
||||||
|
for i := range buf {
|
||||||
|
buf[i] = byte(i)
|
||||||
|
}
|
||||||
|
b.Run(fmt.Sprintf("size=%d/local", size), func(b *testing.B) {
|
||||||
|
b.SetBytes(int64(size))
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = Checksum(buf, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
b.Run(fmt.Sprintf("size=%d/gvisor", size), func(b *testing.B) {
|
||||||
|
b.SetBytes(int64(size))
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = gvisorchecksum.Checksum(buf, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkChecksum compares the local Checksum to gvisor's at sizes that
|
||||||
|
// match real traffic: a TCP/IP header (60), a typical MSS (1448), a typical
|
||||||
|
// USO size (8192), and the kernel's max GSO superpacket (65535).
|
||||||
|
func BenchmarkChecksum(b *testing.B) {
|
||||||
|
for _, size := range []int{60, 1448, 8192, 65535} {
|
||||||
|
buf := make([]byte, size)
|
||||||
|
for i := range buf {
|
||||||
|
buf[i] = byte(i)
|
||||||
|
}
|
||||||
|
b.Run(fmt.Sprintf("size=%d/local", size), func(b *testing.B) {
|
||||||
|
b.SetBytes(int64(size))
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = Checksum(buf, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
b.Run(fmt.Sprintf("size=%d/gvisor", size), func(b *testing.B) {
|
||||||
|
b.SetBytes(int64(size))
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = gvisorchecksum.Checksum(buf, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -14,7 +14,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/checksum"
|
|
||||||
|
"github.com/slackhq/nebula/overlay/checksum"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Protocol header size bounds used to validate / cap kernel-supplied offsets.
|
// Protocol header size bounds used to validate / cap kernel-supplied offsets.
|
||||||
|
|||||||
Reference in New Issue
Block a user