GSO again

This commit is contained in:
JackDoan
2026-04-17 10:25:05 -05:00
parent 6ee5e18d84
commit d0825514a0
31 changed files with 3278 additions and 164 deletions

View File

@@ -0,0 +1,331 @@
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
package overlay
import (
"encoding/binary"
"fmt"
"golang.org/x/sys/unix"
)
// Size of the legacy struct virtio_net_hdr that the kernel prepends/expects on
// a TUN opened with IFF_VNET_HDR (TUNSETVNETHDRSZ not set).
const virtioNetHdrLen = 10
// Maximum size we accept for a single read from a TUN with IFF_VNET_HDR. A
// TSO superpacket can be up to 64KiB of payload plus a single L2/L3/L4 header
// prefix plus the virtio header.
const tunReadBufSize = 65535
// Space for segmented output. Worst case is many small segments, each paying
// an IP+TCP header. 128KiB comfortably covers the 64KiB payload ceiling.
const tunSegBufSize = 131072
// tunSegBufCap is the total size we allocate for the per-reader segment
// buffer. It is sized as one worst-case TSO superpacket (tunSegBufSize) plus
// the same again as drain headroom so a Read wake can accumulate
// additional packets after an initial big read without overflowing.
const tunSegBufCap = tunSegBufSize * 2
// tunDrainCap caps how many packets a single Read will accumulate via
// the post-wake drain loop. Sized to soak up a burst of small ACKs while
// bounding how much work a single caller holds before handing off.
const tunDrainCap = 64
type virtioNetHdr struct {
Flags uint8
GSOType uint8
HdrLen uint16
GSOSize uint16
CsumStart uint16
CsumOffset uint16
}
// decode reads a virtio_net_hdr in host byte order (TUN default; we never
// call TUNSETVNETLE so the kernel matches our endianness).
func (h *virtioNetHdr) decode(b []byte) {
h.Flags = b[0]
h.GSOType = b[1]
h.HdrLen = binary.NativeEndian.Uint16(b[2:4])
h.GSOSize = binary.NativeEndian.Uint16(b[4:6])
h.CsumStart = binary.NativeEndian.Uint16(b[6:8])
h.CsumOffset = binary.NativeEndian.Uint16(b[8:10])
}
// encode is the inverse of decode: writes the virtio_net_hdr fields into b
// (must be at least virtioNetHdrLen bytes). Used to emit a TSO superpacket
// on egress.
func (h *virtioNetHdr) encode(b []byte) {
b[0] = h.Flags
b[1] = h.GSOType
binary.NativeEndian.PutUint16(b[2:4], h.HdrLen)
binary.NativeEndian.PutUint16(b[4:6], h.GSOSize)
binary.NativeEndian.PutUint16(b[6:8], h.CsumStart)
binary.NativeEndian.PutUint16(b[8:10], h.CsumOffset)
}
// segmentInto splits a TUN-side packet described by hdr into one or more
// IP packets, each appended to *out as a slice of scratch. scratch must be
// sized to hold every segment (including replicated headers).
func segmentInto(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) error {
// When RSC_INFO is set the csum_start/csum_offset fields are repurposed to
// carry coalescing info rather than checksum offsets. A TUN writing via
// IFF_VNET_HDR should never emit this, but if it did we would silently
// miscompute the segment checksums — refuse the packet instead.
if hdr.Flags&unix.VIRTIO_NET_HDR_F_RSC_INFO != 0 {
return fmt.Errorf("virtio RSC_INFO flag not supported on TUN reads")
}
switch hdr.GSOType {
case unix.VIRTIO_NET_HDR_GSO_NONE:
if len(pkt) > len(scratch) {
return fmt.Errorf("packet larger than segment buffer: %d > %d", len(pkt), len(scratch))
}
copy(scratch, pkt)
seg := scratch[:len(pkt)]
if hdr.Flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
if err := finishChecksum(seg, hdr); err != nil {
return err
}
}
*out = append(*out, seg)
return nil
case unix.VIRTIO_NET_HDR_GSO_TCPV4, unix.VIRTIO_NET_HDR_GSO_TCPV6:
return segmentTCP(pkt, hdr, out, scratch)
default:
return fmt.Errorf("unsupported virtio gso type: %d", hdr.GSOType)
}
}
// finishChecksum computes the L4 checksum for a non-GSO packet that the kernel
// handed us with NEEDS_CSUM set. csum_start / csum_offset point at the 16-bit
// checksum field; we zero it, fold a full sum (the field was pre-loaded with
// the pseudo-header partial sum by the kernel), and store the result.
func finishChecksum(seg []byte, hdr virtioNetHdr) error {
cs := int(hdr.CsumStart)
co := int(hdr.CsumOffset)
if cs+co+2 > len(seg) {
return fmt.Errorf("csum offsets out of range: start=%d offset=%d len=%d", cs, co, len(seg))
}
// The kernel stores a partial pseudo-header sum at [cs+co:]; sum over the
// L4 region starting at cs, folding the prior partial in as the seed.
partial := uint32(binary.BigEndian.Uint16(seg[cs+co : cs+co+2]))
seg[cs+co] = 0
seg[cs+co+1] = 0
sum := checksumBytes(seg[cs:], partial)
binary.BigEndian.PutUint16(seg[cs+co:cs+co+2], checksumFold(sum))
return nil
}
// segmentTCP software-segments a TSO superpacket into one IP packet per MSS
// chunk. The caller guarantees hdr.GSOType is TCPV4 or TCPV6.
//
// Hot-path shape: the per-segment loop only sums the payload chunk. The TCP
// header, the IPv4 header, and the pseudo-header src/dst/proto contributions
// are each summed once up front — every segment reuses those three pre-folded
// uint32 values and combines them with small per-segment deltas (seq, flags,
// tcpLen, ip_id, total_len) that are cheap to fold in.
func segmentTCP(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) error {
if hdr.GSOSize == 0 {
return fmt.Errorf("gso_size is zero")
}
if int(hdr.HdrLen) > len(pkt) || hdr.HdrLen == 0 {
return fmt.Errorf("hdr_len %d out of range (pkt %d)", hdr.HdrLen, len(pkt))
}
if hdr.CsumStart == 0 || hdr.CsumStart >= hdr.HdrLen {
return fmt.Errorf("csum_start %d out of range (hdr_len %d)", hdr.CsumStart, hdr.HdrLen)
}
isV4 := hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_TCPV4
headerLen := int(hdr.HdrLen)
csumStart := int(hdr.CsumStart)
if isV4 && csumStart < 20 {
return fmt.Errorf("csum_start %d too small for IPv4", csumStart)
}
if !isV4 && csumStart < 40 {
return fmt.Errorf("csum_start %d too small for IPv6", csumStart)
}
tcpHdrLen := headerLen - csumStart
if tcpHdrLen < 20 {
return fmt.Errorf("tcp header region too small: %d", tcpHdrLen)
}
payload := pkt[headerLen:]
payLen := len(payload)
gso := int(hdr.GSOSize)
numSeg := (payLen + gso - 1) / gso
if numSeg == 0 {
numSeg = 1
}
need := numSeg*headerLen + payLen
if need > len(scratch) {
return fmt.Errorf("scratch too small for %d segments: need %d have %d", numSeg, need, len(scratch))
}
origSeq := binary.BigEndian.Uint32(pkt[csumStart+4 : csumStart+8])
origFlags := pkt[csumStart+13]
const tcpFinPsh = 0x09 // FIN(0x01) | PSH(0x08)
// Precompute the TCP header sum with seq/flags/csum zeroed. The max TCP
// header is 60 bytes; copy onto the stack, zero the per-segment-varying
// fields, sum once.
var tmp [60]byte
copy(tmp[:tcpHdrLen], pkt[csumStart:headerLen])
tmp[4], tmp[5], tmp[6], tmp[7] = 0, 0, 0, 0 // seq
tmp[13] = 0 // flags
tmp[16], tmp[17] = 0, 0 // csum
baseTcpHdrSum := checksumBytes(tmp[:tcpHdrLen], 0)
// Pseudo-header src+dst+proto contribution (tcpLen varies per segment).
var baseProtoSum uint32
if isV4 {
baseProtoSum = checksumBytes(pkt[12:16], 0)
baseProtoSum = checksumBytes(pkt[16:20], baseProtoSum)
} else {
baseProtoSum = checksumBytes(pkt[8:24], 0)
baseProtoSum = checksumBytes(pkt[24:40], baseProtoSum)
}
baseProtoSum += uint32(unix.IPPROTO_TCP)
// Precompute IPv4 header sum with total_len/id/csum zeroed.
var origIPID uint16
var ihl int
var baseIPHdrSum uint32
if isV4 {
origIPID = binary.BigEndian.Uint16(pkt[4:6])
ihl = int(pkt[0]&0x0f) * 4
if ihl < 20 || ihl > csumStart {
return fmt.Errorf("bad IPv4 IHL: %d", ihl)
}
var ipTmp [60]byte
copy(ipTmp[:ihl], pkt[:ihl])
ipTmp[2], ipTmp[3] = 0, 0 // total_len
ipTmp[4], ipTmp[5] = 0, 0 // id
ipTmp[10], ipTmp[11] = 0, 0 // checksum
baseIPHdrSum = checksumBytes(ipTmp[:ihl], 0)
}
off := 0
for i := 0; i < numSeg; i++ {
segStart := i * gso
segEnd := segStart + gso
if segEnd > payLen {
segEnd = payLen
}
segPayLen := segEnd - segStart
copy(scratch[off:], pkt[:headerLen])
copy(scratch[off+headerLen:], payload[segStart:segEnd])
seg := scratch[off : off+headerLen+segPayLen]
off += headerLen + segPayLen
segSeq := origSeq + uint32(segStart)
segFlags := origFlags
if i != numSeg-1 {
segFlags = origFlags &^ tcpFinPsh
}
totalLen := headerLen + segPayLen
// Patch IP header and write the v4 header checksum from the precomputed base.
if isV4 {
segID := origIPID + uint16(i)
binary.BigEndian.PutUint16(seg[2:4], uint16(totalLen))
binary.BigEndian.PutUint16(seg[4:6], segID)
ipSum := baseIPHdrSum + uint32(totalLen) + uint32(segID)
binary.BigEndian.PutUint16(seg[10:12], checksumFold(ipSum))
} else {
// IPv6 payload length excludes the 40-byte fixed header but
// includes any extension headers between [40:csumStart].
binary.BigEndian.PutUint16(seg[4:6], uint16(headerLen-40+segPayLen))
}
// Patch TCP header.
binary.BigEndian.PutUint32(seg[csumStart+4:csumStart+8], segSeq)
seg[csumStart+13] = segFlags
// (csum is written below; its prior contents in `seg` don't affect the
// computation since we never sum over the segment's own header.)
tcpLen := tcpHdrLen + segPayLen
paySum := checksumBytes(payload[segStart:segEnd], 0)
// Combine pre-folded uint32s into a wider accumulator, then fold. Using
// uint64 guards against overflow when segSeq's high bits set.
wide := uint64(baseTcpHdrSum) + uint64(paySum) + uint64(baseProtoSum)
wide += uint64(segSeq) + uint64(segFlags) + uint64(tcpLen)
wide = (wide & 0xffffffff) + (wide >> 32)
wide = (wide & 0xffffffff) + (wide >> 32)
binary.BigEndian.PutUint16(seg[csumStart+16:csumStart+18], checksumFold(uint32(wide)))
*out = append(*out, seg)
}
return nil
}
// checksumBytes returns the Internet-checksum partial sum of b, seeded with
// initial. Result is a 32-bit accumulator; the caller folds to 16.
//
// Each 4-byte load is added directly into a 64-bit accumulator. Two parallel
// accumulators break the serial dependency through `sum` and let the CPU
// overlap independent adds. The final fold from 64 → 32 → 16 handles the
// carries that accumulated across the 32-bit lane boundary.
func checksumBytes(b []byte, initial uint32) uint32 {
s0 := uint64(initial)
var s1 uint64
for len(b) >= 32 {
s0 += uint64(binary.BigEndian.Uint32(b[0:4]))
s1 += uint64(binary.BigEndian.Uint32(b[4:8]))
s0 += uint64(binary.BigEndian.Uint32(b[8:12]))
s1 += uint64(binary.BigEndian.Uint32(b[12:16]))
s0 += uint64(binary.BigEndian.Uint32(b[16:20]))
s1 += uint64(binary.BigEndian.Uint32(b[20:24]))
s0 += uint64(binary.BigEndian.Uint32(b[24:28]))
s1 += uint64(binary.BigEndian.Uint32(b[28:32]))
b = b[32:]
}
sum := s0 + s1
for len(b) >= 4 {
sum += uint64(binary.BigEndian.Uint32(b[:4]))
b = b[4:]
}
if len(b) >= 2 {
sum += uint64(binary.BigEndian.Uint16(b[:2]))
b = b[2:]
}
if len(b) == 1 {
sum += uint64(b[0]) << 8
}
sum = (sum & 0xffffffff) + (sum >> 32)
sum = (sum & 0xffffffff) + (sum >> 32)
return uint32(sum)
}
func checksumFold(sum uint32) uint16 {
for sum>>16 != 0 {
sum = (sum & 0xffff) + (sum >> 16)
}
return ^uint16(sum)
}
func pseudoHeaderIPv4(src, dst []byte, proto byte, tcpLen int) uint32 {
sum := checksumBytes(src, 0)
sum = checksumBytes(dst, sum)
sum += uint32(proto)
sum += uint32(tcpLen)
return sum
}
func pseudoHeaderIPv6(src, dst []byte, proto byte, tcpLen int) uint32 {
sum := checksumBytes(src, 0)
sum = checksumBytes(dst, sum)
sum += uint32(tcpLen >> 16)
sum += uint32(tcpLen & 0xffff)
sum += uint32(proto)
return sum
}