attempt to improve readability

This commit is contained in:
JackDoan
2026-04-27 10:26:15 -05:00
parent 45bc0fc055
commit 8fdd98f639
2 changed files with 88 additions and 42 deletions

View File

@@ -225,7 +225,6 @@ func (r *Offload) Read() ([][]byte, error) {
// decodeRead decodes the virtio header plus payload in r.readBuf[:n], appends
// the segments to r.pending, and advances r.segOff by the total scratch used.
// Caller must have already ensured r.vnetHdr is true.
func (r *Offload) decodeRead(n int) error {
if n < virtioNetHdrLen {
return fmt.Errorf("short tun read: %d < %d", n, virtioNetHdrLen)

View File

@@ -5,6 +5,7 @@ package tio
import (
"encoding/binary"
"errors"
"fmt"
"golang.org/x/sys/unix"
@@ -47,10 +48,7 @@ const (
// tcpFinPshMask is cleared on every segment except the last of a TSO burst.
const tcpFinPshMask = 0x09 // FIN(0x01) | PSH(0x08)
// segmentInto splits a TUN-side packet described by hdr into one or more
// IP packets, each appended to *out as a slice of scratch. scratch must be
// sized to hold every segment (including replicated headers).
func segmentInto(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) error {
func checkVirtioValid(pkt []byte, hdr VirtioNetHdr) error {
// When RSC_INFO is set the csum_start/csum_offset fields are repurposed to
// carry coalescing info rather than checksum offsets. A TUN writing via
// IFF_VNET_HDR should never emit this, but if it did we would silently
@@ -58,9 +56,26 @@ func segmentInto(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) er
if hdr.Flags&unix.VIRTIO_NET_HDR_F_RSC_INFO != 0 {
return fmt.Errorf("virtio RSC_INFO flag not supported on TUN reads")
}
ipVersion := pkt[0] >> 4
switch hdr.GSOType {
case unix.VIRTIO_NET_HDR_GSO_NONE:
case unix.VIRTIO_NET_HDR_GSO_TCPV4:
if ipVersion != 4 {
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
}
case unix.VIRTIO_NET_HDR_GSO_TCPV6:
if ipVersion != 6 {
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
}
default:
if !(ipVersion == 6 || ipVersion == 4) {
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
}
}
return nil
}
func handleGSONone(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) error {
if len(pkt) > len(scratch) {
return fmt.Errorf("packet larger than segment buffer: %d > %d", len(pkt), len(scratch))
}
@@ -73,7 +88,59 @@ func segmentInto(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) er
}
*out = append(*out, seg)
return nil
}
func correctHdrLen(pkt []byte, hdr *VirtioNetHdr) error {
// Thank you wireguard-go for documenting these edge-cases
// Don't trust hdr.hdrLen from the kernel as it can be equal to the length
// of the entire first packet when the kernel is handling it as part of a
// FORWARD path. Instead, parse the transport header length and add it onto
// csumStart, which is synonymous for IP header length.
const tcpDataOffset = 12
if hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
hdr.HdrLen = hdr.CsumStart + 8
} else {
if len(pkt) <= int(hdr.CsumStart+tcpDataOffset) {
return errors.New("packet is too short")
}
tcpHLen := uint16(pkt[hdr.CsumStart+tcpDataOffset] >> 4 * 4)
if tcpHLen < 20 || tcpHLen > 60 {
// A TCP header must be between 20 and 60 bytes in length.
return fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
}
hdr.HdrLen = hdr.CsumStart + tcpHLen
}
if len(pkt) < int(hdr.HdrLen) {
return fmt.Errorf("length of packet (%d) < virtioNetHdr.HdrLen (%d)", len(pkt), hdr.HdrLen)
}
if hdr.HdrLen < hdr.CsumStart {
return fmt.Errorf("virtioNetHdr.HdrLen (%d) < virtioNetHdr.CsumStart (%d)", hdr.HdrLen, hdr.CsumStart)
}
cSumAt := int(hdr.CsumStart + hdr.CsumStart)
if cSumAt+1 >= len(pkt) {
return fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(pkt))
}
return nil
}
// segmentInto splits a TUN-side packet described by hdr into one or more
// IP packets, each appended to *out as a slice of scratch. scratch must be
// sized to hold every segment (including replicated headers).
func segmentInto(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) error {
if err := checkVirtioValid(pkt, hdr); err != nil {
return err
}
if hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_NONE {
return handleGSONone(pkt, hdr, out, scratch)
}
if err := correctHdrLen(pkt, &hdr); err != nil {
return err
}
switch hdr.GSOType {
case unix.VIRTIO_NET_HDR_GSO_TCPV4, unix.VIRTIO_NET_HDR_GSO_TCPV6:
return segmentTCP(pkt, hdr, out, scratch)
@@ -118,35 +185,15 @@ func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) err
}
isV4 := hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_TCPV4
headerLen := int(hdr.HdrLen) // already corrected by the caller
csumStart := int(hdr.CsumStart)
if isV4 && csumStart < ipv4HeaderMinLen {
return fmt.Errorf("csum_start %d too small for IPv4", csumStart)
}
if !isV4 && csumStart < ipv6FixedLen {
return fmt.Errorf("csum_start %d too small for IPv6", csumStart)
}
// Don't trust hdr.HdrLen from the kernel: on some paths it can be set
// to the full length of the first packet rather than the true L3+L4 header length.
// Instead, read the TCP data-offset field from the packet itself and derive
// headerLen = csum_start + tcpHdrLen. Matches wireguard-go's approach.
if csumStart+tcpFlagsOff+1 > len(pkt) {
return fmt.Errorf("packet too short for tcp header at csum_start=%d (pkt %d)", csumStart, len(pkt))
}
tcpHdrLen := int(pkt[csumStart+tcpDataOffOff]>>4) * 4
if tcpHdrLen < tcpHeaderMinLen || tcpHdrLen > tcpHeaderMaxLen {
return fmt.Errorf("tcp data-offset out of range: %d", tcpHdrLen)
}
headerLen := csumStart + tcpHdrLen
if headerLen > len(pkt) {
return fmt.Errorf("derived hdr_len %d > pkt %d", headerLen, len(pkt))
}
payload := pkt[headerLen:]
payLen := len(payload)
gso := int(hdr.GSOSize)
numSeg := (payLen + gso - 1) / gso
gsoSize := int(hdr.GSOSize)
numSeg := (payLen + gsoSize - 1) / gsoSize
if numSeg == 0 {
numSeg = 1
}
@@ -197,8 +244,8 @@ func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) err
off := 0
for i := 0; i < numSeg; i++ {
segStart := i * gso
segEnd := segStart + gso
segStart := i * gsoSize
segEnd := segStart + gsoSize
if segEnd > payLen {
segEnd = payLen
}