mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-16 04:47:38 +02:00
402 lines
14 KiB
Go
402 lines
14 KiB
Go
//go:build linux && !android && !e2e_testing
|
|
// +build linux,!android,!e2e_testing
|
|
|
|
// Package virtio implements the pure validation, header-correction, and
|
|
// per-segment slicing logic for kernel-supplied TSO/USO superpackets on
|
|
// IFF_VNET_HDR TUN devices. It is FD-free and depends only on the byte
|
|
// layout of the virtio_net_hdr and the IP/TCP/UDP headers it describes,
|
|
// so it can be unit-tested in isolation from the tio Queue runtime.
|
|
package virtio
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
|
|
"golang.org/x/sys/unix"
|
|
"gvisor.dev/gvisor/pkg/tcpip/checksum"
|
|
)
|
|
|
|
// Protocol header size bounds used to validate / cap kernel-supplied offsets.
|
|
const (
|
|
ipv4HeaderMinLen = 20 // IHL=5, no options
|
|
ipv4HeaderMaxLen = 60 // IHL=15, max options
|
|
ipv6FixedLen = 40 // IPv6 base header; extensions would extend this
|
|
tcpHeaderMinLen = 20 // data-offset=5, no options
|
|
tcpHeaderMaxLen = 60 // data-offset=15, max options
|
|
)
|
|
|
|
// Byte offsets inside an IPv4 header.
|
|
const (
|
|
ipv4TotalLenOff = 2
|
|
ipv4IDOff = 4
|
|
ipv4ChecksumOff = 10
|
|
ipv4SrcOff = 12
|
|
ipv4AddrsEnd = 20 // end of dst address (ipv4SrcOff + 2*4)
|
|
)
|
|
|
|
// Byte offsets inside an IPv6 header.
|
|
const (
|
|
ipv6PayloadLenOff = 4
|
|
ipv6SrcOff = 8
|
|
ipv6AddrsEnd = 40 // end of dst address (ipv6SrcOff + 2*16)
|
|
)
|
|
|
|
// Byte offsets inside a TCP header (relative to its start, i.e. csumStart).
|
|
const (
|
|
tcpSeqOff = 4
|
|
tcpDataOffOff = 12 // upper nibble is header len in 32-bit words
|
|
tcpFlagsOff = 13
|
|
tcpChecksumOff = 16
|
|
)
|
|
|
|
// UDP header is fixed at 8 bytes: {sport, dport, length, checksum}.
|
|
const (
|
|
udpHeaderLen = 8
|
|
udpLengthOff = 4
|
|
udpChecksumOff = 6
|
|
)
|
|
|
|
// tcpFinPshMask is cleared on every segment except the last of a TSO burst.
|
|
const tcpFinPshMask = 0x09 // FIN(0x01) | PSH(0x08)
|
|
|
|
// tcpCwrFlag is cleared on every segment except the first. Per RFC 3168
|
|
// §6.1.2 the CWR bit signals a one-shot transition (the sender just halved
|
|
// its window) and must appear on the first segment of a TSO burst only.
|
|
const tcpCwrFlag = 0x80
|
|
|
|
// CheckValid rejects packets whose virtio_net_hdr/IP combination would
|
|
// cause a downstream miscompute. The TUN should never emit RSC_INFO and
|
|
// the GSO type must agree with the IP version nibble.
|
|
func CheckValid(pkt []byte, hdr Hdr) error {
|
|
// When RSC_INFO is set the csum_start/csum_offset fields are repurposed to
|
|
// carry coalescing info rather than checksum offsets. A TUN writing via
|
|
// IFF_VNET_HDR should never emit this, but if it did we would silently
|
|
// miscompute the segment checksums — refuse the packet instead.
|
|
if hdr.Flags&unix.VIRTIO_NET_HDR_F_RSC_INFO != 0 {
|
|
return fmt.Errorf("virtio RSC_INFO flag not supported on TUN reads")
|
|
}
|
|
if len(pkt) < ipv4HeaderMinLen {
|
|
return fmt.Errorf("packet too short")
|
|
}
|
|
ipVersion := pkt[0] >> 4
|
|
switch hdr.GSOType {
|
|
case unix.VIRTIO_NET_HDR_GSO_TCPV4:
|
|
if ipVersion != 4 {
|
|
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
|
|
}
|
|
case unix.VIRTIO_NET_HDR_GSO_TCPV6:
|
|
if ipVersion != 6 {
|
|
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
|
|
}
|
|
case unix.VIRTIO_NET_HDR_GSO_UDP_L4:
|
|
// USO carries either v4 or v6; the leading nibble disambiguates.
|
|
if !(ipVersion == 4 || ipVersion == 6) {
|
|
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
|
|
}
|
|
default:
|
|
if !(ipVersion == 6 || ipVersion == 4) {
|
|
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// CorrectHdrLen rewrites hdr.HdrLen based on the actual transport header
|
|
// length read out of pkt. The kernel's hdr.HdrLen on the FORWARD path can
|
|
// be the length of the entire first packet, so we don't trust it.
|
|
func CorrectHdrLen(pkt []byte, hdr *Hdr) error {
|
|
// Thank you wireguard-go for documenting these edge-cases
|
|
// Don't trust hdr.hdrLen from the kernel as it can be equal to the length
|
|
// of the entire first packet when the kernel is handling it as part of a
|
|
// FORWARD path. Instead, parse the transport header length and add it onto
|
|
// csumStart, which is synonymous for IP header length.
|
|
|
|
if hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
|
|
hdr.HdrLen = hdr.CsumStart + 8
|
|
} else {
|
|
if len(pkt) <= int(hdr.CsumStart+tcpDataOffOff) {
|
|
return errors.New("packet is too short")
|
|
}
|
|
|
|
tcpHLen := uint16(pkt[hdr.CsumStart+tcpDataOffOff] >> 4 * 4)
|
|
if tcpHLen < 20 || tcpHLen > 60 {
|
|
// A TCP header must be between 20 and 60 bytes in length.
|
|
return fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
|
|
}
|
|
hdr.HdrLen = hdr.CsumStart + tcpHLen
|
|
}
|
|
|
|
if len(pkt) < int(hdr.HdrLen) {
|
|
return fmt.Errorf("length of packet (%d) < virtioNetHdr.HdrLen (%d)", len(pkt), hdr.HdrLen)
|
|
}
|
|
|
|
if hdr.HdrLen < hdr.CsumStart {
|
|
return fmt.Errorf("virtioNetHdr.HdrLen (%d) < virtioNetHdr.CsumStart (%d)", hdr.HdrLen, hdr.CsumStart)
|
|
}
|
|
cSumAt := int(hdr.CsumStart + hdr.CsumStart)
|
|
if cSumAt+1 >= len(pkt) {
|
|
return fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(pkt))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SegmentTCP walks a TSO superpacket pkt, yielding each segment as a
|
|
// slice into pkt itself. Per-segment plaintext is laid out by sliding a
|
|
// freshly-patched copy of the L3+L4 header into pkt at offset i*gsoSize,
|
|
// where it sits immediately before that segment's payload chunk in the
|
|
// original buffer. The slide is destructive: iter i's header write overwrites
|
|
// the last hdrLen bytes of seg_{i-1}'s payload, which is dead by the time
|
|
// the next iteration begins. pkt is consumed by this call and must not be
|
|
// inspected by the caller after the final yield.
|
|
func SegmentTCP(pkt []byte, hdrLenU, csumStartU, gsoSizeU uint16, yield func(seg []byte) error) error {
|
|
if gsoSizeU == 0 {
|
|
return fmt.Errorf("gso_size is zero")
|
|
}
|
|
if csumStartU == 0 {
|
|
return fmt.Errorf("csum_start is zero")
|
|
}
|
|
|
|
headerLen := int(hdrLenU)
|
|
csumStart := int(csumStartU)
|
|
isV4 := pkt[0]>>4 == 4
|
|
|
|
tcpHdrLen := int(pkt[csumStart+tcpDataOffOff]>>4) * 4
|
|
payLen := len(pkt) - headerLen
|
|
gsoSize := int(gsoSizeU)
|
|
numSeg := (payLen + gsoSize - 1) / gsoSize
|
|
if numSeg == 0 {
|
|
numSeg = 1
|
|
}
|
|
|
|
origSeq := binary.BigEndian.Uint32(pkt[csumStart+tcpSeqOff : csumStart+tcpSeqOff+4])
|
|
origFlags := pkt[csumStart+tcpFlagsOff]
|
|
|
|
var tmp [tcpHeaderMaxLen]byte
|
|
copy(tmp[:tcpHdrLen], pkt[csumStart:headerLen])
|
|
tmp[tcpSeqOff], tmp[tcpSeqOff+1], tmp[tcpSeqOff+2], tmp[tcpSeqOff+3] = 0, 0, 0, 0
|
|
tmp[tcpFlagsOff] = 0
|
|
tmp[tcpChecksumOff], tmp[tcpChecksumOff+1] = 0, 0
|
|
baseTcpHdrSum := uint32(checksum.Checksum(tmp[:tcpHdrLen], 0))
|
|
|
|
var baseProtoSum uint32
|
|
if isV4 {
|
|
baseProtoSum = uint32(checksum.Checksum(pkt[ipv4SrcOff:ipv4AddrsEnd], 0))
|
|
} else {
|
|
baseProtoSum = uint32(checksum.Checksum(pkt[ipv6SrcOff:ipv6AddrsEnd], 0))
|
|
}
|
|
baseProtoSum += uint32(unix.IPPROTO_TCP)
|
|
|
|
var origIPID uint16
|
|
var baseIPHdrSum uint32
|
|
if isV4 {
|
|
origIPID = binary.BigEndian.Uint16(pkt[ipv4IDOff : ipv4IDOff+2])
|
|
ihl := int(pkt[0]&0x0f) * 4
|
|
if ihl < ipv4HeaderMinLen || ihl > csumStart {
|
|
return fmt.Errorf("bad IPv4 IHL: %d", ihl)
|
|
}
|
|
var ipTmp [ipv4HeaderMaxLen]byte
|
|
copy(ipTmp[:ihl], pkt[:ihl])
|
|
ipTmp[ipv4TotalLenOff], ipTmp[ipv4TotalLenOff+1] = 0, 0
|
|
ipTmp[ipv4IDOff], ipTmp[ipv4IDOff+1] = 0, 0
|
|
ipTmp[ipv4ChecksumOff], ipTmp[ipv4ChecksumOff+1] = 0, 0
|
|
baseIPHdrSum = uint32(checksum.Checksum(ipTmp[:ihl], 0))
|
|
}
|
|
|
|
for i := 0; i < numSeg; i++ {
|
|
segStart := i * gsoSize
|
|
segEnd := segStart + gsoSize
|
|
if segEnd > payLen {
|
|
segEnd = payLen
|
|
}
|
|
segPayLen := segEnd - segStart
|
|
segLen := headerLen + segPayLen
|
|
headerOff := i * gsoSize
|
|
|
|
// Slide the header into place immediately before this segment's
|
|
// payload. Iter 0's header is already at pkt[:headerLen]; for
|
|
// i ≥ 1 we copy from there. The constant-byte fields of pkt[:headerLen]
|
|
// survive iter 0's in-place patches (only seq/flags/cksum/totalLen/id
|
|
// are touched), and iter 0's stale variable-field values are
|
|
// overwritten by the per-segment patches below.
|
|
if i > 0 {
|
|
copy(pkt[headerOff:headerOff+headerLen], pkt[:headerLen])
|
|
}
|
|
seg := pkt[headerOff : headerOff+segLen]
|
|
|
|
segSeq := origSeq + uint32(segStart)
|
|
segFlags := origFlags
|
|
if i != 0 {
|
|
segFlags &^= tcpCwrFlag
|
|
}
|
|
if i != numSeg-1 {
|
|
segFlags &^= tcpFinPshMask
|
|
}
|
|
totalLen := segLen
|
|
|
|
if isV4 {
|
|
segID := origIPID + uint16(i)
|
|
binary.BigEndian.PutUint16(seg[ipv4TotalLenOff:ipv4TotalLenOff+2], uint16(totalLen))
|
|
binary.BigEndian.PutUint16(seg[ipv4IDOff:ipv4IDOff+2], segID)
|
|
ipSum := baseIPHdrSum + uint32(totalLen) + uint32(segID)
|
|
binary.BigEndian.PutUint16(seg[ipv4ChecksumOff:ipv4ChecksumOff+2], foldComplement(ipSum))
|
|
} else {
|
|
binary.BigEndian.PutUint16(seg[ipv6PayloadLenOff:ipv6PayloadLenOff+2], uint16(headerLen-ipv6FixedLen+segPayLen))
|
|
}
|
|
|
|
binary.BigEndian.PutUint32(seg[csumStart+tcpSeqOff:csumStart+tcpSeqOff+4], segSeq)
|
|
seg[csumStart+tcpFlagsOff] = segFlags
|
|
|
|
tcpLen := tcpHdrLen + segPayLen
|
|
// Payload bytes still live at their original offset in pkt. The
|
|
// header slide above only writes into pkt[i*G : i*G+H], which is
|
|
// the tail of seg_{i-1}'s payload (already consumed) and never
|
|
// overlaps seg_i's own payload at pkt[H+i*G : H+(i+1)*G].
|
|
paySum := uint32(checksum.Checksum(pkt[headerLen+segStart:headerLen+segEnd], 0))
|
|
wide := uint64(baseTcpHdrSum) + uint64(paySum) + uint64(baseProtoSum)
|
|
wide += uint64(segSeq) + uint64(segFlags) + uint64(tcpLen)
|
|
wide = (wide & 0xffffffff) + (wide >> 32)
|
|
wide = (wide & 0xffffffff) + (wide >> 32)
|
|
binary.BigEndian.PutUint16(seg[csumStart+tcpChecksumOff:csumStart+tcpChecksumOff+2], foldComplement(uint32(wide)))
|
|
|
|
if err := yield(seg); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SegmentUDP walks a USO superpacket, sliding a per-segment-patched
|
|
// L3+L4 header into pkt at offset i*gsoSize and yielding pkt[i*G:i*G+segLen]
|
|
// to the caller. Per-segment patches are total_len + IPv4 csum (or IPv6
|
|
// payload_len) plus the UDP length and checksum. pkt is consumed
|
|
// destructively; see SegmentTCP for the layout reasoning.
|
|
//
|
|
// UDP-GSO leaves the IPv4 ID identical across segments (the kernel does not
|
|
// bump it), which is why the IP-level per-segment work is limited to
|
|
// total_len + IPv4 header checksum (v4) or payload_len (v6).
|
|
func SegmentUDP(pkt []byte, hdrLenU, csumStartU, gsoSizeU uint16, yield func(seg []byte) error) error {
|
|
if gsoSizeU == 0 {
|
|
return fmt.Errorf("gso_size is zero")
|
|
}
|
|
if csumStartU == 0 {
|
|
return fmt.Errorf("csum_start is zero")
|
|
}
|
|
|
|
isV4 := pkt[0]>>4 == 4
|
|
headerLen := int(hdrLenU)
|
|
csumStart := int(csumStartU)
|
|
if headerLen-csumStart != udpHeaderLen {
|
|
return fmt.Errorf("udp header len mismatch: %d", headerLen-csumStart)
|
|
}
|
|
|
|
payLen := len(pkt) - headerLen
|
|
gsoSize := int(gsoSizeU)
|
|
numSeg := (payLen + gsoSize - 1) / gsoSize
|
|
if numSeg == 0 {
|
|
numSeg = 1
|
|
}
|
|
|
|
var udpTmp [udpHeaderLen]byte
|
|
copy(udpTmp[:], pkt[csumStart:headerLen])
|
|
udpTmp[udpLengthOff], udpTmp[udpLengthOff+1] = 0, 0
|
|
udpTmp[udpChecksumOff], udpTmp[udpChecksumOff+1] = 0, 0
|
|
baseUDPHdrSum := uint32(checksum.Checksum(udpTmp[:], 0))
|
|
|
|
var baseProtoSum uint32
|
|
if isV4 {
|
|
baseProtoSum = uint32(checksum.Checksum(pkt[ipv4SrcOff:ipv4AddrsEnd], 0))
|
|
} else {
|
|
baseProtoSum = uint32(checksum.Checksum(pkt[ipv6SrcOff:ipv6AddrsEnd], 0))
|
|
}
|
|
baseProtoSum += uint32(unix.IPPROTO_UDP)
|
|
|
|
var baseIPHdrSum uint32
|
|
if isV4 {
|
|
ihl := int(pkt[0]&0x0f) * 4
|
|
if ihl < ipv4HeaderMinLen || ihl > csumStart {
|
|
return fmt.Errorf("bad IPv4 IHL: %d", ihl)
|
|
}
|
|
var ipTmp [ipv4HeaderMaxLen]byte
|
|
copy(ipTmp[:ihl], pkt[:ihl])
|
|
ipTmp[ipv4TotalLenOff], ipTmp[ipv4TotalLenOff+1] = 0, 0
|
|
ipTmp[ipv4ChecksumOff], ipTmp[ipv4ChecksumOff+1] = 0, 0
|
|
baseIPHdrSum = uint32(checksum.Checksum(ipTmp[:ihl], 0))
|
|
}
|
|
|
|
for i := 0; i < numSeg; i++ {
|
|
segStart := i * gsoSize
|
|
segEnd := segStart + gsoSize
|
|
if segEnd > payLen {
|
|
segEnd = payLen
|
|
}
|
|
segPayLen := segEnd - segStart
|
|
segLen := headerLen + segPayLen
|
|
headerOff := i * gsoSize
|
|
|
|
if i > 0 {
|
|
copy(pkt[headerOff:headerOff+headerLen], pkt[:headerLen])
|
|
}
|
|
seg := pkt[headerOff : headerOff+segLen]
|
|
|
|
totalLen := segLen
|
|
udpLen := udpHeaderLen + segPayLen
|
|
|
|
if isV4 {
|
|
binary.BigEndian.PutUint16(seg[ipv4TotalLenOff:ipv4TotalLenOff+2], uint16(totalLen))
|
|
ipSum := baseIPHdrSum + uint32(totalLen)
|
|
binary.BigEndian.PutUint16(seg[ipv4ChecksumOff:ipv4ChecksumOff+2], foldComplement(ipSum))
|
|
} else {
|
|
binary.BigEndian.PutUint16(seg[ipv6PayloadLenOff:ipv6PayloadLenOff+2], uint16(headerLen-ipv6FixedLen+segPayLen))
|
|
}
|
|
|
|
binary.BigEndian.PutUint16(seg[csumStart+udpLengthOff:csumStart+udpLengthOff+2], uint16(udpLen))
|
|
|
|
paySum := uint32(checksum.Checksum(pkt[headerLen+segStart:headerLen+segEnd], 0))
|
|
wide := uint64(baseUDPHdrSum) + uint64(paySum) + uint64(baseProtoSum)
|
|
wide += uint64(udpLen) + uint64(udpLen)
|
|
wide = (wide & 0xffffffff) + (wide >> 32)
|
|
wide = (wide & 0xffffffff) + (wide >> 32)
|
|
csum := foldComplement(uint32(wide))
|
|
if csum == 0 {
|
|
csum = 0xffff
|
|
}
|
|
binary.BigEndian.PutUint16(seg[csumStart+udpChecksumOff:csumStart+udpChecksumOff+2], csum)
|
|
|
|
if err := yield(seg); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// FinishChecksum computes the L4 checksum for a non-GSO packet that the kernel
|
|
// handed us with NEEDS_CSUM set. csum_start / csum_offset point at the 16-bit
|
|
// checksum field; we zero it, fold a full sum (the field was pre-loaded with
|
|
// the pseudo-header partial sum by the kernel), and store the result.
|
|
func FinishChecksum(seg []byte, hdr Hdr) error {
|
|
cs := int(hdr.CsumStart)
|
|
co := int(hdr.CsumOffset)
|
|
if cs+co+2 > len(seg) {
|
|
return fmt.Errorf("csum offsets out of range: start=%d offset=%d len=%d", cs, co, len(seg))
|
|
}
|
|
// The kernel stores a partial pseudo-header sum at [cs+co:]; sum over the
|
|
// L4 region starting at cs, folding the prior partial in as the seed.
|
|
partial := binary.BigEndian.Uint16(seg[cs+co : cs+co+2])
|
|
seg[cs+co] = 0
|
|
seg[cs+co+1] = 0
|
|
binary.BigEndian.PutUint16(seg[cs+co:cs+co+2], ^checksum.Checksum(seg[cs:], partial))
|
|
return nil
|
|
}
|
|
|
|
// foldComplement folds a 32-bit one's-complement partial sum to 16 bits and
|
|
// complements it, yielding the on-wire Internet checksum value.
|
|
func foldComplement(sum uint32) uint16 {
|
|
sum = (sum & 0xffff) + (sum >> 16)
|
|
sum = (sum & 0xffff) + (sum >> 16)
|
|
return ^uint16(sum)
|
|
}
|