mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-16 04:47:38 +02:00
robot say this faster
This commit is contained in:
@@ -123,6 +123,12 @@ func finishChecksum(seg []byte, hdr virtioNetHdr) error {
|
|||||||
|
|
||||||
// segmentTCP software-segments a TSO superpacket into one IP packet per MSS
|
// segmentTCP software-segments a TSO superpacket into one IP packet per MSS
|
||||||
// chunk. The caller guarantees hdr.GSOType is TCPV4 or TCPV6.
|
// chunk. The caller guarantees hdr.GSOType is TCPV4 or TCPV6.
|
||||||
|
//
|
||||||
|
// Hot-path shape: the per-segment loop only sums the payload chunk. The TCP
|
||||||
|
// header, the IPv4 header, and the pseudo-header src/dst/proto contributions
|
||||||
|
// are each summed once up front — every segment reuses those three pre-folded
|
||||||
|
// uint32 values and combines them with small per-segment deltas (seq, flags,
|
||||||
|
// tcpLen, ip_id, total_len) that are cheap to fold in.
|
||||||
func segmentTCP(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) error {
|
func segmentTCP(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) error {
|
||||||
if hdr.GSOSize == 0 {
|
if hdr.GSOSize == 0 {
|
||||||
return fmt.Errorf("gso_size is zero")
|
return fmt.Errorf("gso_size is zero")
|
||||||
@@ -144,8 +150,9 @@ func segmentTCP(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) err
|
|||||||
if !isV4 && csumStart < 40 {
|
if !isV4 && csumStart < 40 {
|
||||||
return fmt.Errorf("csum_start %d too small for IPv6", csumStart)
|
return fmt.Errorf("csum_start %d too small for IPv6", csumStart)
|
||||||
}
|
}
|
||||||
if headerLen-csumStart < 20 {
|
tcpHdrLen := headerLen - csumStart
|
||||||
return fmt.Errorf("tcp header region too small: %d", headerLen-csumStart)
|
if tcpHdrLen < 20 {
|
||||||
|
return fmt.Errorf("tcp header region too small: %d", tcpHdrLen)
|
||||||
}
|
}
|
||||||
|
|
||||||
payload := pkt[headerLen:]
|
payload := pkt[headerLen:]
|
||||||
@@ -165,9 +172,43 @@ func segmentTCP(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) err
|
|||||||
origFlags := pkt[csumStart+13]
|
origFlags := pkt[csumStart+13]
|
||||||
const tcpFinPsh = 0x09 // FIN(0x01) | PSH(0x08)
|
const tcpFinPsh = 0x09 // FIN(0x01) | PSH(0x08)
|
||||||
|
|
||||||
|
// Precompute the TCP header sum with seq/flags/csum zeroed. The max TCP
|
||||||
|
// header is 60 bytes; copy onto the stack, zero the per-segment-varying
|
||||||
|
// fields, sum once.
|
||||||
|
var tmp [60]byte
|
||||||
|
copy(tmp[:tcpHdrLen], pkt[csumStart:headerLen])
|
||||||
|
tmp[4], tmp[5], tmp[6], tmp[7] = 0, 0, 0, 0 // seq
|
||||||
|
tmp[13] = 0 // flags
|
||||||
|
tmp[16], tmp[17] = 0, 0 // csum
|
||||||
|
baseTcpHdrSum := checksumBytes(tmp[:tcpHdrLen], 0)
|
||||||
|
|
||||||
|
// Pseudo-header src+dst+proto contribution (tcpLen varies per segment).
|
||||||
|
var baseProtoSum uint32
|
||||||
|
if isV4 {
|
||||||
|
baseProtoSum = checksumBytes(pkt[12:16], 0)
|
||||||
|
baseProtoSum = checksumBytes(pkt[16:20], baseProtoSum)
|
||||||
|
} else {
|
||||||
|
baseProtoSum = checksumBytes(pkt[8:24], 0)
|
||||||
|
baseProtoSum = checksumBytes(pkt[24:40], baseProtoSum)
|
||||||
|
}
|
||||||
|
baseProtoSum += uint32(unix.IPPROTO_TCP)
|
||||||
|
|
||||||
|
// Precompute IPv4 header sum with total_len/id/csum zeroed.
|
||||||
var origIPID uint16
|
var origIPID uint16
|
||||||
|
var ihl int
|
||||||
|
var baseIPHdrSum uint32
|
||||||
if isV4 {
|
if isV4 {
|
||||||
origIPID = binary.BigEndian.Uint16(pkt[4:6])
|
origIPID = binary.BigEndian.Uint16(pkt[4:6])
|
||||||
|
ihl = int(pkt[0]&0x0f) * 4
|
||||||
|
if ihl < 20 || ihl > csumStart {
|
||||||
|
return fmt.Errorf("bad IPv4 IHL: %d", ihl)
|
||||||
|
}
|
||||||
|
var ipTmp [60]byte
|
||||||
|
copy(ipTmp[:ihl], pkt[:ihl])
|
||||||
|
ipTmp[2], ipTmp[3] = 0, 0 // total_len
|
||||||
|
ipTmp[4], ipTmp[5] = 0, 0 // id
|
||||||
|
ipTmp[10], ipTmp[11] = 0, 0 // checksum
|
||||||
|
baseIPHdrSum = checksumBytes(ipTmp[:ihl], 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
off := 0
|
off := 0
|
||||||
@@ -179,48 +220,47 @@ func segmentTCP(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) err
|
|||||||
}
|
}
|
||||||
segPayLen := segEnd - segStart
|
segPayLen := segEnd - segStart
|
||||||
|
|
||||||
// Materialise IP+TCP header and this segment's payload chunk.
|
|
||||||
copy(scratch[off:], pkt[:headerLen])
|
copy(scratch[off:], pkt[:headerLen])
|
||||||
copy(scratch[off+headerLen:], payload[segStart:segEnd])
|
copy(scratch[off+headerLen:], payload[segStart:segEnd])
|
||||||
seg := scratch[off : off+headerLen+segPayLen]
|
seg := scratch[off : off+headerLen+segPayLen]
|
||||||
off += headerLen + segPayLen
|
off += headerLen + segPayLen
|
||||||
|
|
||||||
// Fix IP header: total/payload length, v4 ID, v4 header csum.
|
segSeq := origSeq + uint32(segStart)
|
||||||
if isV4 {
|
segFlags := origFlags
|
||||||
ihl := int(seg[0]&0x0f) * 4
|
if i != numSeg-1 {
|
||||||
if ihl < 20 || ihl > csumStart {
|
segFlags = origFlags &^ tcpFinPsh
|
||||||
return fmt.Errorf("bad IPv4 IHL: %d", ihl)
|
|
||||||
}
|
}
|
||||||
binary.BigEndian.PutUint16(seg[2:4], uint16(headerLen+segPayLen))
|
totalLen := headerLen + segPayLen
|
||||||
binary.BigEndian.PutUint16(seg[4:6], origIPID+uint16(i))
|
|
||||||
seg[10] = 0
|
// Patch IP header and write the v4 header checksum from the precomputed base.
|
||||||
seg[11] = 0
|
if isV4 {
|
||||||
binary.BigEndian.PutUint16(seg[10:12], checksumFold(checksumBytes(seg[:ihl], 0)))
|
segID := origIPID + uint16(i)
|
||||||
|
binary.BigEndian.PutUint16(seg[2:4], uint16(totalLen))
|
||||||
|
binary.BigEndian.PutUint16(seg[4:6], segID)
|
||||||
|
ipSum := baseIPHdrSum + uint32(totalLen) + uint32(segID)
|
||||||
|
binary.BigEndian.PutUint16(seg[10:12], checksumFold(ipSum))
|
||||||
} else {
|
} else {
|
||||||
// IPv6 payload length excludes the 40-byte fixed header but
|
// IPv6 payload length excludes the 40-byte fixed header but
|
||||||
// includes any extension headers that sit between [40:csumStart].
|
// includes any extension headers between [40:csumStart].
|
||||||
binary.BigEndian.PutUint16(seg[4:6], uint16(headerLen-40+segPayLen))
|
binary.BigEndian.PutUint16(seg[4:6], uint16(headerLen-40+segPayLen))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fix TCP header: seq, flags, checksum.
|
// Patch TCP header.
|
||||||
segSeq := origSeq + uint32(segStart)
|
|
||||||
binary.BigEndian.PutUint32(seg[csumStart+4:csumStart+8], segSeq)
|
binary.BigEndian.PutUint32(seg[csumStart+4:csumStart+8], segSeq)
|
||||||
if i != numSeg-1 {
|
seg[csumStart+13] = segFlags
|
||||||
seg[csumStart+13] = origFlags &^ tcpFinPsh
|
// (csum is written below; its prior contents in `seg` don't affect the
|
||||||
} else {
|
// computation since we never sum over the segment's own header.)
|
||||||
seg[csumStart+13] = origFlags
|
|
||||||
}
|
|
||||||
seg[csumStart+16] = 0
|
|
||||||
seg[csumStart+17] = 0
|
|
||||||
|
|
||||||
tcpLen := headerLen - csumStart + segPayLen
|
tcpLen := tcpHdrLen + segPayLen
|
||||||
var psum uint32
|
paySum := checksumBytes(payload[segStart:segEnd], 0)
|
||||||
if isV4 {
|
|
||||||
psum = pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, tcpLen)
|
// Combine pre-folded uint32s into a wider accumulator, then fold. Using
|
||||||
} else {
|
// uint64 guards against overflow when segSeq's high bits set.
|
||||||
psum = pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_TCP, tcpLen)
|
wide := uint64(baseTcpHdrSum) + uint64(paySum) + uint64(baseProtoSum)
|
||||||
}
|
wide += uint64(segSeq) + uint64(segFlags) + uint64(tcpLen)
|
||||||
binary.BigEndian.PutUint16(seg[csumStart+16:csumStart+18], checksumFold(checksumBytes(seg[csumStart:csumStart+tcpLen], psum)))
|
wide = (wide & 0xffffffff) + (wide >> 32)
|
||||||
|
wide = (wide & 0xffffffff) + (wide >> 32)
|
||||||
|
binary.BigEndian.PutUint16(seg[csumStart+16:csumStart+18], checksumFold(uint32(wide)))
|
||||||
|
|
||||||
*out = append(*out, seg)
|
*out = append(*out, seg)
|
||||||
}
|
}
|
||||||
@@ -231,28 +271,27 @@ func segmentTCP(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) err
|
|||||||
// checksumBytes returns the Internet-checksum partial sum of b, seeded with
|
// checksumBytes returns the Internet-checksum partial sum of b, seeded with
|
||||||
// initial. Result is a 32-bit accumulator; the caller folds to 16.
|
// initial. Result is a 32-bit accumulator; the caller folds to 16.
|
||||||
//
|
//
|
||||||
// Wide-word variant: each 8-byte load contributes four 16-bit lanes to a
|
// Each 4-byte load is added directly into a 64-bit accumulator. Two parallel
|
||||||
// 64-bit accumulator, cutting the number of loads, shifts, and slice reslices
|
// accumulators break the serial dependency through `sum` and let the CPU
|
||||||
// ~4x versus the naive Uint16 loop. The 64-bit accumulator has ample headroom
|
// overlap independent adds. The final fold from 64 → 32 → 16 handles the
|
||||||
// — worst case is (initial=2^32) + (64KiB / 2) * 0xffff ≈ 2.5 * 10^9, far
|
// carries that accumulated across the 32-bit lane boundary.
|
||||||
// below 2^64 — so no mid-loop fold is needed.
|
|
||||||
func checksumBytes(b []byte, initial uint32) uint32 {
|
func checksumBytes(b []byte, initial uint32) uint32 {
|
||||||
sum := uint64(initial)
|
s0 := uint64(initial)
|
||||||
for len(b) >= 16 {
|
var s1 uint64
|
||||||
w1 := binary.BigEndian.Uint64(b[:8])
|
for len(b) >= 32 {
|
||||||
w2 := binary.BigEndian.Uint64(b[8:16])
|
s0 += uint64(binary.BigEndian.Uint32(b[0:4]))
|
||||||
sum += (w1 >> 48) + ((w1 >> 32) & 0xffff) + ((w1 >> 16) & 0xffff) + (w1 & 0xffff)
|
s1 += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||||
sum += (w2 >> 48) + ((w2 >> 32) & 0xffff) + ((w2 >> 16) & 0xffff) + (w2 & 0xffff)
|
s0 += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||||
b = b[16:]
|
s1 += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||||
|
s0 += uint64(binary.BigEndian.Uint32(b[16:20]))
|
||||||
|
s1 += uint64(binary.BigEndian.Uint32(b[20:24]))
|
||||||
|
s0 += uint64(binary.BigEndian.Uint32(b[24:28]))
|
||||||
|
s1 += uint64(binary.BigEndian.Uint32(b[28:32]))
|
||||||
|
b = b[32:]
|
||||||
}
|
}
|
||||||
if len(b) >= 8 {
|
sum := s0 + s1
|
||||||
w := binary.BigEndian.Uint64(b[:8])
|
for len(b) >= 4 {
|
||||||
sum += (w >> 48) + ((w >> 32) & 0xffff) + ((w >> 16) & 0xffff) + (w & 0xffff)
|
sum += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||||
b = b[8:]
|
|
||||||
}
|
|
||||||
if len(b) >= 4 {
|
|
||||||
w := binary.BigEndian.Uint32(b[:4])
|
|
||||||
sum += uint64(w>>16) + uint64(w&0xffff)
|
|
||||||
b = b[4:]
|
b = b[4:]
|
||||||
}
|
}
|
||||||
if len(b) >= 2 {
|
if len(b) >= 2 {
|
||||||
@@ -262,8 +301,6 @@ func checksumBytes(b []byte, initial uint32) uint32 {
|
|||||||
if len(b) == 1 {
|
if len(b) == 1 {
|
||||||
sum += uint64(b[0]) << 8
|
sum += uint64(b[0]) << 8
|
||||||
}
|
}
|
||||||
// Fold 64 → 32. The checksum is one's complement, so carries are
|
|
||||||
// end-around-added; once the high 32 bits are zero we're done.
|
|
||||||
sum = (sum & 0xffffffff) + (sum >> 32)
|
sum = (sum & 0xffffffff) + (sum >> 32)
|
||||||
sum = (sum & 0xffffffff) + (sum >> 32)
|
sum = (sum & 0xffffffff) + (sum >> 32)
|
||||||
return uint32(sum)
|
return uint32(sum)
|
||||||
|
|||||||
@@ -247,6 +247,62 @@ func TestSegmentRejectsUDP(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BenchmarkSegmentTCPv4(b *testing.B) {
|
||||||
|
sizes := []struct {
|
||||||
|
name string
|
||||||
|
payLen int
|
||||||
|
mss int
|
||||||
|
}{
|
||||||
|
{"64KiB_MSS1460", 65000, 1460},
|
||||||
|
{"16KiB_MSS1460", 16384, 1460},
|
||||||
|
{"4KiB_MSS1460", 4096, 1460},
|
||||||
|
}
|
||||||
|
for _, sz := range sizes {
|
||||||
|
b.Run(sz.name, func(b *testing.B) {
|
||||||
|
const ipLen = 20
|
||||||
|
const tcpLen = 20
|
||||||
|
pkt := make([]byte, ipLen+tcpLen+sz.payLen)
|
||||||
|
pkt[0] = 0x45
|
||||||
|
binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+tcpLen+sz.payLen))
|
||||||
|
binary.BigEndian.PutUint16(pkt[4:6], 0x4242)
|
||||||
|
pkt[8] = 64
|
||||||
|
pkt[9] = unix.IPPROTO_TCP
|
||||||
|
copy(pkt[12:16], []byte{10, 0, 0, 1})
|
||||||
|
copy(pkt[16:20], []byte{10, 0, 0, 2})
|
||||||
|
binary.BigEndian.PutUint16(pkt[20:22], 12345)
|
||||||
|
binary.BigEndian.PutUint16(pkt[22:24], 80)
|
||||||
|
binary.BigEndian.PutUint32(pkt[24:28], 10000)
|
||||||
|
binary.BigEndian.PutUint32(pkt[28:32], 20000)
|
||||||
|
pkt[32] = 0x50
|
||||||
|
pkt[33] = 0x18
|
||||||
|
binary.BigEndian.PutUint16(pkt[34:36], 65535)
|
||||||
|
for i := 0; i < sz.payLen; i++ {
|
||||||
|
pkt[ipLen+tcpLen+i] = byte(i)
|
||||||
|
}
|
||||||
|
hdr := virtioNetHdr{
|
||||||
|
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||||||
|
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
|
||||||
|
HdrLen: uint16(ipLen + tcpLen),
|
||||||
|
GSOSize: uint16(sz.mss),
|
||||||
|
CsumStart: uint16(ipLen),
|
||||||
|
CsumOffset: 16,
|
||||||
|
}
|
||||||
|
|
||||||
|
scratch := make([]byte, tunSegBufSize)
|
||||||
|
out := make([][]byte, 0, 64)
|
||||||
|
|
||||||
|
b.SetBytes(int64(len(pkt)))
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
out = out[:0]
|
||||||
|
if err := segmentTCP(pkt, hdr, &out, scratch); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestTunFileWriteVnetHdrNoAlloc verifies the IFF_VNET_HDR fast-path write is
|
// TestTunFileWriteVnetHdrNoAlloc verifies the IFF_VNET_HDR fast-path write is
|
||||||
// allocation-free. We write to /dev/null so every call succeeds synchronously.
|
// allocation-free. We write to /dev/null so every call succeeds synchronously.
|
||||||
func TestTunFileWriteVnetHdrNoAlloc(t *testing.T) {
|
func TestTunFileWriteVnetHdrNoAlloc(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user