mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-16 12:57:38 +02:00
attempt to improve readability
This commit is contained in:
@@ -225,7 +225,6 @@ func (r *Offload) Read() ([][]byte, error) {
|
|||||||
|
|
||||||
// decodeRead decodes the virtio header plus payload in r.readBuf[:n], appends
|
// decodeRead decodes the virtio header plus payload in r.readBuf[:n], appends
|
||||||
// the segments to r.pending, and advances r.segOff by the total scratch used.
|
// the segments to r.pending, and advances r.segOff by the total scratch used.
|
||||||
// Caller must have already ensured r.vnetHdr is true.
|
|
||||||
func (r *Offload) decodeRead(n int) error {
|
func (r *Offload) decodeRead(n int) error {
|
||||||
if n < virtioNetHdrLen {
|
if n < virtioNetHdrLen {
|
||||||
return fmt.Errorf("short tun read: %d < %d", n, virtioNetHdrLen)
|
return fmt.Errorf("short tun read: %d < %d", n, virtioNetHdrLen)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ package tio
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
@@ -47,10 +48,7 @@ const (
|
|||||||
// tcpFinPshMask is cleared on every segment except the last of a TSO burst.
|
// tcpFinPshMask is cleared on every segment except the last of a TSO burst.
|
||||||
const tcpFinPshMask = 0x09 // FIN(0x01) | PSH(0x08)
|
const tcpFinPshMask = 0x09 // FIN(0x01) | PSH(0x08)
|
||||||
|
|
||||||
// segmentInto splits a TUN-side packet described by hdr into one or more
|
func checkVirtioValid(pkt []byte, hdr VirtioNetHdr) error {
|
||||||
// IP packets, each appended to *out as a slice of scratch. scratch must be
|
|
||||||
// sized to hold every segment (including replicated headers).
|
|
||||||
func segmentInto(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) error {
|
|
||||||
// When RSC_INFO is set the csum_start/csum_offset fields are repurposed to
|
// When RSC_INFO is set the csum_start/csum_offset fields are repurposed to
|
||||||
// carry coalescing info rather than checksum offsets. A TUN writing via
|
// carry coalescing info rather than checksum offsets. A TUN writing via
|
||||||
// IFF_VNET_HDR should never emit this, but if it did we would silently
|
// IFF_VNET_HDR should never emit this, but if it did we would silently
|
||||||
@@ -58,9 +56,26 @@ func segmentInto(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) er
|
|||||||
if hdr.Flags&unix.VIRTIO_NET_HDR_F_RSC_INFO != 0 {
|
if hdr.Flags&unix.VIRTIO_NET_HDR_F_RSC_INFO != 0 {
|
||||||
return fmt.Errorf("virtio RSC_INFO flag not supported on TUN reads")
|
return fmt.Errorf("virtio RSC_INFO flag not supported on TUN reads")
|
||||||
}
|
}
|
||||||
|
ipVersion := pkt[0] >> 4
|
||||||
switch hdr.GSOType {
|
switch hdr.GSOType {
|
||||||
case unix.VIRTIO_NET_HDR_GSO_NONE:
|
case unix.VIRTIO_NET_HDR_GSO_TCPV4:
|
||||||
|
if ipVersion != 4 {
|
||||||
|
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
|
||||||
|
}
|
||||||
|
case unix.VIRTIO_NET_HDR_GSO_TCPV6:
|
||||||
|
if ipVersion != 6 {
|
||||||
|
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if !(ipVersion == 6 || ipVersion == 4) {
|
||||||
|
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleGSONone(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) error {
|
||||||
if len(pkt) > len(scratch) {
|
if len(pkt) > len(scratch) {
|
||||||
return fmt.Errorf("packet larger than segment buffer: %d > %d", len(pkt), len(scratch))
|
return fmt.Errorf("packet larger than segment buffer: %d > %d", len(pkt), len(scratch))
|
||||||
}
|
}
|
||||||
@@ -73,7 +88,59 @@ func segmentInto(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) er
|
|||||||
}
|
}
|
||||||
*out = append(*out, seg)
|
*out = append(*out, seg)
|
||||||
return nil
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func correctHdrLen(pkt []byte, hdr *VirtioNetHdr) error {
|
||||||
|
// Thank you wireguard-go for documenting these edge-cases
|
||||||
|
// Don't trust hdr.hdrLen from the kernel as it can be equal to the length
|
||||||
|
// of the entire first packet when the kernel is handling it as part of a
|
||||||
|
// FORWARD path. Instead, parse the transport header length and add it onto
|
||||||
|
// csumStart, which is synonymous for IP header length.
|
||||||
|
const tcpDataOffset = 12
|
||||||
|
|
||||||
|
if hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
|
||||||
|
hdr.HdrLen = hdr.CsumStart + 8
|
||||||
|
} else {
|
||||||
|
if len(pkt) <= int(hdr.CsumStart+tcpDataOffset) {
|
||||||
|
return errors.New("packet is too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
tcpHLen := uint16(pkt[hdr.CsumStart+tcpDataOffset] >> 4 * 4)
|
||||||
|
if tcpHLen < 20 || tcpHLen > 60 {
|
||||||
|
// A TCP header must be between 20 and 60 bytes in length.
|
||||||
|
return fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
|
||||||
|
}
|
||||||
|
hdr.HdrLen = hdr.CsumStart + tcpHLen
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(pkt) < int(hdr.HdrLen) {
|
||||||
|
return fmt.Errorf("length of packet (%d) < virtioNetHdr.HdrLen (%d)", len(pkt), hdr.HdrLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hdr.HdrLen < hdr.CsumStart {
|
||||||
|
return fmt.Errorf("virtioNetHdr.HdrLen (%d) < virtioNetHdr.CsumStart (%d)", hdr.HdrLen, hdr.CsumStart)
|
||||||
|
}
|
||||||
|
cSumAt := int(hdr.CsumStart + hdr.CsumStart)
|
||||||
|
if cSumAt+1 >= len(pkt) {
|
||||||
|
return fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(pkt))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// segmentInto splits a TUN-side packet described by hdr into one or more
|
||||||
|
// IP packets, each appended to *out as a slice of scratch. scratch must be
|
||||||
|
// sized to hold every segment (including replicated headers).
|
||||||
|
func segmentInto(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) error {
|
||||||
|
if err := checkVirtioValid(pkt, hdr); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_NONE {
|
||||||
|
return handleGSONone(pkt, hdr, out, scratch)
|
||||||
|
}
|
||||||
|
if err := correctHdrLen(pkt, &hdr); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
switch hdr.GSOType {
|
||||||
case unix.VIRTIO_NET_HDR_GSO_TCPV4, unix.VIRTIO_NET_HDR_GSO_TCPV6:
|
case unix.VIRTIO_NET_HDR_GSO_TCPV4, unix.VIRTIO_NET_HDR_GSO_TCPV6:
|
||||||
return segmentTCP(pkt, hdr, out, scratch)
|
return segmentTCP(pkt, hdr, out, scratch)
|
||||||
|
|
||||||
@@ -118,35 +185,15 @@ func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) err
|
|||||||
}
|
}
|
||||||
|
|
||||||
isV4 := hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_TCPV4
|
isV4 := hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_TCPV4
|
||||||
|
headerLen := int(hdr.HdrLen) // already corrected by the caller
|
||||||
csumStart := int(hdr.CsumStart)
|
csumStart := int(hdr.CsumStart)
|
||||||
|
|
||||||
if isV4 && csumStart < ipv4HeaderMinLen {
|
|
||||||
return fmt.Errorf("csum_start %d too small for IPv4", csumStart)
|
|
||||||
}
|
|
||||||
if !isV4 && csumStart < ipv6FixedLen {
|
|
||||||
return fmt.Errorf("csum_start %d too small for IPv6", csumStart)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Don't trust hdr.HdrLen from the kernel: on some paths it can be set
|
|
||||||
// to the full length of the first packet rather than the true L3+L4 header length.
|
|
||||||
// Instead, read the TCP data-offset field from the packet itself and derive
|
|
||||||
// headerLen = csum_start + tcpHdrLen. Matches wireguard-go's approach.
|
|
||||||
if csumStart+tcpFlagsOff+1 > len(pkt) {
|
|
||||||
return fmt.Errorf("packet too short for tcp header at csum_start=%d (pkt %d)", csumStart, len(pkt))
|
|
||||||
}
|
|
||||||
tcpHdrLen := int(pkt[csumStart+tcpDataOffOff]>>4) * 4
|
tcpHdrLen := int(pkt[csumStart+tcpDataOffOff]>>4) * 4
|
||||||
if tcpHdrLen < tcpHeaderMinLen || tcpHdrLen > tcpHeaderMaxLen {
|
|
||||||
return fmt.Errorf("tcp data-offset out of range: %d", tcpHdrLen)
|
|
||||||
}
|
|
||||||
headerLen := csumStart + tcpHdrLen
|
|
||||||
if headerLen > len(pkt) {
|
|
||||||
return fmt.Errorf("derived hdr_len %d > pkt %d", headerLen, len(pkt))
|
|
||||||
}
|
|
||||||
|
|
||||||
payload := pkt[headerLen:]
|
payload := pkt[headerLen:]
|
||||||
payLen := len(payload)
|
payLen := len(payload)
|
||||||
gso := int(hdr.GSOSize)
|
gsoSize := int(hdr.GSOSize)
|
||||||
numSeg := (payLen + gso - 1) / gso
|
numSeg := (payLen + gsoSize - 1) / gsoSize
|
||||||
if numSeg == 0 {
|
if numSeg == 0 {
|
||||||
numSeg = 1
|
numSeg = 1
|
||||||
}
|
}
|
||||||
@@ -197,8 +244,8 @@ func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) err
|
|||||||
|
|
||||||
off := 0
|
off := 0
|
||||||
for i := 0; i < numSeg; i++ {
|
for i := 0; i < numSeg; i++ {
|
||||||
segStart := i * gso
|
segStart := i * gsoSize
|
||||||
segEnd := segStart + gso
|
segEnd := segStart + gsoSize
|
||||||
if segEnd > payLen {
|
if segEnd > payLen {
|
||||||
segEnd = payLen
|
segEnd = payLen
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user