From 8fdd98f639b58cdfdccbea34aae915284009fb44 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Mon, 27 Apr 2026 10:26:15 -0500 Subject: [PATCH] attempt to improve readability --- overlay/tio/tio_gso_linux.go | 1 - overlay/tio/tun_linux_offload.go | 129 +++++++++++++++++++++---------- 2 files changed, 88 insertions(+), 42 deletions(-) diff --git a/overlay/tio/tio_gso_linux.go b/overlay/tio/tio_gso_linux.go index e95e8791..0df8d0fd 100644 --- a/overlay/tio/tio_gso_linux.go +++ b/overlay/tio/tio_gso_linux.go @@ -225,7 +225,6 @@ func (r *Offload) Read() ([][]byte, error) { // decodeRead decodes the virtio header plus payload in r.readBuf[:n], appends // the segments to r.pending, and advances r.segOff by the total scratch used. -// Caller must have already ensured r.vnetHdr is true. func (r *Offload) decodeRead(n int) error { if n < virtioNetHdrLen { return fmt.Errorf("short tun read: %d < %d", n, virtioNetHdrLen) diff --git a/overlay/tio/tun_linux_offload.go b/overlay/tio/tun_linux_offload.go index 6e30dec0..0b953a3e 100644 --- a/overlay/tio/tun_linux_offload.go +++ b/overlay/tio/tun_linux_offload.go @@ -5,6 +5,7 @@ package tio import ( "encoding/binary" + "errors" "fmt" "golang.org/x/sys/unix" @@ -47,10 +48,7 @@ const ( // tcpFinPshMask is cleared on every segment except the last of a TSO burst. const tcpFinPshMask = 0x09 // FIN(0x01) | PSH(0x08) -// segmentInto splits a TUN-side packet described by hdr into one or more -// IP packets, each appended to *out as a slice of scratch. scratch must be -// sized to hold every segment (including replicated headers). -func segmentInto(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) error { +func checkVirtioValid(pkt []byte, hdr VirtioNetHdr) error { // When RSC_INFO is set the csum_start/csum_offset fields are repurposed to // carry coalescing info rather than checksum offsets. A TUN writing via // IFF_VNET_HDR should never emit this, but if it did we would silently @@ -58,22 +56,91 @@ func segmentInto(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) er if hdr.Flags&unix.VIRTIO_NET_HDR_F_RSC_INFO != 0 { return fmt.Errorf("virtio RSC_INFO flag not supported on TUN reads") } - + ipVersion := pkt[0] >> 4 switch hdr.GSOType { - case unix.VIRTIO_NET_HDR_GSO_NONE: - if len(pkt) > len(scratch) { - return fmt.Errorf("packet larger than segment buffer: %d > %d", len(pkt), len(scratch)) + case unix.VIRTIO_NET_HDR_GSO_TCPV4: + if ipVersion != 4 { + return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType) } - copy(scratch, pkt) - seg := scratch[:len(pkt)] - if hdr.Flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 { - if err := finishChecksum(seg, hdr); err != nil { - return err - } + case unix.VIRTIO_NET_HDR_GSO_TCPV6: + if ipVersion != 6 { + return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType) } - *out = append(*out, seg) - return nil + default: + if !(ipVersion == 6 || ipVersion == 4) { + return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType) + } + } + return nil +} + +func handleGSONone(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) error { + if len(pkt) > len(scratch) { + return fmt.Errorf("packet larger than segment buffer: %d > %d", len(pkt), len(scratch)) + } + copy(scratch, pkt) + seg := scratch[:len(pkt)] + if hdr.Flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 { + if err := finishChecksum(seg, hdr); err != nil { + return err + } + } + *out = append(*out, seg) + return nil +} + +func correctHdrLen(pkt []byte, hdr *VirtioNetHdr) error { + // Thank you wireguard-go for documenting these edge-cases + // Don't trust hdr.hdrLen from the kernel as it can be equal to the length + // of the entire first packet when the kernel is handling it as part of a + // FORWARD path. Instead, parse the transport header length and add it onto + // csumStart, which is synonymous for IP header length. + const tcpDataOffset = 12 + + if hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 { + hdr.HdrLen = hdr.CsumStart + 8 + } else { + if len(pkt) <= int(hdr.CsumStart+tcpDataOffset) { + return errors.New("packet is too short") + } + + tcpHLen := uint16(pkt[hdr.CsumStart+tcpDataOffset] >> 4 * 4) + if tcpHLen < 20 || tcpHLen > 60 { + // A TCP header must be between 20 and 60 bytes in length. + return fmt.Errorf("tcp header len is invalid: %d", tcpHLen) + } + hdr.HdrLen = hdr.CsumStart + tcpHLen + } + + if len(pkt) < int(hdr.HdrLen) { + return fmt.Errorf("length of packet (%d) < virtioNetHdr.HdrLen (%d)", len(pkt), hdr.HdrLen) + } + + if hdr.HdrLen < hdr.CsumStart { + return fmt.Errorf("virtioNetHdr.HdrLen (%d) < virtioNetHdr.CsumStart (%d)", hdr.HdrLen, hdr.CsumStart) + } + cSumAt := int(hdr.CsumStart + hdr.CsumStart) + if cSumAt+1 >= len(pkt) { + return fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(pkt)) + } + return nil +} + +// segmentInto splits a TUN-side packet described by hdr into one or more +// IP packets, each appended to *out as a slice of scratch. scratch must be +// sized to hold every segment (including replicated headers). +func segmentInto(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) error { + if err := checkVirtioValid(pkt, hdr); err != nil { + return err + } + if hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_NONE { + return handleGSONone(pkt, hdr, out, scratch) + } + if err := correctHdrLen(pkt, &hdr); err != nil { + return err + } + switch hdr.GSOType { case unix.VIRTIO_NET_HDR_GSO_TCPV4, unix.VIRTIO_NET_HDR_GSO_TCPV6: return segmentTCP(pkt, hdr, out, scratch) @@ -118,35 +185,15 @@ func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) err } isV4 := hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_TCPV4 + headerLen := int(hdr.HdrLen) // already corrected by the caller csumStart := int(hdr.CsumStart) - if isV4 && csumStart < ipv4HeaderMinLen { - return fmt.Errorf("csum_start %d too small for IPv4", csumStart) - } - if !isV4 && csumStart < ipv6FixedLen { - return fmt.Errorf("csum_start %d too small for IPv6", csumStart) - } - - // Don't trust hdr.HdrLen from the kernel: on some paths it can be set - // to the full length of the first packet rather than the true L3+L4 header length. - // Instead, read the TCP data-offset field from the packet itself and derive - // headerLen = csum_start + tcpHdrLen. Matches wireguard-go's approach. - if csumStart+tcpFlagsOff+1 > len(pkt) { - return fmt.Errorf("packet too short for tcp header at csum_start=%d (pkt %d)", csumStart, len(pkt)) - } tcpHdrLen := int(pkt[csumStart+tcpDataOffOff]>>4) * 4 - if tcpHdrLen < tcpHeaderMinLen || tcpHdrLen > tcpHeaderMaxLen { - return fmt.Errorf("tcp data-offset out of range: %d", tcpHdrLen) - } - headerLen := csumStart + tcpHdrLen - if headerLen > len(pkt) { - return fmt.Errorf("derived hdr_len %d > pkt %d", headerLen, len(pkt)) - } payload := pkt[headerLen:] payLen := len(payload) - gso := int(hdr.GSOSize) - numSeg := (payLen + gso - 1) / gso + gsoSize := int(hdr.GSOSize) + numSeg := (payLen + gsoSize - 1) / gsoSize if numSeg == 0 { numSeg = 1 } @@ -197,8 +244,8 @@ func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) err off := 0 for i := 0; i < numSeg; i++ { - segStart := i * gso - segEnd := segStart + gso + segStart := i * gsoSize + segEnd := segStart + gsoSize if segEnd > payLen { segEnd = payLen }