//go:build linux && !android && !e2e_testing // +build linux,!android,!e2e_testing package tio import ( "encoding/binary" "errors" "fmt" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/tcpip/checksum" ) // Protocol header size bounds used to validate / cap kernel-supplied offsets. const ( ipv4HeaderMinLen = 20 // IHL=5, no options ipv4HeaderMaxLen = 60 // IHL=15, max options ipv6FixedLen = 40 // IPv6 base header; extensions would extend this tcpHeaderMinLen = 20 // data-offset=5, no options tcpHeaderMaxLen = 60 // data-offset=15, max options ) // Byte offsets inside an IPv4 header. const ( ipv4TotalLenOff = 2 ipv4IDOff = 4 ipv4ChecksumOff = 10 ipv4SrcOff = 12 ipv4AddrsEnd = 20 // end of dst address (ipv4SrcOff + 2*4) ) // Byte offsets inside an IPv6 header. const ( ipv6PayloadLenOff = 4 ipv6SrcOff = 8 ipv6AddrsEnd = 40 // end of dst address (ipv6SrcOff + 2*16) ) // Byte offsets inside a TCP header (relative to its start, i.e. csumStart). const ( tcpSeqOff = 4 tcpDataOffOff = 12 // upper nibble is header len in 32-bit words tcpFlagsOff = 13 tcpChecksumOff = 16 ) // tcpFinPshMask is cleared on every segment except the last of a TSO burst. const tcpFinPshMask = 0x09 // FIN(0x01) | PSH(0x08) func checkVirtioValid(pkt []byte, hdr VirtioNetHdr) error { // When RSC_INFO is set the csum_start/csum_offset fields are repurposed to // carry coalescing info rather than checksum offsets. A TUN writing via // IFF_VNET_HDR should never emit this, but if it did we would silently // miscompute the segment checksums — refuse the packet instead. if hdr.Flags&unix.VIRTIO_NET_HDR_F_RSC_INFO != 0 { return fmt.Errorf("virtio RSC_INFO flag not supported on TUN reads") } if len(pkt) < ipv4HeaderMinLen { return fmt.Errorf("packet too short") } ipVersion := pkt[0] >> 4 switch hdr.GSOType { case unix.VIRTIO_NET_HDR_GSO_TCPV4: if ipVersion != 4 { return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType) } case unix.VIRTIO_NET_HDR_GSO_TCPV6: if ipVersion != 6 { return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType) } default: if !(ipVersion == 6 || ipVersion == 4) { return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType) } } return nil } func handleGSONone(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) error { if len(pkt) > len(scratch) { return fmt.Errorf("packet larger than segment buffer: %d > %d", len(pkt), len(scratch)) } copy(scratch, pkt) seg := scratch[:len(pkt)] if hdr.Flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 { if err := finishChecksum(seg, hdr); err != nil { return err } } *out = append(*out, seg) return nil } func correctHdrLen(pkt []byte, hdr *VirtioNetHdr) error { // Thank you wireguard-go for documenting these edge-cases // Don't trust hdr.hdrLen from the kernel as it can be equal to the length // of the entire first packet when the kernel is handling it as part of a // FORWARD path. Instead, parse the transport header length and add it onto // csumStart, which is synonymous for IP header length. const tcpDataOffset = 12 if hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 { hdr.HdrLen = hdr.CsumStart + 8 } else { if len(pkt) <= int(hdr.CsumStart+tcpDataOffset) { return errors.New("packet is too short") } tcpHLen := uint16(pkt[hdr.CsumStart+tcpDataOffset] >> 4 * 4) if tcpHLen < 20 || tcpHLen > 60 { // A TCP header must be between 20 and 60 bytes in length. return fmt.Errorf("tcp header len is invalid: %d", tcpHLen) } hdr.HdrLen = hdr.CsumStart + tcpHLen } if len(pkt) < int(hdr.HdrLen) { return fmt.Errorf("length of packet (%d) < virtioNetHdr.HdrLen (%d)", len(pkt), hdr.HdrLen) } if hdr.HdrLen < hdr.CsumStart { return fmt.Errorf("virtioNetHdr.HdrLen (%d) < virtioNetHdr.CsumStart (%d)", hdr.HdrLen, hdr.CsumStart) } cSumAt := int(hdr.CsumStart + hdr.CsumStart) if cSumAt+1 >= len(pkt) { return fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(pkt)) } return nil } // segmentInto splits a TUN-side packet described by hdr into one or more // IP packets, each appended to *out as a slice of scratch. scratch must be // sized to hold every segment (including replicated headers). func segmentInto(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) error { if err := checkVirtioValid(pkt, hdr); err != nil { return err } if hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_NONE { return handleGSONone(pkt, hdr, out, scratch) } if err := correctHdrLen(pkt, &hdr); err != nil { return err } switch hdr.GSOType { case unix.VIRTIO_NET_HDR_GSO_TCPV4, unix.VIRTIO_NET_HDR_GSO_TCPV6: return segmentTCP(pkt, hdr, out, scratch) default: return fmt.Errorf("unsupported virtio gso type: %d", hdr.GSOType) } } // finishChecksum computes the L4 checksum for a non-GSO packet that the kernel // handed us with NEEDS_CSUM set. csum_start / csum_offset point at the 16-bit // checksum field; we zero it, fold a full sum (the field was pre-loaded with // the pseudo-header partial sum by the kernel), and store the result. func finishChecksum(seg []byte, hdr VirtioNetHdr) error { cs := int(hdr.CsumStart) co := int(hdr.CsumOffset) if cs+co+2 > len(seg) { return fmt.Errorf("csum offsets out of range: start=%d offset=%d len=%d", cs, co, len(seg)) } // The kernel stores a partial pseudo-header sum at [cs+co:]; sum over the // L4 region starting at cs, folding the prior partial in as the seed. partial := binary.BigEndian.Uint16(seg[cs+co : cs+co+2]) seg[cs+co] = 0 seg[cs+co+1] = 0 binary.BigEndian.PutUint16(seg[cs+co:cs+co+2], ^checksum.Checksum(seg[cs:], partial)) return nil } // segmentTCP software-segments a TSO superpacket into one IP packet per MSS // chunk. The caller guarantees hdr.GSOType is TCPV4 or TCPV6. // // Hot-path shape: the per-segment loop only sums the payload chunk. The TCP // header, the IPv4 header, and the pseudo-header src/dst/proto contributions // are each summed once up front — every segment reuses those three pre-folded // uint32 values and combines them with small per-segment deltas (seq, flags, // tcpLen, ip_id, total_len) that are cheap to fold in. func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) error { if hdr.GSOSize == 0 { return fmt.Errorf("gso_size is zero") } if hdr.CsumStart == 0 { return fmt.Errorf("csum_start is zero") } isV4 := hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_TCPV4 headerLen := int(hdr.HdrLen) // already corrected by the caller csumStart := int(hdr.CsumStart) tcpHdrLen := int(pkt[csumStart+tcpDataOffOff]>>4) * 4 payload := pkt[headerLen:] payLen := len(payload) gsoSize := int(hdr.GSOSize) numSeg := (payLen + gsoSize - 1) / gsoSize if numSeg == 0 { numSeg = 1 } need := numSeg*headerLen + payLen if need > len(scratch) { return fmt.Errorf("scratch too small for %d segments: need %d have %d", numSeg, need, len(scratch)) } origSeq := binary.BigEndian.Uint32(pkt[csumStart+tcpSeqOff : csumStart+tcpSeqOff+4]) origFlags := pkt[csumStart+tcpFlagsOff] // Precompute the TCP header sum with seq/flags/csum zeroed. Copy onto // the stack, zero the per-segment-varying fields, sum once. var tmp [tcpHeaderMaxLen]byte copy(tmp[:tcpHdrLen], pkt[csumStart:headerLen]) tmp[tcpSeqOff], tmp[tcpSeqOff+1], tmp[tcpSeqOff+2], tmp[tcpSeqOff+3] = 0, 0, 0, 0 tmp[tcpFlagsOff] = 0 tmp[tcpChecksumOff], tmp[tcpChecksumOff+1] = 0, 0 baseTcpHdrSum := uint32(checksum.Checksum(tmp[:tcpHdrLen], 0)) // Pseudo-header src+dst+proto contribution (tcpLen varies per segment). var baseProtoSum uint32 if isV4 { baseProtoSum = uint32(checksum.Checksum(pkt[ipv4SrcOff:ipv4AddrsEnd], 0)) } else { baseProtoSum = uint32(checksum.Checksum(pkt[ipv6SrcOff:ipv6AddrsEnd], 0)) } baseProtoSum += uint32(unix.IPPROTO_TCP) // Precompute IPv4 header sum with total_len/id/csum zeroed. var origIPID uint16 var ihl int var baseIPHdrSum uint32 if isV4 { origIPID = binary.BigEndian.Uint16(pkt[ipv4IDOff : ipv4IDOff+2]) ihl = int(pkt[0]&0x0f) * 4 if ihl < ipv4HeaderMinLen || ihl > csumStart { return fmt.Errorf("bad IPv4 IHL: %d", ihl) } var ipTmp [ipv4HeaderMaxLen]byte copy(ipTmp[:ihl], pkt[:ihl]) ipTmp[ipv4TotalLenOff], ipTmp[ipv4TotalLenOff+1] = 0, 0 ipTmp[ipv4IDOff], ipTmp[ipv4IDOff+1] = 0, 0 ipTmp[ipv4ChecksumOff], ipTmp[ipv4ChecksumOff+1] = 0, 0 baseIPHdrSum = uint32(checksum.Checksum(ipTmp[:ihl], 0)) } off := 0 for i := 0; i < numSeg; i++ { segStart := i * gsoSize segEnd := segStart + gsoSize if segEnd > payLen { segEnd = payLen } segPayLen := segEnd - segStart copy(scratch[off:], pkt[:headerLen]) copy(scratch[off+headerLen:], payload[segStart:segEnd]) seg := scratch[off : off+headerLen+segPayLen] off += headerLen + segPayLen segSeq := origSeq + uint32(segStart) segFlags := origFlags if i != numSeg-1 { segFlags = origFlags &^ tcpFinPshMask } totalLen := headerLen + segPayLen // Patch IP header and write the v4 header checksum from the precomputed base. if isV4 { segID := origIPID + uint16(i) binary.BigEndian.PutUint16(seg[ipv4TotalLenOff:ipv4TotalLenOff+2], uint16(totalLen)) binary.BigEndian.PutUint16(seg[ipv4IDOff:ipv4IDOff+2], segID) ipSum := baseIPHdrSum + uint32(totalLen) + uint32(segID) binary.BigEndian.PutUint16(seg[ipv4ChecksumOff:ipv4ChecksumOff+2], foldComplement(ipSum)) } else { // IPv6 payload length excludes the fixed header but includes any // extension headers between [ipv6FixedLen:csumStart]. binary.BigEndian.PutUint16(seg[ipv6PayloadLenOff:ipv6PayloadLenOff+2], uint16(headerLen-ipv6FixedLen+segPayLen)) } // Patch TCP header. binary.BigEndian.PutUint32(seg[csumStart+tcpSeqOff:csumStart+tcpSeqOff+4], segSeq) seg[csumStart+tcpFlagsOff] = segFlags // (csum is written below; its prior contents in `seg` don't affect the // computation since we never sum over the segment's own header.) tcpLen := tcpHdrLen + segPayLen paySum := uint32(checksum.Checksum(payload[segStart:segEnd], 0)) // Combine pre-folded uint32s into a wider accumulator, then fold. Using // uint64 guards against overflow when segSeq's high bits set. wide := uint64(baseTcpHdrSum) + uint64(paySum) + uint64(baseProtoSum) wide += uint64(segSeq) + uint64(segFlags) + uint64(tcpLen) wide = (wide & 0xffffffff) + (wide >> 32) wide = (wide & 0xffffffff) + (wide >> 32) binary.BigEndian.PutUint16(seg[csumStart+tcpChecksumOff:csumStart+tcpChecksumOff+2], foldComplement(uint32(wide))) *out = append(*out, seg) } return nil } // foldComplement folds a 32-bit one's-complement partial sum to 16 bits and // complements it, yielding the on-wire Internet checksum value. func foldComplement(sum uint32) uint16 { sum = (sum & 0xffff) + (sum >> 16) sum = (sum & 0xffff) + (sum >> 16) return ^uint16(sum) } // pseudoHeaderIPv4 returns the folded pseudo-header sum used to verify a TCP // segment's checksum in tests. src/dst are 4 bytes each. func pseudoHeaderIPv4(src, dst []byte, proto byte, tcpLen int) uint16 { s := uint32(checksum.Checksum(src, 0)) + uint32(checksum.Checksum(dst, 0)) s += uint32(proto) + uint32(tcpLen) s = (s & 0xffff) + (s >> 16) s = (s & 0xffff) + (s >> 16) return uint16(s) } // pseudoHeaderIPv6 returns the folded pseudo-header sum used to verify a TCP // segment's checksum in tests. src/dst are 16 bytes each. func pseudoHeaderIPv6(src, dst []byte, proto byte, tcpLen int) uint16 { s := uint32(checksum.Checksum(src, 0)) + uint32(checksum.Checksum(dst, 0)) s += uint32(tcpLen>>16) + uint32(tcpLen&0xffff) + uint32(proto) s = (s & 0xffff) + (s >> 16) s = (s & 0xffff) + (s >> 16) return uint16(s) }