From 8fd724d762882c9154a6fef9c2204c7b9143a44b Mon Sep 17 00:00:00 2001 From: JackDoan Date: Fri, 24 Apr 2026 16:06:11 -0500 Subject: [PATCH] fix? --- overlay/tio/tun_linux_offload.go | 120 +++++++++++++++++++++---------- 1 file changed, 81 insertions(+), 39 deletions(-) diff --git a/overlay/tio/tun_linux_offload.go b/overlay/tio/tun_linux_offload.go index b060a2ba..6e30dec0 100644 --- a/overlay/tio/tun_linux_offload.go +++ b/overlay/tio/tun_linux_offload.go @@ -11,6 +11,42 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/checksum" ) +// Protocol header size bounds used to validate / cap kernel-supplied offsets. +const ( + ipv4HeaderMinLen = 20 // IHL=5, no options + ipv4HeaderMaxLen = 60 // IHL=15, max options + ipv6FixedLen = 40 // IPv6 base header; extensions would extend this + tcpHeaderMinLen = 20 // data-offset=5, no options + tcpHeaderMaxLen = 60 // data-offset=15, max options +) + +// Byte offsets inside an IPv4 header. +const ( + ipv4TotalLenOff = 2 + ipv4IDOff = 4 + ipv4ChecksumOff = 10 + ipv4SrcOff = 12 + ipv4AddrsEnd = 20 // end of dst address (ipv4SrcOff + 2*4) +) + +// Byte offsets inside an IPv6 header. +const ( + ipv6PayloadLenOff = 4 + ipv6SrcOff = 8 + ipv6AddrsEnd = 40 // end of dst address (ipv6SrcOff + 2*16) +) + +// Byte offsets inside a TCP header (relative to its start, i.e. csumStart). +const ( + tcpSeqOff = 4 + tcpDataOffOff = 12 // upper nibble is header len in 32-bit words + tcpFlagsOff = 13 + tcpChecksumOff = 16 +) + +// tcpFinPshMask is cleared on every segment except the last of a TSO burst. +const tcpFinPshMask = 0x09 // FIN(0x01) | PSH(0x08) + // segmentInto splits a TUN-side packet described by hdr into one or more // IP packets, each appended to *out as a slice of scratch. scratch must be // sized to hold every segment (including replicated headers). @@ -77,26 +113,34 @@ func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) err if hdr.GSOSize == 0 { return fmt.Errorf("gso_size is zero") } - if int(hdr.HdrLen) > len(pkt) || hdr.HdrLen == 0 { - return fmt.Errorf("hdr_len %d out of range (pkt %d)", hdr.HdrLen, len(pkt)) - } - if hdr.CsumStart == 0 || hdr.CsumStart >= hdr.HdrLen { - return fmt.Errorf("csum_start %d out of range (hdr_len %d)", hdr.CsumStart, hdr.HdrLen) + if hdr.CsumStart == 0 { + return fmt.Errorf("csum_start is zero") } isV4 := hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_TCPV4 - headerLen := int(hdr.HdrLen) csumStart := int(hdr.CsumStart) - if isV4 && csumStart < 20 { + if isV4 && csumStart < ipv4HeaderMinLen { return fmt.Errorf("csum_start %d too small for IPv4", csumStart) } - if !isV4 && csumStart < 40 { + if !isV4 && csumStart < ipv6FixedLen { return fmt.Errorf("csum_start %d too small for IPv6", csumStart) } - tcpHdrLen := headerLen - csumStart - if tcpHdrLen < 20 { - return fmt.Errorf("tcp header region too small: %d", tcpHdrLen) + + // Don't trust hdr.HdrLen from the kernel: on some paths it can be set + // to the full length of the first packet rather than the true L3+L4 header length. + // Instead, read the TCP data-offset field from the packet itself and derive + // headerLen = csum_start + tcpHdrLen. Matches wireguard-go's approach. + if csumStart+tcpFlagsOff+1 > len(pkt) { + return fmt.Errorf("packet too short for tcp header at csum_start=%d (pkt %d)", csumStart, len(pkt)) + } + tcpHdrLen := int(pkt[csumStart+tcpDataOffOff]>>4) * 4 + if tcpHdrLen < tcpHeaderMinLen || tcpHdrLen > tcpHeaderMaxLen { + return fmt.Errorf("tcp data-offset out of range: %d", tcpHdrLen) + } + headerLen := csumStart + tcpHdrLen + if headerLen > len(pkt) { + return fmt.Errorf("derived hdr_len %d > pkt %d", headerLen, len(pkt)) } payload := pkt[headerLen:] @@ -112,26 +156,24 @@ func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) err return fmt.Errorf("scratch too small for %d segments: need %d have %d", numSeg, need, len(scratch)) } - origSeq := binary.BigEndian.Uint32(pkt[csumStart+4 : csumStart+8]) - origFlags := pkt[csumStart+13] - const tcpFinPsh = 0x09 // FIN(0x01) | PSH(0x08) + origSeq := binary.BigEndian.Uint32(pkt[csumStart+tcpSeqOff : csumStart+tcpSeqOff+4]) + origFlags := pkt[csumStart+tcpFlagsOff] - // Precompute the TCP header sum with seq/flags/csum zeroed. The max TCP - // header is 60 bytes; copy onto the stack, zero the per-segment-varying - // fields, sum once. - var tmp [60]byte + // Precompute the TCP header sum with seq/flags/csum zeroed. Copy onto + // the stack, zero the per-segment-varying fields, sum once. + var tmp [tcpHeaderMaxLen]byte copy(tmp[:tcpHdrLen], pkt[csumStart:headerLen]) - tmp[4], tmp[5], tmp[6], tmp[7] = 0, 0, 0, 0 // seq - tmp[13] = 0 // flags - tmp[16], tmp[17] = 0, 0 // csum + tmp[tcpSeqOff], tmp[tcpSeqOff+1], tmp[tcpSeqOff+2], tmp[tcpSeqOff+3] = 0, 0, 0, 0 + tmp[tcpFlagsOff] = 0 + tmp[tcpChecksumOff], tmp[tcpChecksumOff+1] = 0, 0 baseTcpHdrSum := uint32(checksum.Checksum(tmp[:tcpHdrLen], 0)) // Pseudo-header src+dst+proto contribution (tcpLen varies per segment). var baseProtoSum uint32 if isV4 { - baseProtoSum = uint32(checksum.Checksum(pkt[12:20], 0)) + baseProtoSum = uint32(checksum.Checksum(pkt[ipv4SrcOff:ipv4AddrsEnd], 0)) } else { - baseProtoSum = uint32(checksum.Checksum(pkt[8:40], 0)) + baseProtoSum = uint32(checksum.Checksum(pkt[ipv6SrcOff:ipv6AddrsEnd], 0)) } baseProtoSum += uint32(unix.IPPROTO_TCP) @@ -140,16 +182,16 @@ func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) err var ihl int var baseIPHdrSum uint32 if isV4 { - origIPID = binary.BigEndian.Uint16(pkt[4:6]) + origIPID = binary.BigEndian.Uint16(pkt[ipv4IDOff : ipv4IDOff+2]) ihl = int(pkt[0]&0x0f) * 4 - if ihl < 20 || ihl > csumStart { + if ihl < ipv4HeaderMinLen || ihl > csumStart { return fmt.Errorf("bad IPv4 IHL: %d", ihl) } - var ipTmp [60]byte + var ipTmp [ipv4HeaderMaxLen]byte copy(ipTmp[:ihl], pkt[:ihl]) - ipTmp[2], ipTmp[3] = 0, 0 // total_len - ipTmp[4], ipTmp[5] = 0, 0 // id - ipTmp[10], ipTmp[11] = 0, 0 // checksum + ipTmp[ipv4TotalLenOff], ipTmp[ipv4TotalLenOff+1] = 0, 0 + ipTmp[ipv4IDOff], ipTmp[ipv4IDOff+1] = 0, 0 + ipTmp[ipv4ChecksumOff], ipTmp[ipv4ChecksumOff+1] = 0, 0 baseIPHdrSum = uint32(checksum.Checksum(ipTmp[:ihl], 0)) } @@ -170,26 +212,26 @@ func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) err segSeq := origSeq + uint32(segStart) segFlags := origFlags if i != numSeg-1 { - segFlags = origFlags &^ tcpFinPsh + segFlags = origFlags &^ tcpFinPshMask } totalLen := headerLen + segPayLen // Patch IP header and write the v4 header checksum from the precomputed base. if isV4 { segID := origIPID + uint16(i) - binary.BigEndian.PutUint16(seg[2:4], uint16(totalLen)) - binary.BigEndian.PutUint16(seg[4:6], segID) + binary.BigEndian.PutUint16(seg[ipv4TotalLenOff:ipv4TotalLenOff+2], uint16(totalLen)) + binary.BigEndian.PutUint16(seg[ipv4IDOff:ipv4IDOff+2], segID) ipSum := baseIPHdrSum + uint32(totalLen) + uint32(segID) - binary.BigEndian.PutUint16(seg[10:12], foldComplement(ipSum)) + binary.BigEndian.PutUint16(seg[ipv4ChecksumOff:ipv4ChecksumOff+2], foldComplement(ipSum)) } else { - // IPv6 payload length excludes the 40-byte fixed header but - // includes any extension headers between [40:csumStart]. - binary.BigEndian.PutUint16(seg[4:6], uint16(headerLen-40+segPayLen)) + // IPv6 payload length excludes the fixed header but includes any + // extension headers between [ipv6FixedLen:csumStart]. + binary.BigEndian.PutUint16(seg[ipv6PayloadLenOff:ipv6PayloadLenOff+2], uint16(headerLen-ipv6FixedLen+segPayLen)) } // Patch TCP header. - binary.BigEndian.PutUint32(seg[csumStart+4:csumStart+8], segSeq) - seg[csumStart+13] = segFlags + binary.BigEndian.PutUint32(seg[csumStart+tcpSeqOff:csumStart+tcpSeqOff+4], segSeq) + seg[csumStart+tcpFlagsOff] = segFlags // (csum is written below; its prior contents in `seg` don't affect the // computation since we never sum over the segment's own header.) @@ -202,7 +244,7 @@ func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) err wide += uint64(segSeq) + uint64(segFlags) + uint64(tcpLen) wide = (wide & 0xffffffff) + (wide >> 32) wide = (wide & 0xffffffff) + (wide >> 32) - binary.BigEndian.PutUint16(seg[csumStart+16:csumStart+18], foldComplement(uint32(wide))) + binary.BigEndian.PutUint16(seg[csumStart+tcpChecksumOff:csumStart+tcpChecksumOff+2], foldComplement(uint32(wide))) *out = append(*out, seg) }