From 9d59cba7e1be1e4b1de059d7a6c9caaa0a831613 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Fri, 17 Apr 2026 10:25:05 -0500 Subject: [PATCH] first try --- overlay/tun_linux.go | 224 +++++++++++++++++++++------ overlay/tun_linux_offload.go | 234 ++++++++++++++++++++++++++++ overlay/tun_linux_offload_test.go | 247 ++++++++++++++++++++++++++++++ 3 files changed, 657 insertions(+), 48 deletions(-) create mode 100644 overlay/tun_linux_offload.go create mode 100644 overlay/tun_linux_offload_test.go diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 6d7d9fb8..bda1bef2 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -34,6 +34,16 @@ type tunFile struct { readPoll [2]unix.PollFd writePoll [2]unix.PollFd closed bool + + // vnetHdr is true when this fd was opened with IFF_VNET_HDR and the + // kernel successfully accepted TUNSETOFFLOAD. Reads include a leading + // virtio_net_hdr and may carry a TSO superpacket we must segment; + // writes must prepend a zeroed virtio_net_hdr. + vnetHdr bool + readBuf []byte // scratch for a single raw read (virtio hdr + superpacket) + segBuf []byte // backing store for segmented output + pending [][]byte // segments waiting to be drained by Read + pendingIdx int } // newFriend makes a tunFile for a MultiQueueReader that copies the shutdown eventfd from the parent tun @@ -41,9 +51,10 @@ func (r *tunFile) newFriend(fd int) (*tunFile, error) { if err := unix.SetNonblock(fd, true); err != nil { return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) } - return &tunFile{ + out := &tunFile{ fd: fd, shutdownFd: r.shutdownFd, + vnetHdr: r.vnetHdr, readPoll: [2]unix.PollFd{ {Fd: int32(fd), Events: unix.POLLIN}, {Fd: int32(r.shutdownFd), Events: unix.POLLIN}, @@ -52,10 +63,15 @@ func (r *tunFile) newFriend(fd int) (*tunFile, error) { {Fd: int32(fd), Events: unix.POLLOUT}, {Fd: int32(r.shutdownFd), Events: unix.POLLIN}, }, - }, nil + } + if r.vnetHdr { + out.readBuf = make([]byte, tunReadBufSize) + out.segBuf = make([]byte, tunSegBufSize) + } + return out, nil } -func newTunFd(fd int) (*tunFile, error) { +func newTunFd(fd int, vnetHdr bool) (*tunFile, error) { if err := unix.SetNonblock(fd, true); err != nil { return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) } @@ -69,6 +85,7 @@ func newTunFd(fd int) (*tunFile, error) { fd: fd, shutdownFd: shutdownFd, lastOne: true, + vnetHdr: vnetHdr, readPoll: [2]unix.PollFd{ {Fd: int32(fd), Events: unix.POLLIN}, {Fd: int32(shutdownFd), Events: unix.POLLIN}, @@ -78,6 +95,10 @@ func newTunFd(fd int) (*tunFile, error) { {Fd: int32(shutdownFd), Events: unix.POLLIN}, }, } + if vnetHdr { + out.readBuf = make([]byte, tunReadBufSize) + out.segBuf = make([]byte, tunSegBufSize) + } return out, nil } @@ -134,7 +155,7 @@ func (r *tunFile) blockOnWrite() error { return nil } -func (r *tunFile) Read(buf []byte) (int, error) { +func (r *tunFile) readRaw(buf []byte) (int, error) { for { if n, err := unix.Read(r.fd, buf); err == nil { return n, nil @@ -149,20 +170,80 @@ func (r *tunFile) Read(buf []byte) (int, error) { } } -func (r *tunFile) Write(buf []byte) (int, error) { +func (r *tunFile) Read(buf []byte) (int, error) { + if !r.vnetHdr { + return r.readRaw(buf) + } + for { - if n, err := unix.Write(r.fd, buf); err == nil { - return n, nil - } else if err == unix.EAGAIN { + if r.pendingIdx < len(r.pending) { + seg := r.pending[r.pendingIdx] + r.pendingIdx++ + if len(seg) > len(buf) { + return 0, io.ErrShortBuffer + } + return copy(buf, seg), nil + } + r.pending = r.pending[:0] + r.pendingIdx = 0 + + n, err := r.readRaw(r.readBuf) + if err != nil { + return 0, err + } + if n < virtioNetHdrLen { + return 0, fmt.Errorf("short tun read: %d < %d", n, virtioNetHdrLen) + } + var hdr virtioNetHdr + hdr.decode(r.readBuf[:virtioNetHdrLen]) + if err := segmentInto(r.readBuf[virtioNetHdrLen:n], hdr, &r.pending, r.segBuf); err != nil { + // Drop and read again — a bad packet should not kill the reader. + continue + } + } +} + +// zeroVnetHdr is the prefix we prepend to every write when IFF_VNET_HDR is +// active and we have no offload info to convey. +var zeroVnetHdr [virtioNetHdrLen]byte + +func (r *tunFile) Write(buf []byte) (int, error) { + if !r.vnetHdr { + for { + if n, err := unix.Write(r.fd, buf); err == nil { + return n, nil + } else if err == unix.EAGAIN { + if err = r.blockOnWrite(); err != nil { + return 0, err + } + continue + } else if err == unix.EINTR { + continue + } else { + return 0, err + } + } + } + + iovs := [][]byte{zeroVnetHdr[:], buf} + for { + n, err := unix.Writev(r.fd, iovs) + if err == nil { + if n < virtioNetHdrLen { + return 0, io.ErrShortWrite + } + return n - virtioNetHdrLen, nil + } + if err == unix.EAGAIN { if err = r.blockOnWrite(); err != nil { return 0, err } continue - } else if err == unix.EINTR { - continue - } else { - return 0, err } + if err == unix.EINTR { + continue + } + return 0, err } } @@ -233,7 +314,9 @@ type ifreqQLEN struct { } func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { - t, err := newTunGeneric(c, l, deviceFd, vpnNetworks) + // We don't know what flags the caller opened this fd with and can't turn + // on IFF_VNET_HDR after TUNSETIFF, so skip offload on inherited fds. + t, err := newTunGeneric(c, l, deviceFd, false, vpnNetworks) if err != nil { return nil, err } @@ -243,46 +326,83 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net return t, nil } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { +// openTunDev opens /dev/net/tun, creating the device node first if it's +// missing (docker containers occasionally omit it). +func openTunDev() (int, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) - if err != nil { - // If /dev/net/tun doesn't exist, try to create it (will happen in docker) - if os.IsNotExist(err) { - err = os.MkdirAll("/dev/net", 0755) - if err != nil { - return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err) - } - err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))) - if err != nil { - return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err) - } - - fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0) - if err != nil { - return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err) - } - } else { - return nil, err - } + if err == nil { + return fd, nil } + if !os.IsNotExist(err) { + return -1, err + } + if err = os.MkdirAll("/dev/net", 0755); err != nil { + return -1, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err) + } + if err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))); err != nil { + return -1, fmt.Errorf("failed to create /dev/net/tun: %w", err) + } + fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0) + if err != nil { + return -1, fmt.Errorf("created /dev/net/tun, but still failed: %w", err) + } + return fd, nil +} +// tunSetIff runs TUNSETIFF with the given flags and returns the kernel-chosen +// device name on success. +func tunSetIff(fd int, name string, flags uint16) (string, error) { var req ifReq - req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI) + req.Flags = flags + copy(req.Name[:], name) + if err := ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + return "", err + } + return strings.Trim(string(req.Name[:]), "\x00"), nil +} + +// tsoOffloadFlags are the TUN_F_* bits we ask the kernel to enable when a +// TSO-capable TUN is available. CSUM is required as a prerequisite for TSO. +const tsoOffloadFlags = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 + +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { + baseFlags := uint16(unix.IFF_TUN | unix.IFF_NO_PI) if multiqueue { - req.Flags |= unix.IFF_MULTI_QUEUE + baseFlags |= unix.IFF_MULTI_QUEUE } nameStr := c.GetString("tun.dev", "") - copy(req.Name[:], nameStr) - if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + + // First try to open with IFF_VNET_HDR + TUNSETOFFLOAD so we can receive + // TSO superpackets. If either step fails (older kernel, unprivileged + // container, etc.) we close and fall back to a plain TUN. + fd, err := openTunDev() + if err != nil { + return nil, err + } + vnetHdr := true + name, err := tunSetIff(fd, nameStr, baseFlags|unix.IFF_VNET_HDR) + if err != nil { _ = unix.Close(fd) - return nil, &NameError{ - Name: nameStr, - Underlying: err, + vnetHdr = false + } else if err = ioctl(uintptr(fd), unix.TUNSETOFFLOAD, uintptr(tsoOffloadFlags)); err != nil { + l.WithError(err).Warn("Failed to enable TUN offload (TSO); proceeding without virtio headers") + _ = unix.Close(fd) + vnetHdr = false + } + + if !vnetHdr { + fd, err = openTunDev() + if err != nil { + return nil, err + } + name, err = tunSetIff(fd, nameStr, baseFlags) + if err != nil { + _ = unix.Close(fd) + return nil, &NameError{Name: nameStr, Underlying: err} } } - name := strings.Trim(string(req.Name[:]), "\x00") - t, err := newTunGeneric(c, l, fd, vpnNetworks) + t, err := newTunGeneric(c, l, fd, vnetHdr, vpnNetworks) if err != nil { return nil, err } @@ -293,8 +413,8 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu } // newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error. -func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) { - tfd, err := newTunFd(fd) +func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vnetHdr bool, vpnNetworks []netip.Prefix) (*tun, error) { + tfd, err := newTunFd(fd, vnetHdr) if err != nil { _ = unix.Close(fd) return nil, err @@ -413,14 +533,22 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, err } - var req ifReq - req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) - copy(req.Name[:], t.Device) - if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + flags := uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) + if t.vnetHdr { + flags |= unix.IFF_VNET_HDR + } + if _, err = tunSetIff(fd, t.Device, flags); err != nil { _ = unix.Close(fd) return nil, err } + if t.vnetHdr { + if err = ioctl(uintptr(fd), unix.TUNSETOFFLOAD, uintptr(tsoOffloadFlags)); err != nil { + _ = unix.Close(fd) + return nil, fmt.Errorf("failed to enable offload on multiqueue tun fd: %w", err) + } + } + out, err := t.tunFile.newFriend(fd) if err != nil { _ = unix.Close(fd) diff --git a/overlay/tun_linux_offload.go b/overlay/tun_linux_offload.go new file mode 100644 index 00000000..1d77b443 --- /dev/null +++ b/overlay/tun_linux_offload.go @@ -0,0 +1,234 @@ +//go:build !android && !e2e_testing +// +build !android,!e2e_testing + +package overlay + +import ( + "encoding/binary" + "fmt" + + "golang.org/x/sys/unix" +) + +// Size of the legacy struct virtio_net_hdr that the kernel prepends/expects on +// a TUN opened with IFF_VNET_HDR (TUNSETVNETHDRSZ not set). +const virtioNetHdrLen = 10 + +// Maximum size we accept for a single read from a TUN with IFF_VNET_HDR. A +// TSO superpacket can be up to 64KiB of payload plus a single L2/L3/L4 header +// prefix plus the virtio header. +const tunReadBufSize = 65535 + +// Space for segmented output. Worst case is many small segments, each paying +// an IP+TCP header. 128KiB comfortably covers the 64KiB payload ceiling. +const tunSegBufSize = 131072 + +type virtioNetHdr struct { + Flags uint8 + GSOType uint8 + HdrLen uint16 + GSOSize uint16 + CsumStart uint16 + CsumOffset uint16 +} + +// decode reads a virtio_net_hdr in host byte order (TUN default; we never +// call TUNSETVNETLE so the kernel matches our endianness). +func (h *virtioNetHdr) decode(b []byte) { + h.Flags = b[0] + h.GSOType = b[1] + h.HdrLen = binary.NativeEndian.Uint16(b[2:4]) + h.GSOSize = binary.NativeEndian.Uint16(b[4:6]) + h.CsumStart = binary.NativeEndian.Uint16(b[6:8]) + h.CsumOffset = binary.NativeEndian.Uint16(b[8:10]) +} + +// segmentInto splits a TUN-side packet described by hdr into one or more +// IP packets, each appended to *out as a slice of scratch. scratch must be +// sized to hold every segment (including replicated headers). +func segmentInto(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) error { + switch hdr.GSOType { + case unix.VIRTIO_NET_HDR_GSO_NONE: + if len(pkt) > len(scratch) { + return fmt.Errorf("packet larger than segment buffer: %d > %d", len(pkt), len(scratch)) + } + copy(scratch, pkt) + seg := scratch[:len(pkt)] + if hdr.Flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 { + if err := finishChecksum(seg, hdr); err != nil { + return err + } + } + *out = append(*out, seg) + return nil + + case unix.VIRTIO_NET_HDR_GSO_TCPV4, unix.VIRTIO_NET_HDR_GSO_TCPV6: + return segmentTCP(pkt, hdr, out, scratch) + + default: + return fmt.Errorf("unsupported virtio gso type: %d", hdr.GSOType) + } +} + +// finishChecksum computes the L4 checksum for a non-GSO packet that the kernel +// handed us with NEEDS_CSUM set. csum_start / csum_offset point at the 16-bit +// checksum field; we zero it, fold a full sum (the field was pre-loaded with +// the pseudo-header partial sum by the kernel), and store the result. +func finishChecksum(seg []byte, hdr virtioNetHdr) error { + cs := int(hdr.CsumStart) + co := int(hdr.CsumOffset) + if cs+co+2 > len(seg) { + return fmt.Errorf("csum offsets out of range: start=%d offset=%d len=%d", cs, co, len(seg)) + } + // The kernel stores a partial pseudo-header sum at [cs+co:]; sum over the + // L4 region starting at cs, folding the prior partial in as the seed. + partial := uint32(binary.BigEndian.Uint16(seg[cs+co : cs+co+2])) + seg[cs+co] = 0 + seg[cs+co+1] = 0 + sum := checksumBytes(seg[cs:], partial) + binary.BigEndian.PutUint16(seg[cs+co:cs+co+2], checksumFold(sum)) + return nil +} + +// segmentTCP software-segments a TSO superpacket into one IP packet per MSS +// chunk. The caller guarantees hdr.GSOType is TCPV4 or TCPV6. +func segmentTCP(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) error { + if hdr.GSOSize == 0 { + return fmt.Errorf("gso_size is zero") + } + if int(hdr.HdrLen) > len(pkt) || hdr.HdrLen == 0 { + return fmt.Errorf("hdr_len %d out of range (pkt %d)", hdr.HdrLen, len(pkt)) + } + if hdr.CsumStart == 0 || hdr.CsumStart >= hdr.HdrLen { + return fmt.Errorf("csum_start %d out of range (hdr_len %d)", hdr.CsumStart, hdr.HdrLen) + } + + isV4 := hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_TCPV4 + headerLen := int(hdr.HdrLen) + csumStart := int(hdr.CsumStart) + + if isV4 && csumStart < 20 { + return fmt.Errorf("csum_start %d too small for IPv4", csumStart) + } + if !isV4 && csumStart < 40 { + return fmt.Errorf("csum_start %d too small for IPv6", csumStart) + } + if headerLen-csumStart < 20 { + return fmt.Errorf("tcp header region too small: %d", headerLen-csumStart) + } + + payload := pkt[headerLen:] + payLen := len(payload) + gso := int(hdr.GSOSize) + numSeg := (payLen + gso - 1) / gso + if numSeg == 0 { + numSeg = 1 + } + + need := numSeg*headerLen + payLen + if need > len(scratch) { + return fmt.Errorf("scratch too small for %d segments: need %d have %d", numSeg, need, len(scratch)) + } + + origSeq := binary.BigEndian.Uint32(pkt[csumStart+4 : csumStart+8]) + origFlags := pkt[csumStart+13] + const tcpFinPsh = 0x09 // FIN(0x01) | PSH(0x08) + + var origIPID uint16 + if isV4 { + origIPID = binary.BigEndian.Uint16(pkt[4:6]) + } + + off := 0 + for i := 0; i < numSeg; i++ { + segStart := i * gso + segEnd := segStart + gso + if segEnd > payLen { + segEnd = payLen + } + segPayLen := segEnd - segStart + + // Materialise IP+TCP header and this segment's payload chunk. + copy(scratch[off:], pkt[:headerLen]) + copy(scratch[off+headerLen:], payload[segStart:segEnd]) + seg := scratch[off : off+headerLen+segPayLen] + off += headerLen + segPayLen + + // Fix IP header: total/payload length, v4 ID, v4 header csum. + if isV4 { + ihl := int(seg[0]&0x0f) * 4 + if ihl < 20 || ihl > csumStart { + return fmt.Errorf("bad IPv4 IHL: %d", ihl) + } + binary.BigEndian.PutUint16(seg[2:4], uint16(headerLen+segPayLen)) + binary.BigEndian.PutUint16(seg[4:6], origIPID+uint16(i)) + seg[10] = 0 + seg[11] = 0 + binary.BigEndian.PutUint16(seg[10:12], checksumFold(checksumBytes(seg[:ihl], 0))) + } else { + // IPv6 payload length excludes the 40-byte fixed header but + // includes any extension headers that sit between [40:csumStart]. + binary.BigEndian.PutUint16(seg[4:6], uint16(headerLen-40+segPayLen)) + } + + // Fix TCP header: seq, flags, checksum. + segSeq := origSeq + uint32(segStart) + binary.BigEndian.PutUint32(seg[csumStart+4:csumStart+8], segSeq) + if i != numSeg-1 { + seg[csumStart+13] = origFlags &^ tcpFinPsh + } else { + seg[csumStart+13] = origFlags + } + seg[csumStart+16] = 0 + seg[csumStart+17] = 0 + + tcpLen := headerLen - csumStart + segPayLen + var psum uint32 + if isV4 { + psum = pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, tcpLen) + } else { + psum = pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_TCP, tcpLen) + } + binary.BigEndian.PutUint16(seg[csumStart+16:csumStart+18], checksumFold(checksumBytes(seg[csumStart:csumStart+tcpLen], psum))) + + *out = append(*out, seg) + } + + return nil +} + +func checksumBytes(b []byte, initial uint32) uint32 { + sum := initial + for len(b) >= 2 { + sum += uint32(binary.BigEndian.Uint16(b[:2])) + b = b[2:] + } + if len(b) == 1 { + sum += uint32(b[0]) << 8 + } + return sum +} + +func checksumFold(sum uint32) uint16 { + for sum>>16 != 0 { + sum = (sum & 0xffff) + (sum >> 16) + } + return ^uint16(sum) +} + +func pseudoHeaderIPv4(src, dst []byte, proto byte, tcpLen int) uint32 { + sum := checksumBytes(src, 0) + sum = checksumBytes(dst, sum) + sum += uint32(proto) + sum += uint32(tcpLen) + return sum +} + +func pseudoHeaderIPv6(src, dst []byte, proto byte, tcpLen int) uint32 { + sum := checksumBytes(src, 0) + sum = checksumBytes(dst, sum) + sum += uint32(tcpLen >> 16) + sum += uint32(tcpLen & 0xffff) + sum += uint32(proto) + return sum +} diff --git a/overlay/tun_linux_offload_test.go b/overlay/tun_linux_offload_test.go new file mode 100644 index 00000000..a00df60a --- /dev/null +++ b/overlay/tun_linux_offload_test.go @@ -0,0 +1,247 @@ +//go:build !android && !e2e_testing +// +build !android,!e2e_testing + +package overlay + +import ( + "encoding/binary" + "testing" + + "golang.org/x/sys/unix" +) + +// verifyChecksum confirms that the one's-complement sum across `b`, optionally +// seeded with a pseudo-header sum, folds to all-ones (valid). +func verifyChecksum(b []byte, pseudo uint32) bool { + sum := checksumBytes(b, pseudo) + for sum>>16 != 0 { + sum = (sum & 0xffff) + (sum >> 16) + } + return uint16(sum) == 0xffff +} + +// buildTSOv4 builds a synthetic IPv4/TCP TSO superpacket with a payload of +// `payLen` bytes split at `mss`. +func buildTSOv4(t *testing.T, payLen, mss int) ([]byte, virtioNetHdr) { + t.Helper() + const ipLen = 20 + const tcpLen = 20 + pkt := make([]byte, ipLen+tcpLen+payLen) + + // IPv4 header + pkt[0] = 0x45 // version 4, IHL 5 + // total length is meaningless for TSO but set it anyway + binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+tcpLen+payLen)) + binary.BigEndian.PutUint16(pkt[4:6], 0x4242) // original ID + pkt[8] = 64 // TTL + pkt[9] = unix.IPPROTO_TCP + copy(pkt[12:16], []byte{10, 0, 0, 1}) // src + copy(pkt[16:20], []byte{10, 0, 0, 2}) // dst + + // TCP header + binary.BigEndian.PutUint16(pkt[20:22], 12345) // sport + binary.BigEndian.PutUint16(pkt[22:24], 80) // dport + binary.BigEndian.PutUint32(pkt[24:28], 10000) // seq + binary.BigEndian.PutUint32(pkt[28:32], 20000) // ack + pkt[32] = 0x50 // data offset 5 words + pkt[33] = 0x18 // ACK | PSH + binary.BigEndian.PutUint16(pkt[34:36], 65535) // window + + // payload + for i := 0; i < payLen; i++ { + pkt[ipLen+tcpLen+i] = byte(i & 0xff) + } + + return pkt, virtioNetHdr{ + Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4, + HdrLen: uint16(ipLen + tcpLen), + GSOSize: uint16(mss), + CsumStart: uint16(ipLen), + CsumOffset: 16, + } +} + +func TestSegmentTCPv4(t *testing.T) { + const mss = 100 + const numSeg = 3 + pkt, hdr := buildTSOv4(t, mss*numSeg, mss) + + scratch := make([]byte, tunSegBufSize) + var out [][]byte + if err := segmentTCP(pkt, hdr, &out, scratch); err != nil { + t.Fatalf("segmentTCP: %v", err) + } + if len(out) != numSeg { + t.Fatalf("expected %d segments, got %d", numSeg, len(out)) + } + + for i, seg := range out { + if len(seg) != 40+mss { + t.Errorf("seg %d: unexpected len %d", i, len(seg)) + } + totalLen := binary.BigEndian.Uint16(seg[2:4]) + if totalLen != uint16(40+mss) { + t.Errorf("seg %d: total_len=%d want %d", i, totalLen, 40+mss) + } + id := binary.BigEndian.Uint16(seg[4:6]) + if id != 0x4242+uint16(i) { + t.Errorf("seg %d: ip id=%#x want %#x", i, id, 0x4242+uint16(i)) + } + seq := binary.BigEndian.Uint32(seg[24:28]) + wantSeq := uint32(10000 + i*mss) + if seq != wantSeq { + t.Errorf("seg %d: seq=%d want %d", i, seq, wantSeq) + } + flags := seg[33] + wantFlags := byte(0x10) // ACK only, PSH cleared + if i == numSeg-1 { + wantFlags = 0x18 // ACK | PSH preserved on last + } + if flags != wantFlags { + t.Errorf("seg %d: flags=%#x want %#x", i, flags, wantFlags) + } + // IPv4 header checksum must verify against itself. + if !verifyChecksum(seg[:20], 0) { + t.Errorf("seg %d: bad IPv4 header checksum", i) + } + // TCP checksum must verify against the pseudo-header. + psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+mss) + if !verifyChecksum(seg[20:], psum) { + t.Errorf("seg %d: bad TCP checksum", i) + } + } +} + +func TestSegmentTCPv4OddTail(t *testing.T) { + // Payload of 250 bytes with MSS 100 → segments of 100, 100, 50. + pkt, hdr := buildTSOv4(t, 250, 100) + scratch := make([]byte, tunSegBufSize) + var out [][]byte + if err := segmentTCP(pkt, hdr, &out, scratch); err != nil { + t.Fatalf("segmentTCP: %v", err) + } + if len(out) != 3 { + t.Fatalf("want 3 segments, got %d", len(out)) + } + wantPayLens := []int{100, 100, 50} + for i, seg := range out { + if len(seg)-40 != wantPayLens[i] { + t.Errorf("seg %d: pay len %d want %d", i, len(seg)-40, wantPayLens[i]) + } + if !verifyChecksum(seg[:20], 0) { + t.Errorf("seg %d: bad IPv4 header checksum", i) + } + psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+wantPayLens[i]) + if !verifyChecksum(seg[20:], psum) { + t.Errorf("seg %d: bad TCP checksum", i) + } + } +} + +func TestSegmentTCPv6(t *testing.T) { + const ipLen = 40 + const tcpLen = 20 + const mss = 120 + const numSeg = 2 + payLen := mss * numSeg + pkt := make([]byte, ipLen+tcpLen+payLen) + + // IPv6 header + pkt[0] = 0x60 // version 6 + binary.BigEndian.PutUint16(pkt[4:6], uint16(tcpLen+payLen)) + pkt[6] = unix.IPPROTO_TCP + pkt[7] = 64 + // src/dst fe80::1 / fe80::2 + pkt[8] = 0xfe + pkt[9] = 0x80 + pkt[23] = 1 + pkt[24] = 0xfe + pkt[25] = 0x80 + pkt[39] = 2 + + // TCP header + binary.BigEndian.PutUint16(pkt[40:42], 12345) + binary.BigEndian.PutUint16(pkt[42:44], 80) + binary.BigEndian.PutUint32(pkt[44:48], 7) + binary.BigEndian.PutUint32(pkt[48:52], 99) + pkt[52] = 0x50 + pkt[53] = 0x19 // FIN | ACK | PSH — exercise FIN clearing too + binary.BigEndian.PutUint16(pkt[54:56], 65535) + + for i := 0; i < payLen; i++ { + pkt[ipLen+tcpLen+i] = byte(i) + } + + hdr := virtioNetHdr{ + Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV6, + HdrLen: uint16(ipLen + tcpLen), + GSOSize: uint16(mss), + CsumStart: uint16(ipLen), + CsumOffset: 16, + } + + scratch := make([]byte, tunSegBufSize) + var out [][]byte + if err := segmentTCP(pkt, hdr, &out, scratch); err != nil { + t.Fatalf("segmentTCP: %v", err) + } + if len(out) != numSeg { + t.Fatalf("want %d segments, got %d", numSeg, len(out)) + } + + for i, seg := range out { + if len(seg) != ipLen+tcpLen+mss { + t.Errorf("seg %d: len %d want %d", i, len(seg), ipLen+tcpLen+mss) + } + pl := binary.BigEndian.Uint16(seg[4:6]) + if pl != uint16(tcpLen+mss) { + t.Errorf("seg %d: payload_length=%d want %d", i, pl, tcpLen+mss) + } + seq := binary.BigEndian.Uint32(seg[44:48]) + if seq != uint32(7+i*mss) { + t.Errorf("seg %d: seq=%d want %d", i, seq, 7+i*mss) + } + flags := seg[53] + // Original flags = 0x19 (FIN|ACK|PSH). FIN(0x01)+PSH(0x08) should be + // cleared on all but the last; ACK(0x10) always preserved. + wantFlags := byte(0x10) + if i == numSeg-1 { + wantFlags = 0x19 + } + if flags != wantFlags { + t.Errorf("seg %d: flags=%#x want %#x", i, flags, wantFlags) + } + psum := pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_TCP, tcpLen+mss) + if !verifyChecksum(seg[ipLen:], psum) { + t.Errorf("seg %d: bad TCP checksum", i) + } + } +} + +func TestSegmentGSONonePassesThrough(t *testing.T) { + pkt, hdr := buildTSOv4(t, 100, 100) + hdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE + hdr.Flags = 0 // no NEEDS_CSUM, leave packet untouched + + scratch := make([]byte, tunSegBufSize) + var out [][]byte + if err := segmentInto(pkt, hdr, &out, scratch); err != nil { + t.Fatalf("segmentInto: %v", err) + } + if len(out) != 1 { + t.Fatalf("want 1 segment, got %d", len(out)) + } + if len(out[0]) != len(pkt) { + t.Fatalf("unexpected length: %d vs %d", len(out[0]), len(pkt)) + } +} + +func TestSegmentRejectsUDP(t *testing.T) { + hdr := virtioNetHdr{GSOType: unix.VIRTIO_NET_HDR_GSO_UDP} + var out [][]byte + if err := segmentInto(nil, hdr, &out, nil); err == nil { + t.Fatalf("expected rejection for UDP GSO") + } +}