mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-16 12:57:38 +02:00
fix?
This commit is contained in:
@@ -11,6 +11,42 @@ import (
|
|||||||
"gvisor.dev/gvisor/pkg/tcpip/checksum"
|
"gvisor.dev/gvisor/pkg/tcpip/checksum"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Protocol header size bounds used to validate / cap kernel-supplied offsets.
|
||||||
|
const (
|
||||||
|
ipv4HeaderMinLen = 20 // IHL=5, no options
|
||||||
|
ipv4HeaderMaxLen = 60 // IHL=15, max options
|
||||||
|
ipv6FixedLen = 40 // IPv6 base header; extensions would extend this
|
||||||
|
tcpHeaderMinLen = 20 // data-offset=5, no options
|
||||||
|
tcpHeaderMaxLen = 60 // data-offset=15, max options
|
||||||
|
)
|
||||||
|
|
||||||
|
// Byte offsets inside an IPv4 header.
|
||||||
|
const (
|
||||||
|
ipv4TotalLenOff = 2
|
||||||
|
ipv4IDOff = 4
|
||||||
|
ipv4ChecksumOff = 10
|
||||||
|
ipv4SrcOff = 12
|
||||||
|
ipv4AddrsEnd = 20 // end of dst address (ipv4SrcOff + 2*4)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Byte offsets inside an IPv6 header.
|
||||||
|
const (
|
||||||
|
ipv6PayloadLenOff = 4
|
||||||
|
ipv6SrcOff = 8
|
||||||
|
ipv6AddrsEnd = 40 // end of dst address (ipv6SrcOff + 2*16)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Byte offsets inside a TCP header (relative to its start, i.e. csumStart).
|
||||||
|
const (
|
||||||
|
tcpSeqOff = 4
|
||||||
|
tcpDataOffOff = 12 // upper nibble is header len in 32-bit words
|
||||||
|
tcpFlagsOff = 13
|
||||||
|
tcpChecksumOff = 16
|
||||||
|
)
|
||||||
|
|
||||||
|
// tcpFinPshMask is cleared on every segment except the last of a TSO burst.
|
||||||
|
const tcpFinPshMask = 0x09 // FIN(0x01) | PSH(0x08)
|
||||||
|
|
||||||
// segmentInto splits a TUN-side packet described by hdr into one or more
|
// segmentInto splits a TUN-side packet described by hdr into one or more
|
||||||
// IP packets, each appended to *out as a slice of scratch. scratch must be
|
// IP packets, each appended to *out as a slice of scratch. scratch must be
|
||||||
// sized to hold every segment (including replicated headers).
|
// sized to hold every segment (including replicated headers).
|
||||||
@@ -77,26 +113,34 @@ func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) err
|
|||||||
if hdr.GSOSize == 0 {
|
if hdr.GSOSize == 0 {
|
||||||
return fmt.Errorf("gso_size is zero")
|
return fmt.Errorf("gso_size is zero")
|
||||||
}
|
}
|
||||||
if int(hdr.HdrLen) > len(pkt) || hdr.HdrLen == 0 {
|
if hdr.CsumStart == 0 {
|
||||||
return fmt.Errorf("hdr_len %d out of range (pkt %d)", hdr.HdrLen, len(pkt))
|
return fmt.Errorf("csum_start is zero")
|
||||||
}
|
|
||||||
if hdr.CsumStart == 0 || hdr.CsumStart >= hdr.HdrLen {
|
|
||||||
return fmt.Errorf("csum_start %d out of range (hdr_len %d)", hdr.CsumStart, hdr.HdrLen)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
isV4 := hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_TCPV4
|
isV4 := hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_TCPV4
|
||||||
headerLen := int(hdr.HdrLen)
|
|
||||||
csumStart := int(hdr.CsumStart)
|
csumStart := int(hdr.CsumStart)
|
||||||
|
|
||||||
if isV4 && csumStart < 20 {
|
if isV4 && csumStart < ipv4HeaderMinLen {
|
||||||
return fmt.Errorf("csum_start %d too small for IPv4", csumStart)
|
return fmt.Errorf("csum_start %d too small for IPv4", csumStart)
|
||||||
}
|
}
|
||||||
if !isV4 && csumStart < 40 {
|
if !isV4 && csumStart < ipv6FixedLen {
|
||||||
return fmt.Errorf("csum_start %d too small for IPv6", csumStart)
|
return fmt.Errorf("csum_start %d too small for IPv6", csumStart)
|
||||||
}
|
}
|
||||||
tcpHdrLen := headerLen - csumStart
|
|
||||||
if tcpHdrLen < 20 {
|
// Don't trust hdr.HdrLen from the kernel: on some paths it can be set
|
||||||
return fmt.Errorf("tcp header region too small: %d", tcpHdrLen)
|
// to the full length of the first packet rather than the true L3+L4 header length.
|
||||||
|
// Instead, read the TCP data-offset field from the packet itself and derive
|
||||||
|
// headerLen = csum_start + tcpHdrLen. Matches wireguard-go's approach.
|
||||||
|
if csumStart+tcpFlagsOff+1 > len(pkt) {
|
||||||
|
return fmt.Errorf("packet too short for tcp header at csum_start=%d (pkt %d)", csumStart, len(pkt))
|
||||||
|
}
|
||||||
|
tcpHdrLen := int(pkt[csumStart+tcpDataOffOff]>>4) * 4
|
||||||
|
if tcpHdrLen < tcpHeaderMinLen || tcpHdrLen > tcpHeaderMaxLen {
|
||||||
|
return fmt.Errorf("tcp data-offset out of range: %d", tcpHdrLen)
|
||||||
|
}
|
||||||
|
headerLen := csumStart + tcpHdrLen
|
||||||
|
if headerLen > len(pkt) {
|
||||||
|
return fmt.Errorf("derived hdr_len %d > pkt %d", headerLen, len(pkt))
|
||||||
}
|
}
|
||||||
|
|
||||||
payload := pkt[headerLen:]
|
payload := pkt[headerLen:]
|
||||||
@@ -112,26 +156,24 @@ func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) err
|
|||||||
return fmt.Errorf("scratch too small for %d segments: need %d have %d", numSeg, need, len(scratch))
|
return fmt.Errorf("scratch too small for %d segments: need %d have %d", numSeg, need, len(scratch))
|
||||||
}
|
}
|
||||||
|
|
||||||
origSeq := binary.BigEndian.Uint32(pkt[csumStart+4 : csumStart+8])
|
origSeq := binary.BigEndian.Uint32(pkt[csumStart+tcpSeqOff : csumStart+tcpSeqOff+4])
|
||||||
origFlags := pkt[csumStart+13]
|
origFlags := pkt[csumStart+tcpFlagsOff]
|
||||||
const tcpFinPsh = 0x09 // FIN(0x01) | PSH(0x08)
|
|
||||||
|
|
||||||
// Precompute the TCP header sum with seq/flags/csum zeroed. The max TCP
|
// Precompute the TCP header sum with seq/flags/csum zeroed. Copy onto
|
||||||
// header is 60 bytes; copy onto the stack, zero the per-segment-varying
|
// the stack, zero the per-segment-varying fields, sum once.
|
||||||
// fields, sum once.
|
var tmp [tcpHeaderMaxLen]byte
|
||||||
var tmp [60]byte
|
|
||||||
copy(tmp[:tcpHdrLen], pkt[csumStart:headerLen])
|
copy(tmp[:tcpHdrLen], pkt[csumStart:headerLen])
|
||||||
tmp[4], tmp[5], tmp[6], tmp[7] = 0, 0, 0, 0 // seq
|
tmp[tcpSeqOff], tmp[tcpSeqOff+1], tmp[tcpSeqOff+2], tmp[tcpSeqOff+3] = 0, 0, 0, 0
|
||||||
tmp[13] = 0 // flags
|
tmp[tcpFlagsOff] = 0
|
||||||
tmp[16], tmp[17] = 0, 0 // csum
|
tmp[tcpChecksumOff], tmp[tcpChecksumOff+1] = 0, 0
|
||||||
baseTcpHdrSum := uint32(checksum.Checksum(tmp[:tcpHdrLen], 0))
|
baseTcpHdrSum := uint32(checksum.Checksum(tmp[:tcpHdrLen], 0))
|
||||||
|
|
||||||
// Pseudo-header src+dst+proto contribution (tcpLen varies per segment).
|
// Pseudo-header src+dst+proto contribution (tcpLen varies per segment).
|
||||||
var baseProtoSum uint32
|
var baseProtoSum uint32
|
||||||
if isV4 {
|
if isV4 {
|
||||||
baseProtoSum = uint32(checksum.Checksum(pkt[12:20], 0))
|
baseProtoSum = uint32(checksum.Checksum(pkt[ipv4SrcOff:ipv4AddrsEnd], 0))
|
||||||
} else {
|
} else {
|
||||||
baseProtoSum = uint32(checksum.Checksum(pkt[8:40], 0))
|
baseProtoSum = uint32(checksum.Checksum(pkt[ipv6SrcOff:ipv6AddrsEnd], 0))
|
||||||
}
|
}
|
||||||
baseProtoSum += uint32(unix.IPPROTO_TCP)
|
baseProtoSum += uint32(unix.IPPROTO_TCP)
|
||||||
|
|
||||||
@@ -140,16 +182,16 @@ func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) err
|
|||||||
var ihl int
|
var ihl int
|
||||||
var baseIPHdrSum uint32
|
var baseIPHdrSum uint32
|
||||||
if isV4 {
|
if isV4 {
|
||||||
origIPID = binary.BigEndian.Uint16(pkt[4:6])
|
origIPID = binary.BigEndian.Uint16(pkt[ipv4IDOff : ipv4IDOff+2])
|
||||||
ihl = int(pkt[0]&0x0f) * 4
|
ihl = int(pkt[0]&0x0f) * 4
|
||||||
if ihl < 20 || ihl > csumStart {
|
if ihl < ipv4HeaderMinLen || ihl > csumStart {
|
||||||
return fmt.Errorf("bad IPv4 IHL: %d", ihl)
|
return fmt.Errorf("bad IPv4 IHL: %d", ihl)
|
||||||
}
|
}
|
||||||
var ipTmp [60]byte
|
var ipTmp [ipv4HeaderMaxLen]byte
|
||||||
copy(ipTmp[:ihl], pkt[:ihl])
|
copy(ipTmp[:ihl], pkt[:ihl])
|
||||||
ipTmp[2], ipTmp[3] = 0, 0 // total_len
|
ipTmp[ipv4TotalLenOff], ipTmp[ipv4TotalLenOff+1] = 0, 0
|
||||||
ipTmp[4], ipTmp[5] = 0, 0 // id
|
ipTmp[ipv4IDOff], ipTmp[ipv4IDOff+1] = 0, 0
|
||||||
ipTmp[10], ipTmp[11] = 0, 0 // checksum
|
ipTmp[ipv4ChecksumOff], ipTmp[ipv4ChecksumOff+1] = 0, 0
|
||||||
baseIPHdrSum = uint32(checksum.Checksum(ipTmp[:ihl], 0))
|
baseIPHdrSum = uint32(checksum.Checksum(ipTmp[:ihl], 0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -170,26 +212,26 @@ func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) err
|
|||||||
segSeq := origSeq + uint32(segStart)
|
segSeq := origSeq + uint32(segStart)
|
||||||
segFlags := origFlags
|
segFlags := origFlags
|
||||||
if i != numSeg-1 {
|
if i != numSeg-1 {
|
||||||
segFlags = origFlags &^ tcpFinPsh
|
segFlags = origFlags &^ tcpFinPshMask
|
||||||
}
|
}
|
||||||
totalLen := headerLen + segPayLen
|
totalLen := headerLen + segPayLen
|
||||||
|
|
||||||
// Patch IP header and write the v4 header checksum from the precomputed base.
|
// Patch IP header and write the v4 header checksum from the precomputed base.
|
||||||
if isV4 {
|
if isV4 {
|
||||||
segID := origIPID + uint16(i)
|
segID := origIPID + uint16(i)
|
||||||
binary.BigEndian.PutUint16(seg[2:4], uint16(totalLen))
|
binary.BigEndian.PutUint16(seg[ipv4TotalLenOff:ipv4TotalLenOff+2], uint16(totalLen))
|
||||||
binary.BigEndian.PutUint16(seg[4:6], segID)
|
binary.BigEndian.PutUint16(seg[ipv4IDOff:ipv4IDOff+2], segID)
|
||||||
ipSum := baseIPHdrSum + uint32(totalLen) + uint32(segID)
|
ipSum := baseIPHdrSum + uint32(totalLen) + uint32(segID)
|
||||||
binary.BigEndian.PutUint16(seg[10:12], foldComplement(ipSum))
|
binary.BigEndian.PutUint16(seg[ipv4ChecksumOff:ipv4ChecksumOff+2], foldComplement(ipSum))
|
||||||
} else {
|
} else {
|
||||||
// IPv6 payload length excludes the 40-byte fixed header but
|
// IPv6 payload length excludes the fixed header but includes any
|
||||||
// includes any extension headers between [40:csumStart].
|
// extension headers between [ipv6FixedLen:csumStart].
|
||||||
binary.BigEndian.PutUint16(seg[4:6], uint16(headerLen-40+segPayLen))
|
binary.BigEndian.PutUint16(seg[ipv6PayloadLenOff:ipv6PayloadLenOff+2], uint16(headerLen-ipv6FixedLen+segPayLen))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Patch TCP header.
|
// Patch TCP header.
|
||||||
binary.BigEndian.PutUint32(seg[csumStart+4:csumStart+8], segSeq)
|
binary.BigEndian.PutUint32(seg[csumStart+tcpSeqOff:csumStart+tcpSeqOff+4], segSeq)
|
||||||
seg[csumStart+13] = segFlags
|
seg[csumStart+tcpFlagsOff] = segFlags
|
||||||
// (csum is written below; its prior contents in `seg` don't affect the
|
// (csum is written below; its prior contents in `seg` don't affect the
|
||||||
// computation since we never sum over the segment's own header.)
|
// computation since we never sum over the segment's own header.)
|
||||||
|
|
||||||
@@ -202,7 +244,7 @@ func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) err
|
|||||||
wide += uint64(segSeq) + uint64(segFlags) + uint64(tcpLen)
|
wide += uint64(segSeq) + uint64(segFlags) + uint64(tcpLen)
|
||||||
wide = (wide & 0xffffffff) + (wide >> 32)
|
wide = (wide & 0xffffffff) + (wide >> 32)
|
||||||
wide = (wide & 0xffffffff) + (wide >> 32)
|
wide = (wide & 0xffffffff) + (wide >> 32)
|
||||||
binary.BigEndian.PutUint16(seg[csumStart+16:csumStart+18], foldComplement(uint32(wide)))
|
binary.BigEndian.PutUint16(seg[csumStart+tcpChecksumOff:csumStart+tcpChecksumOff+2], foldComplement(uint32(wide)))
|
||||||
|
|
||||||
*out = append(*out, seg)
|
*out = append(*out, seg)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user