diff --git a/overlay/tio/tun_linux_offload.go b/overlay/tio/tun_linux_offload.go index cc01b3e4..b060a2ba 100644 --- a/overlay/tio/tun_linux_offload.go +++ b/overlay/tio/tun_linux_offload.go @@ -8,6 +8,7 @@ import ( "fmt" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/checksum" ) // segmentInto splits a TUN-side packet described by hdr into one or more @@ -57,11 +58,10 @@ func finishChecksum(seg []byte, hdr VirtioNetHdr) error { } // The kernel stores a partial pseudo-header sum at [cs+co:]; sum over the // L4 region starting at cs, folding the prior partial in as the seed. - partial := uint32(binary.BigEndian.Uint16(seg[cs+co : cs+co+2])) + partial := binary.BigEndian.Uint16(seg[cs+co : cs+co+2]) seg[cs+co] = 0 seg[cs+co+1] = 0 - sum := checksumBytes(seg[cs:], partial) - binary.BigEndian.PutUint16(seg[cs+co:cs+co+2], checksumFold(sum)) + binary.BigEndian.PutUint16(seg[cs+co:cs+co+2], ^checksum.Checksum(seg[cs:], partial)) return nil } @@ -124,16 +124,14 @@ func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) err tmp[4], tmp[5], tmp[6], tmp[7] = 0, 0, 0, 0 // seq tmp[13] = 0 // flags tmp[16], tmp[17] = 0, 0 // csum - baseTcpHdrSum := checksumBytes(tmp[:tcpHdrLen], 0) + baseTcpHdrSum := uint32(checksum.Checksum(tmp[:tcpHdrLen], 0)) // Pseudo-header src+dst+proto contribution (tcpLen varies per segment). var baseProtoSum uint32 if isV4 { - baseProtoSum = checksumBytes(pkt[12:16], 0) - baseProtoSum = checksumBytes(pkt[16:20], baseProtoSum) + baseProtoSum = uint32(checksum.Checksum(pkt[12:20], 0)) } else { - baseProtoSum = checksumBytes(pkt[8:24], 0) - baseProtoSum = checksumBytes(pkt[24:40], baseProtoSum) + baseProtoSum = uint32(checksum.Checksum(pkt[8:40], 0)) } baseProtoSum += uint32(unix.IPPROTO_TCP) @@ -152,7 +150,7 @@ func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) err ipTmp[2], ipTmp[3] = 0, 0 // total_len ipTmp[4], ipTmp[5] = 0, 0 // id ipTmp[10], ipTmp[11] = 0, 0 // checksum - baseIPHdrSum = checksumBytes(ipTmp[:ihl], 0) + baseIPHdrSum = uint32(checksum.Checksum(ipTmp[:ihl], 0)) } off := 0 @@ -182,7 +180,7 @@ func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) err binary.BigEndian.PutUint16(seg[2:4], uint16(totalLen)) binary.BigEndian.PutUint16(seg[4:6], segID) ipSum := baseIPHdrSum + uint32(totalLen) + uint32(segID) - binary.BigEndian.PutUint16(seg[10:12], checksumFold(ipSum)) + binary.BigEndian.PutUint16(seg[10:12], foldComplement(ipSum)) } else { // IPv6 payload length excludes the 40-byte fixed header but // includes any extension headers between [40:csumStart]. @@ -196,7 +194,7 @@ func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) err // computation since we never sum over the segment's own header.) tcpLen := tcpHdrLen + segPayLen - paySum := checksumBytes(payload[segStart:segEnd], 0) + paySum := uint32(checksum.Checksum(payload[segStart:segEnd], 0)) // Combine pre-folded uint32s into a wider accumulator, then fold. Using // uint64 guards against overflow when segSeq's high bits set. @@ -204,7 +202,7 @@ func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) err wide += uint64(segSeq) + uint64(segFlags) + uint64(tcpLen) wide = (wide & 0xffffffff) + (wide >> 32) wide = (wide & 0xffffffff) + (wide >> 32) - binary.BigEndian.PutUint16(seg[csumStart+16:csumStart+18], checksumFold(uint32(wide))) + binary.BigEndian.PutUint16(seg[csumStart+16:csumStart+18], foldComplement(uint32(wide))) *out = append(*out, seg) } @@ -212,64 +210,30 @@ func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) err return nil } -// checksumBytes returns the Internet-checksum partial sum of b, seeded with -// initial. Result is a 32-bit accumulator; the caller folds to 16. -// -// Each 4-byte load is added directly into a 64-bit accumulator. Two parallel -// accumulators break the serial dependency through `sum` and let the CPU -// overlap independent adds. The final fold from 64 → 32 → 16 handles the -// carries that accumulated across the 32-bit lane boundary. -func checksumBytes(b []byte, initial uint32) uint32 { - s0 := uint64(initial) - var s1 uint64 - for len(b) >= 32 { - s0 += uint64(binary.BigEndian.Uint32(b[0:4])) - s1 += uint64(binary.BigEndian.Uint32(b[4:8])) - s0 += uint64(binary.BigEndian.Uint32(b[8:12])) - s1 += uint64(binary.BigEndian.Uint32(b[12:16])) - s0 += uint64(binary.BigEndian.Uint32(b[16:20])) - s1 += uint64(binary.BigEndian.Uint32(b[20:24])) - s0 += uint64(binary.BigEndian.Uint32(b[24:28])) - s1 += uint64(binary.BigEndian.Uint32(b[28:32])) - b = b[32:] - } - sum := s0 + s1 - for len(b) >= 4 { - sum += uint64(binary.BigEndian.Uint32(b[:4])) - b = b[4:] - } - if len(b) >= 2 { - sum += uint64(binary.BigEndian.Uint16(b[:2])) - b = b[2:] - } - if len(b) == 1 { - sum += uint64(b[0]) << 8 - } - sum = (sum & 0xffffffff) + (sum >> 32) - sum = (sum & 0xffffffff) + (sum >> 32) - return uint32(sum) -} - -func checksumFold(sum uint32) uint16 { - for sum>>16 != 0 { - sum = (sum & 0xffff) + (sum >> 16) - } +// foldComplement folds a 32-bit one's-complement partial sum to 16 bits and +// complements it, yielding the on-wire Internet checksum value. +func foldComplement(sum uint32) uint16 { + sum = (sum & 0xffff) + (sum >> 16) + sum = (sum & 0xffff) + (sum >> 16) return ^uint16(sum) } -func pseudoHeaderIPv4(src, dst []byte, proto byte, tcpLen int) uint32 { - sum := checksumBytes(src, 0) - sum = checksumBytes(dst, sum) - sum += uint32(proto) - sum += uint32(tcpLen) - return sum +// pseudoHeaderIPv4 returns the folded pseudo-header sum used to verify a TCP +// segment's checksum in tests. src/dst are 4 bytes each. +func pseudoHeaderIPv4(src, dst []byte, proto byte, tcpLen int) uint16 { + s := uint32(checksum.Checksum(src, 0)) + uint32(checksum.Checksum(dst, 0)) + s += uint32(proto) + uint32(tcpLen) + s = (s & 0xffff) + (s >> 16) + s = (s & 0xffff) + (s >> 16) + return uint16(s) } -func pseudoHeaderIPv6(src, dst []byte, proto byte, tcpLen int) uint32 { - sum := checksumBytes(src, 0) - sum = checksumBytes(dst, sum) - sum += uint32(tcpLen >> 16) - sum += uint32(tcpLen & 0xffff) - sum += uint32(proto) - return sum +// pseudoHeaderIPv6 returns the folded pseudo-header sum used to verify a TCP +// segment's checksum in tests. src/dst are 16 bytes each. +func pseudoHeaderIPv6(src, dst []byte, proto byte, tcpLen int) uint16 { + s := uint32(checksum.Checksum(src, 0)) + uint32(checksum.Checksum(dst, 0)) + s += uint32(tcpLen>>16) + uint32(tcpLen&0xffff) + uint32(proto) + s = (s & 0xffff) + (s >> 16) + s = (s & 0xffff) + (s >> 16) + return uint16(s) } diff --git a/overlay/tio/tun_linux_offload_test.go b/overlay/tio/tun_linux_offload_test.go index 252d823d..ff080b7c 100644 --- a/overlay/tio/tun_linux_offload_test.go +++ b/overlay/tio/tun_linux_offload_test.go @@ -9,16 +9,13 @@ import ( "testing" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/checksum" ) -// verifyChecksum confirms that the one's-complement sum across `b`, optionally -// seeded with a pseudo-header sum, folds to all-ones (valid). -func verifyChecksum(b []byte, pseudo uint32) bool { - sum := checksumBytes(b, pseudo) - for sum>>16 != 0 { - sum = (sum & 0xffff) + (sum >> 16) - } - return uint16(sum) == 0xffff +// verifyChecksum confirms that the one's-complement sum across `b`, seeded +// with a folded pseudo-header sum, equals all-ones (valid). +func verifyChecksum(b []byte, pseudo uint16) bool { + return checksum.Checksum(b, pseudo) == 0xffff } // buildTSOv4 builds a synthetic IPv4/TCP TSO superpacket with a payload of