diff --git a/overlay/tun_linux_offload.go b/overlay/tun_linux_offload.go index d8b0eeac..660c80ad 100644 --- a/overlay/tun_linux_offload.go +++ b/overlay/tun_linux_offload.go @@ -123,6 +123,12 @@ func finishChecksum(seg []byte, hdr virtioNetHdr) error { // segmentTCP software-segments a TSO superpacket into one IP packet per MSS // chunk. The caller guarantees hdr.GSOType is TCPV4 or TCPV6. +// +// Hot-path shape: the per-segment loop only sums the payload chunk. The TCP +// header, the IPv4 header, and the pseudo-header src/dst/proto contributions +// are each summed once up front — every segment reuses those three pre-folded +// uint32 values and combines them with small per-segment deltas (seq, flags, +// tcpLen, ip_id, total_len) that are cheap to fold in. func segmentTCP(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) error { if hdr.GSOSize == 0 { return fmt.Errorf("gso_size is zero") @@ -144,8 +150,9 @@ func segmentTCP(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) err if !isV4 && csumStart < 40 { return fmt.Errorf("csum_start %d too small for IPv6", csumStart) } - if headerLen-csumStart < 20 { - return fmt.Errorf("tcp header region too small: %d", headerLen-csumStart) + tcpHdrLen := headerLen - csumStart + if tcpHdrLen < 20 { + return fmt.Errorf("tcp header region too small: %d", tcpHdrLen) } payload := pkt[headerLen:] @@ -165,9 +172,43 @@ func segmentTCP(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) err origFlags := pkt[csumStart+13] const tcpFinPsh = 0x09 // FIN(0x01) | PSH(0x08) + // Precompute the TCP header sum with seq/flags/csum zeroed. The max TCP + // header is 60 bytes; copy onto the stack, zero the per-segment-varying + // fields, sum once. + var tmp [60]byte + copy(tmp[:tcpHdrLen], pkt[csumStart:headerLen]) + tmp[4], tmp[5], tmp[6], tmp[7] = 0, 0, 0, 0 // seq + tmp[13] = 0 // flags + tmp[16], tmp[17] = 0, 0 // csum + baseTcpHdrSum := checksumBytes(tmp[:tcpHdrLen], 0) + + // Pseudo-header src+dst+proto contribution (tcpLen varies per segment). + var baseProtoSum uint32 + if isV4 { + baseProtoSum = checksumBytes(pkt[12:16], 0) + baseProtoSum = checksumBytes(pkt[16:20], baseProtoSum) + } else { + baseProtoSum = checksumBytes(pkt[8:24], 0) + baseProtoSum = checksumBytes(pkt[24:40], baseProtoSum) + } + baseProtoSum += uint32(unix.IPPROTO_TCP) + + // Precompute IPv4 header sum with total_len/id/csum zeroed. var origIPID uint16 + var ihl int + var baseIPHdrSum uint32 if isV4 { origIPID = binary.BigEndian.Uint16(pkt[4:6]) + ihl = int(pkt[0]&0x0f) * 4 + if ihl < 20 || ihl > csumStart { + return fmt.Errorf("bad IPv4 IHL: %d", ihl) + } + var ipTmp [60]byte + copy(ipTmp[:ihl], pkt[:ihl]) + ipTmp[2], ipTmp[3] = 0, 0 // total_len + ipTmp[4], ipTmp[5] = 0, 0 // id + ipTmp[10], ipTmp[11] = 0, 0 // checksum + baseIPHdrSum = checksumBytes(ipTmp[:ihl], 0) } off := 0 @@ -179,48 +220,47 @@ func segmentTCP(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) err } segPayLen := segEnd - segStart - // Materialise IP+TCP header and this segment's payload chunk. copy(scratch[off:], pkt[:headerLen]) copy(scratch[off+headerLen:], payload[segStart:segEnd]) seg := scratch[off : off+headerLen+segPayLen] off += headerLen + segPayLen - // Fix IP header: total/payload length, v4 ID, v4 header csum. + segSeq := origSeq + uint32(segStart) + segFlags := origFlags + if i != numSeg-1 { + segFlags = origFlags &^ tcpFinPsh + } + totalLen := headerLen + segPayLen + + // Patch IP header and write the v4 header checksum from the precomputed base. if isV4 { - ihl := int(seg[0]&0x0f) * 4 - if ihl < 20 || ihl > csumStart { - return fmt.Errorf("bad IPv4 IHL: %d", ihl) - } - binary.BigEndian.PutUint16(seg[2:4], uint16(headerLen+segPayLen)) - binary.BigEndian.PutUint16(seg[4:6], origIPID+uint16(i)) - seg[10] = 0 - seg[11] = 0 - binary.BigEndian.PutUint16(seg[10:12], checksumFold(checksumBytes(seg[:ihl], 0))) + segID := origIPID + uint16(i) + binary.BigEndian.PutUint16(seg[2:4], uint16(totalLen)) + binary.BigEndian.PutUint16(seg[4:6], segID) + ipSum := baseIPHdrSum + uint32(totalLen) + uint32(segID) + binary.BigEndian.PutUint16(seg[10:12], checksumFold(ipSum)) } else { // IPv6 payload length excludes the 40-byte fixed header but - // includes any extension headers that sit between [40:csumStart]. + // includes any extension headers between [40:csumStart]. binary.BigEndian.PutUint16(seg[4:6], uint16(headerLen-40+segPayLen)) } - // Fix TCP header: seq, flags, checksum. - segSeq := origSeq + uint32(segStart) + // Patch TCP header. binary.BigEndian.PutUint32(seg[csumStart+4:csumStart+8], segSeq) - if i != numSeg-1 { - seg[csumStart+13] = origFlags &^ tcpFinPsh - } else { - seg[csumStart+13] = origFlags - } - seg[csumStart+16] = 0 - seg[csumStart+17] = 0 + seg[csumStart+13] = segFlags + // (csum is written below; its prior contents in `seg` don't affect the + // computation since we never sum over the segment's own header.) - tcpLen := headerLen - csumStart + segPayLen - var psum uint32 - if isV4 { - psum = pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, tcpLen) - } else { - psum = pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_TCP, tcpLen) - } - binary.BigEndian.PutUint16(seg[csumStart+16:csumStart+18], checksumFold(checksumBytes(seg[csumStart:csumStart+tcpLen], psum))) + tcpLen := tcpHdrLen + segPayLen + paySum := checksumBytes(payload[segStart:segEnd], 0) + + // Combine pre-folded uint32s into a wider accumulator, then fold. Using + // uint64 guards against overflow when segSeq's high bits set. + wide := uint64(baseTcpHdrSum) + uint64(paySum) + uint64(baseProtoSum) + wide += uint64(segSeq) + uint64(segFlags) + uint64(tcpLen) + wide = (wide & 0xffffffff) + (wide >> 32) + wide = (wide & 0xffffffff) + (wide >> 32) + binary.BigEndian.PutUint16(seg[csumStart+16:csumStart+18], checksumFold(uint32(wide))) *out = append(*out, seg) } @@ -231,28 +271,27 @@ func segmentTCP(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) err // checksumBytes returns the Internet-checksum partial sum of b, seeded with // initial. Result is a 32-bit accumulator; the caller folds to 16. // -// Wide-word variant: each 8-byte load contributes four 16-bit lanes to a -// 64-bit accumulator, cutting the number of loads, shifts, and slice reslices -// ~4x versus the naive Uint16 loop. The 64-bit accumulator has ample headroom -// — worst case is (initial=2^32) + (64KiB / 2) * 0xffff ≈ 2.5 * 10^9, far -// below 2^64 — so no mid-loop fold is needed. +// Each 4-byte load is added directly into a 64-bit accumulator. Two parallel +// accumulators break the serial dependency through `sum` and let the CPU +// overlap independent adds. The final fold from 64 → 32 → 16 handles the +// carries that accumulated across the 32-bit lane boundary. func checksumBytes(b []byte, initial uint32) uint32 { - sum := uint64(initial) - for len(b) >= 16 { - w1 := binary.BigEndian.Uint64(b[:8]) - w2 := binary.BigEndian.Uint64(b[8:16]) - sum += (w1 >> 48) + ((w1 >> 32) & 0xffff) + ((w1 >> 16) & 0xffff) + (w1 & 0xffff) - sum += (w2 >> 48) + ((w2 >> 32) & 0xffff) + ((w2 >> 16) & 0xffff) + (w2 & 0xffff) - b = b[16:] + s0 := uint64(initial) + var s1 uint64 + for len(b) >= 32 { + s0 += uint64(binary.BigEndian.Uint32(b[0:4])) + s1 += uint64(binary.BigEndian.Uint32(b[4:8])) + s0 += uint64(binary.BigEndian.Uint32(b[8:12])) + s1 += uint64(binary.BigEndian.Uint32(b[12:16])) + s0 += uint64(binary.BigEndian.Uint32(b[16:20])) + s1 += uint64(binary.BigEndian.Uint32(b[20:24])) + s0 += uint64(binary.BigEndian.Uint32(b[24:28])) + s1 += uint64(binary.BigEndian.Uint32(b[28:32])) + b = b[32:] } - if len(b) >= 8 { - w := binary.BigEndian.Uint64(b[:8]) - sum += (w >> 48) + ((w >> 32) & 0xffff) + ((w >> 16) & 0xffff) + (w & 0xffff) - b = b[8:] - } - if len(b) >= 4 { - w := binary.BigEndian.Uint32(b[:4]) - sum += uint64(w>>16) + uint64(w&0xffff) + sum := s0 + s1 + for len(b) >= 4 { + sum += uint64(binary.BigEndian.Uint32(b[:4])) b = b[4:] } if len(b) >= 2 { @@ -262,8 +301,6 @@ func checksumBytes(b []byte, initial uint32) uint32 { if len(b) == 1 { sum += uint64(b[0]) << 8 } - // Fold 64 → 32. The checksum is one's complement, so carries are - // end-around-added; once the high 32 bits are zero we're done. sum = (sum & 0xffffffff) + (sum >> 32) sum = (sum & 0xffffffff) + (sum >> 32) return uint32(sum) diff --git a/overlay/tun_linux_offload_test.go b/overlay/tun_linux_offload_test.go index da596011..f3f9da40 100644 --- a/overlay/tun_linux_offload_test.go +++ b/overlay/tun_linux_offload_test.go @@ -247,6 +247,62 @@ func TestSegmentRejectsUDP(t *testing.T) { } } +func BenchmarkSegmentTCPv4(b *testing.B) { + sizes := []struct { + name string + payLen int + mss int + }{ + {"64KiB_MSS1460", 65000, 1460}, + {"16KiB_MSS1460", 16384, 1460}, + {"4KiB_MSS1460", 4096, 1460}, + } + for _, sz := range sizes { + b.Run(sz.name, func(b *testing.B) { + const ipLen = 20 + const tcpLen = 20 + pkt := make([]byte, ipLen+tcpLen+sz.payLen) + pkt[0] = 0x45 + binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+tcpLen+sz.payLen)) + binary.BigEndian.PutUint16(pkt[4:6], 0x4242) + pkt[8] = 64 + pkt[9] = unix.IPPROTO_TCP + copy(pkt[12:16], []byte{10, 0, 0, 1}) + copy(pkt[16:20], []byte{10, 0, 0, 2}) + binary.BigEndian.PutUint16(pkt[20:22], 12345) + binary.BigEndian.PutUint16(pkt[22:24], 80) + binary.BigEndian.PutUint32(pkt[24:28], 10000) + binary.BigEndian.PutUint32(pkt[28:32], 20000) + pkt[32] = 0x50 + pkt[33] = 0x18 + binary.BigEndian.PutUint16(pkt[34:36], 65535) + for i := 0; i < sz.payLen; i++ { + pkt[ipLen+tcpLen+i] = byte(i) + } + hdr := virtioNetHdr{ + Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4, + HdrLen: uint16(ipLen + tcpLen), + GSOSize: uint16(sz.mss), + CsumStart: uint16(ipLen), + CsumOffset: 16, + } + + scratch := make([]byte, tunSegBufSize) + out := make([][]byte, 0, 64) + + b.SetBytes(int64(len(pkt))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + out = out[:0] + if err := segmentTCP(pkt, hdr, &out, scratch); err != nil { + b.Fatal(err) + } + } + }) + } +} + // TestTunFileWriteVnetHdrNoAlloc verifies the IFF_VNET_HDR fast-path write is // allocation-free. We write to /dev/null so every call succeeds synchronously. func TestTunFileWriteVnetHdrNoAlloc(t *testing.T) {