GSO/GRO offloads, with TCP+ECN and UDP support

This commit is contained in:
JackDoan
2026-04-17 10:25:05 -05:00
parent f95857b4c3
commit 5d35351437
60 changed files with 6915 additions and 283 deletions

View File

@@ -0,0 +1,79 @@
package tio
import (
"encoding/binary"
"errors"
"fmt"
"golang.org/x/sys/unix"
)
type offloadQueueSet struct {
pq []*Offload
// pqi is exactly the same as pq, but stored as the interface type
pqi []Queue
shutdownFd int
// usoEnabled is true when newTun successfully negotiated TUN_F_USO4|6
// with the kernel. Queues created by Add inherit this and surface it
// via Offload.USOSupported so coalescers can gate USO emission.
usoEnabled bool
}
// NewOffloadQueueSet creates a QueueSet that uses virtio_net_hdr to do
// TSO segmentation in userspace. usoEnabled tells downstream queues whether
// the kernel agreed to deliver/accept GSO_UDP_L4 superpackets — coalescers
// should fall back to per-packet writes when this is false.
func NewOffloadQueueSet(usoEnabled bool) (QueueSet, error) {
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
if err != nil {
return nil, fmt.Errorf("failed to create eventfd: %w", err)
}
out := &offloadQueueSet{
pq: []*Offload{},
pqi: []Queue{},
shutdownFd: shutdownFd,
usoEnabled: usoEnabled,
}
return out, nil
}
func (c *offloadQueueSet) Queues() []Queue {
return c.pqi
}
func (c *offloadQueueSet) Add(fd int) error {
x, err := newOffload(fd, c.shutdownFd, c.usoEnabled)
if err != nil {
return err
}
c.pq = append(c.pq, x)
c.pqi = append(c.pqi, x)
return nil
}
func (c *offloadQueueSet) wakeForShutdown() error {
var buf [8]byte
binary.NativeEndian.PutUint64(buf[:], 1)
_, err := unix.Write(c.shutdownFd, buf[:])
return err
}
func (c *offloadQueueSet) Close() error {
errs := []error{}
// Signal all readers blocked in poll to wake up and exit
if err := c.wakeForShutdown(); err != nil {
errs = append(errs, err)
}
for _, x := range c.pq {
if err := x.Close(); err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}

View File

@@ -8,20 +8,20 @@ import (
"golang.org/x/sys/unix"
)
type pollContainer struct {
type pollQueueSet struct {
pq []*Poll
// pqi is exactly the same as pq, but stored as the interface type
pqi []Queue
shutdownFd int
}
func NewPollContainer() (Container, error) {
func NewPollQueueSet() (QueueSet, error) {
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
if err != nil {
return nil, fmt.Errorf("failed to create eventfd: %w", err)
}
out := &pollContainer{
out := &pollQueueSet{
pq: []*Poll{},
pqi: []Queue{},
shutdownFd: shutdownFd,
@@ -30,11 +30,11 @@ func NewPollContainer() (Container, error) {
return out, nil
}
func (c *pollContainer) Queues() []Queue {
func (c *pollQueueSet) Queues() []Queue {
return c.pqi
}
func (c *pollContainer) Add(fd int) error {
func (c *pollQueueSet) Add(fd int) error {
x, err := newPoll(fd, c.shutdownFd)
if err != nil {
return err
@@ -45,14 +45,14 @@ func (c *pollContainer) Add(fd int) error {
return nil
}
func (c *pollContainer) wakeForShutdown() error {
func (c *pollQueueSet) wakeForShutdown() error {
var buf [8]byte
binary.NativeEndian.PutUint64(buf[:], 1)
_, err := unix.Write(int(c.shutdownFd), buf[:])
return err
}
func (c *pollContainer) Close() error {
func (c *pollQueueSet) Close() error {
errs := []error{}
if err := c.wakeForShutdown(); err != nil {

View File

@@ -0,0 +1,65 @@
//go:build linux && !android && !e2e_testing
package tio
import "testing"
// fakeBatch stands in for batch.TxBatcher inside the bench — same shape
// of pointer-capturing closure that sendInsideMessage builds.
type fakeBatch struct{ buf [65536]byte }
func (b *fakeBatch) Reserve(sz int) []byte { return b.buf[:sz] }
func (b *fakeBatch) Commit([]byte) {}
type fakeHostInfo struct {
remoteIndexId uint32
counter uint64
}
type fakeIface struct {
rebindCount uint8
hi *fakeHostInfo
}
// BenchmarkSegmentSuperpacketAllocsTSO measures allocation per
// SegmentSuperpacket call when a closure captures pointer-bearing
// receivers — the realistic shape of sendInsideMessage's closure.
func BenchmarkSegmentSuperpacketAllocsTSO(b *testing.B) {
const mss = 1400
const numSeg = 32
pkt := buildTSOv6(mss*numSeg, mss)
gso := GSOInfo{
Size: mss,
HdrLen: 60, // 40 (IPv6) + 20 (TCP)
CsumStart: 40,
Proto: GSOProtoTCP,
}
p := Packet{Bytes: pkt, GSO: gso}
hi := &fakeHostInfo{remoteIndexId: 0xdeadbeef}
f := &fakeIface{rebindCount: 7, hi: hi}
fb := &fakeBatch{}
// SegmentSuperpacket consumes pkt destructively; refresh from a master
// copy each iter (matches the production pattern where every TUN read
// hands the segmenter a fresh kernel-supplied buffer).
master := append([]byte(nil), pkt...)
work := make([]byte, len(pkt))
p.Bytes = work
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
copy(work, master)
err := SegmentSuperpacket(p, func(seg []byte) error {
out := fb.Reserve(16 + len(seg) + 16)
out[0] = byte(f.rebindCount)
out[1] = byte(hi.counter)
hi.counter++
fb.Commit(out)
return nil
})
if err != nil {
b.Fatalf("SegmentSuperpacket: %v", err)
}
}
}

View File

@@ -0,0 +1,18 @@
//go:build !linux || android || e2e_testing
package tio
import "fmt"
// SegmentSuperpacket invokes fn once per segment of pkt. On non-Linux
// builds (and Android/e2e_testing) this package does not provide a Queue
// implementation, so any caller that does construct a Packet here can only
// be operating on non-superpacket bytes and the stub forwards them
// directly. A non-zero GSO field is a programming error from the caller
// and returns an explicit error rather than silently misbehaving.
func SegmentSuperpacket(pkt Packet, scratch []byte, fn func(seg []byte) error) error {
if pkt.GSO.IsSuperpacket() {
return fmt.Errorf("tio: GSO superpacket on platform without segmentation support")
}
return fn(pkt.Bytes)
}

View File

@@ -1,56 +1,170 @@
package tio
import "io"
import (
"io"
)
// defaultBatchBufSize is the per-Queue scratch size for Read on backends
// that don't do TSO segmentation. 65535 covers any single IP packet.
const defaultBatchBufSize = 65535
// Container holds one or many Queue objects and helps close them in an orderly way
type Container interface {
// QueueSet holds one or many Queue objects and helps close them in an orderly way.
type QueueSet interface {
io.Closer
Queues() []Queue
// Add takes a tun fd, adds it to the container, and prepares it for use as a Queue
// Add takes a tun fd, adds it to the set, and prepares it for use as a Queue.
Add(fd int) error
}
// Capabilities advertises which kernel offload features a Queue
// successfully negotiated. Callers consult this to decide which coalescers
// to wire onto the write path — a Queue without TSO can't usefully accept a
// TCPCoalescer, and a Queue without USO can't accept a UDPCoalescer.
type Capabilities struct {
// TSO means the FD was opened with IFF_VNET_HDR and the kernel agreed
// to TUN_F_TSO4|TSO6 — i.e. WriteGSO with GSOProtoTCP is safe.
TSO bool
// USO means the kernel additionally agreed to TUN_F_USO4|USO6, so
// WriteGSO with GSOProtoUDP is safe. Linux ≥ 6.2.
USO bool
}
// Queue is a readable/writable Poll queue. One Queue is driven by a single
// read goroutine plus concurrent writers (see Write / WriteReject below).
// read goroutine plus a single writer (see Write below).
type Queue interface {
io.Closer
// Read returns one or more packets. The returned slices are borrowed
// from the Queue's internal buffer and are only valid until the next
// Read or Close on this Queue - callers must encrypt or copy each
// slice before the next call. Not safe for concurrent Reads.
Read() ([][]byte, error)
// Read returns one or more packets. The returned Packet.Bytes slices
// are borrowed from the Queue's internal buffer and are only valid
// until the next Read or Close on this Queue - callers must encrypt
// or copy each slice before the next call. A Packet may carry a
// GSO/USO superpacket (see GSOInfo); when GSO.IsSuperpacket() is
// true the caller must segment Bytes before treating it as a single
// IP datagram. Not safe for concurrent Reads.
Read() ([]Packet, error)
// Write emits a single packet on the plaintext (outside→inside)
// delivery path. Not safe for concurrent Writes.
Write(p []byte) (int, error)
}
// GSOWriter is implemented by Queues that can emit a TCP TSO superpacket
// Packet is the unit Queue.Read returns. Bytes points into the queue's
// internal buffer and is only valid until the next Read or Close on the
// queue that produced it. GSO is the zero value for an already-segmented
// IP datagram; when non-zero it describes a kernel-supplied TSO/USO
// superpacket the caller must segment before consuming.
type Packet struct {
Bytes []byte
GSO GSOInfo
}
// GSOInfo describes a kernel-supplied superpacket sitting in Packet.Bytes.
// The zero value means "not a superpacket" — Bytes is one regular IP
// datagram and no segmentation is required.
type GSOInfo struct {
// Size is the GSO segment size: max payload bytes per segment
// (== TCP MSS for TSO, == UDP payload chunk for USO). Zero means
// not a superpacket.
Size uint16
// HdrLen is the total L3+L4 header length within Bytes (already
// corrected via correctHdrLen, so safe to slice on).
HdrLen uint16
// CsumStart is the L4 header offset inside Bytes (== L3 header
// length).
CsumStart uint16
// Proto picks the L4 protocol (TCP or UDP) so the segmenter knows
// which checksum/header layout to apply.
Proto GSOProto
}
// IsSuperpacket reports whether g describes a multi-segment GSO/USO
// superpacket that needs segmentation before its bytes can be encrypted
// and sent on the wire.
func (g GSOInfo) IsSuperpacket() bool { return g.Size > 0 }
// Clone returns a Packet whose Bytes is a freshly allocated copy of p.Bytes,
// safe to retain past the next Read or Close on the originating Queue.
// GSO metadata is copied verbatim. Use this only when a caller genuinely
// needs to outlive the borrowed-slice contract — the hot path reads should
// continue to consume the borrow synchronously to avoid the allocation.
func (p Packet) Clone() Packet {
if p.Bytes == nil {
return p
}
cp := make([]byte, len(p.Bytes))
copy(cp, p.Bytes)
return Packet{Bytes: cp, GSO: p.GSO}
}
// CapsProvider is an optional interface implemented by Queues that
// successfully negotiated kernel offload features at open time. Callers
// pick a write-path coalescer based on the result. Queues that don't
// implement it are treated as having no offload capability — callers must
// fall back to plain per-packet writes.
type CapsProvider interface {
Capabilities() Capabilities
}
// QueueCapabilities returns q's negotiated offload capabilities, or the
// zero value when q does not advertise any.
func QueueCapabilities(q Queue) Capabilities {
if cp, ok := q.(CapsProvider); ok {
return cp.Capabilities()
}
return Capabilities{}
}
// GSOProto selects the L4 protocol for a GSO superpacket. Determines which
// VIRTIO_NET_HDR_GSO_* type the writer stamps and which checksum offset
// inside the transport header virtio NEEDS_CSUM expects.
type GSOProto uint8
const (
GSOProtoTCP GSOProto = iota
GSOProtoUDP
)
// GSOWriter is implemented by Queues that can emit a TCP or UDP superpacket
// assembled from a header prefix plus one or more borrowed payload
// fragments, in a single vectored write (writev with a leading
// virtio_net_hdr). This lets the coalescer avoid copying payload bytes
// between the caller's decrypt buffer and the TUN. Backends without GSO
// support return false from GSOSupported and coalescing is skipped.
// support do not implement this interface and coalescing is skipped.
//
// hdr contains the IPv4/IPv6 + TCP header prefix (mutable - callers will
// have filled in total length and pseudo-header partial). pays are
// non-overlapping payload fragments whose concatenation is the full
// superpacket payload; they are read-only from the writer's perspective
// and must remain valid until the call returns. gsoSize is the MSS:
// every segment except possibly the last is exactly that many bytes.
// csumStart is the byte offset where the TCP header begins within hdr.
// hdr contains the IPv4/IPv6 header prefix (mutable - callers will have
// filled in total length and IP csum). transportHdr is the TCP or UDP
// header (mutable - the L4 checksum field must hold the pseudo-header
// partial, single-fold not inverted, per virtio NEEDS_CSUM semantics).
// pays are non-overlapping payload fragments whose concatenation is the
// full superpacket payload; they are read-only from the writer's
// perspective and must remain valid until the call returns. Every segment
// in pays except possibly the last is exactly the same size. proto picks
// the L4 protocol so the writer knows which GSOType / CsumOffset to set.
//
// # TODO fold into Queue
//
// hdr's TCP checksum field must already hold the pseudo-header partial
// sum (single-fold, not inverted), per virtio NEEDS_CSUM semantics.
// Callers should also consult CapsProvider (via SupportsGSO or
// QueueCapabilities) for the per-protocol negotiated capability; an
// implementation of GSOWriter is necessary but not sufficient since USO
// may not have been negotiated even when TSO was.
type GSOWriter interface {
WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error
GSOSupported() bool
WriteGSO(hdr []byte, transportHdr []byte, pays [][]byte, proto GSOProto) error
}
// SupportsGSO reports whether w implements GSOWriter and the underlying
// queue advertises the negotiated capability for `want`. A writer that
// implements GSOWriter but not CapsProvider is treated as permissive
// (used by tests and fakes that don't negotiate).
func SupportsGSO(w any, want GSOProto) (GSOWriter, bool) {
gw, ok := w.(GSOWriter)
if !ok {
return nil, false
}
cp, ok := w.(CapsProvider)
if !ok {
return gw, true
}
caps := cp.Capabilities()
switch want {
case GSOProtoTCP:
return gw, caps.TSO
case GSOProtoUDP:
return gw, caps.USO
}
return gw, false
}

View File

@@ -0,0 +1,461 @@
package tio
import (
"fmt"
"io"
"log/slog"
"os"
"sync"
"sync/atomic"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
"github.com/slackhq/nebula/overlay/tio/virtio"
)
// tunRxBufSize is the per-Read worst-case footprint inside rxBuf: one
// kernel-supplied packet body, which is at most ~64 KiB (tunReadBufSize).
// Segmentation happens at encrypt time on a per-routine MTU-sized scratch
// (see SegmentSuperpacket), so rxBuf only holds raw kernel-supplied bytes.
// We round up to give comfortable margin for the drain headroom check
// below.
const tunRxBufSize = 64 * 1024
// tunRxBufCap is the total size we allocate for the per-reader rx
// buffer. With reads landing directly in rxBuf, each drain iteration
// consumes up to tunRxBufSize of headroom for the kernel-supplied bytes.
// Sized to two such iterations so the initial blocking read plus one
// drain read both fit without partial-drop.
const tunRxBufCap = tunRxBufSize * 2
// tunDrainCap caps how many packets a single Read will accumulate via
// the post-wake drain loop. Sized to soak up a burst of small ACKs while
// bounding how much work a single caller holds before handing off.
const tunDrainCap = 64
// gsoMaxIovs caps the iovec budget WriteGSO assembles per call: 3 fixed
// entries (virtio_net_hdr, IP hdr, transport hdr) plus up to gsoMaxIovs-3
// payload fragments. Sized comfortably above the typical kernel GSO
// segment cap (Linux UDP_GRO is 64) so realistic coalesced bursts never
// touch the limit. iovecs are tiny (16 bytes), so the entire scratch is
// 4 KiB — fine to keep resident on every queue. WriteGSO returns an error
// rather than reallocating when a caller exceeds this budget.
const gsoMaxIovs = 256
// validVnetHdr is the 10-byte virtio_net_hdr we prepend to every non-GSO TUN
// write. Only flag set is VIRTIO_NET_HDR_F_DATA_VALID, which marks the skb
// CHECKSUM_UNNECESSARY so the receiving network stack skips L4 checksum
// verification. All packets that reach the plain Write paths already carry
// a valid L4 checksum (either supplied by a remote peer whose ciphertext we
// AEAD-authenticated, produced by segmentTCPYield/segmentUDPYield during
// superpacket segmentation, or built locally by CreateRejectPacket), so
// trusting them is safe.
var validVnetHdr = [virtio.Size]byte{unix.VIRTIO_NET_HDR_F_DATA_VALID}
// Offload wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking.
// A shared eventfd allows Close to wake all readers blocked in poll.
type Offload struct {
fd int
shutdownFd int
readPoll [2]unix.PollFd
writePoll [2]unix.PollFd
// writeLock serializes blockOnWrite's read+clear of writePoll[*].Revents.
// Any goroutine that calls Write may end up parked in poll(2); without
// the lock concurrent waiters could race the Revents reset and lose
// events.
writeLock sync.Mutex
closed atomic.Bool
rxBuf []byte // backing store for kernel-handed packets read this drain
rxOff int // cursor into rxBuf for the current Read drain
pending []Packet // packets returned from the most recent Read
// readVnetScratch holds the 10-byte virtio_net_hdr split off the front of
// every TUN read via readv(2). Decoupling the header from the packet body
// lets us read the body directly into rxBuf at the current rxOff with
// no userspace copy on the GSO_NONE fast path.
readVnetScratch [virtio.Size]byte
// readIovs is the readv(2) iovec scratch wired once at construction —
// iovec[0] points at readVnetScratch; iovec[1].Base/Len is updated per
// read to address the current rxBuf slot.
readIovs [2]unix.Iovec
// usoEnabled records whether the kernel agreed to TUN_F_USO* on this FD,
// so writers can decide whether emitting GSO_UDP_L4 superpackets is safe.
usoEnabled bool
// gsoHdrBuf is a per-queue 10-byte scratch for the virtio_net_hdr emitted
// by WriteGSO. Kept separate from the read-only package-level validVnetHdr
// so non-GSO Writes can ship that constant directly while WriteGSO
// rewrites this scratch on every call.
gsoHdrBuf [virtio.Size]byte
// gsoIovs is the writev iovec scratch for WriteGSO. Pre-sized to
// gsoMaxIovs at construction; never grown. WriteGSO returns an error
// (and drops the call) if a caller hands it more fragments than fit.
gsoIovs []unix.Iovec
}
func newOffload(fd int, shutdownFd int, usoEnabled bool) (*Offload, error) {
if err := unix.SetNonblock(fd, true); err != nil {
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
}
out := &Offload{
fd: fd,
shutdownFd: shutdownFd,
usoEnabled: usoEnabled,
closed: atomic.Bool{},
readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
writePoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLOUT},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
writeLock: sync.Mutex{},
rxBuf: make([]byte, tunRxBufCap),
gsoIovs: make([]unix.Iovec, 2, gsoMaxIovs),
}
out.gsoIovs[0].Base = &out.gsoHdrBuf[0]
out.gsoIovs[0].SetLen(virtio.Size)
// readIovs[0] is wired once to the virtio_net_hdr scratch; per-read we
// only repoint readIovs[1] at the next rxBuf slot (see readPacket).
out.readIovs[0].Base = &out.readVnetScratch[0]
out.readIovs[0].SetLen(virtio.Size)
return out, nil
}
func (r *Offload) blockOnRead() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(r.readPoll[:], -1)
if err != unix.EINTR {
break
}
}
//always reset these!
tunEvents := r.readPoll[0].Revents
shutdownEvents := r.readPoll[1].Revents
r.readPoll[0].Revents = 0
r.readPoll[1].Revents = 0
//do the err check before trusting the potentially bogus bits we just got
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
} else if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (r *Offload) blockOnWrite() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(r.writePoll[:], -1)
if err != unix.EINTR {
break
}
}
//always reset these!
r.writeLock.Lock()
tunEvents := r.writePoll[0].Revents
shutdownEvents := r.writePoll[1].Revents
r.writePoll[0].Revents = 0
r.writePoll[1].Revents = 0
r.writeLock.Unlock()
//do the err check before trusting the potentially bogus bits we just got
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
} else if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
// readPacket issues a single readv(2) splitting the virtio_net_hdr off
// into readVnetScratch and reading the packet body directly into rxBuf at
// the current rxOff. Returns the body length (zero virtio header bytes,
// just the IP packet/superpacket). block controls whether EAGAIN is
// retried via poll: the initial read of a drain blocks; subsequent drain
// reads do not.
//
// The body iovec capacity is always tunReadBufSize; callers (the Read
// drain loop) gate entry on tunRxBufCap-rxOff >= tunRxBufSize, sized to
// hold one worst-case kernel-supplied packet body. Without that gate the
// body iovec could be smaller than the next inbound packet and the
// kernel would truncate.
func (r *Offload) readPacket(block bool) (int, error) {
for {
r.readIovs[1].Base = &r.rxBuf[r.rxOff]
r.readIovs[1].SetLen(tunReadBufSize)
n, _, errno := syscall.Syscall(unix.SYS_READV, uintptr(r.fd), uintptr(unsafe.Pointer(&r.readIovs[0])), uintptr(len(r.readIovs)))
if errno == 0 {
if int(n) < virtio.Size {
return 0, io.ErrShortWrite
}
return int(n) - virtio.Size, nil
}
if errno == unix.EAGAIN {
if !block {
return 0, errno
}
if err := r.blockOnRead(); err != nil {
return 0, err
}
continue
}
if errno == unix.EINTR {
continue
}
if errno == unix.EBADF {
return 0, os.ErrClosed
}
return 0, errno
}
}
// Read returns one or more packets from the tun. Each Packet either
// carries a single ready-to-use IP datagram (GSO zero) or a TSO/USO
// superpacket plus the GSOInfo a caller needs to segment it (see
// SegmentSuperpacket). The first read blocks via poll; once the fd is
// known readable we drain additional packets non-blocking until the
// kernel queue is empty (EAGAIN), we've collected tunDrainCap packets,
// or we're out of rxBuf headroom. This amortizes the poll wake over
// bursts of small packets (e.g. TCP ACKs). Packet.Bytes slices point
// into the Offload's internal buffer and are only valid until the next
// Read or Close on this Queue.
func (r *Offload) Read() ([]Packet, error) {
r.pending = r.pending[:0]
r.rxOff = 0
// Initial (blocking) read. Retry on decode errors so a single bad
// packet does not stall the reader.
for {
n, err := r.readPacket(true)
if err != nil {
return nil, err
}
if err := r.decodeRead(n); err != nil {
// Drop and read again — a bad packet should not kill the reader.
continue
}
break
}
// Drain: non-blocking reads until the kernel queue is empty, the drain
// cap is reached, or rxBuf no longer has room for another worst-case
// kernel-supplied packet (tunRxBufSize).
for len(r.pending) < tunDrainCap && tunRxBufCap-r.rxOff >= tunRxBufSize {
n, err := r.readPacket(false)
if err != nil {
// EAGAIN / EINTR / anything else: stop draining. We already
// have a valid batch from the first read.
break
}
if n <= 0 {
break
}
if err := r.decodeRead(n); err != nil {
// Drop this packet and stop the drain; we'd rather hand off
// what we have than keep spinning here.
break
}
}
return r.pending, nil
}
// decodeRead processes the packet sitting in rxBuf at rxOff (length
// pktLen). The bytes stay in rxBuf — for GSO_NONE we slice them as a
// regular IP datagram (running finishChecksum if NEEDS_CSUM is set);
// for TSO/USO superpackets we attach the corrected GSO metadata so the
// caller can segment lazily at encrypt time. rxOff advances past the
// kernel-supplied body and nothing else, since segmentation no longer
// writes back into rxBuf.
func (r *Offload) decodeRead(pktLen int) error {
if pktLen <= 0 {
return fmt.Errorf("short tun read: %d", pktLen)
}
var hdr virtio.Hdr
hdr.Decode(r.readVnetScratch[:])
body := r.rxBuf[r.rxOff : r.rxOff+pktLen]
if hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_NONE {
if hdr.Flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
if err := virtio.FinishChecksum(body, hdr); err != nil {
return err
}
}
r.pending = append(r.pending, Packet{Bytes: body})
r.rxOff += pktLen
return nil
}
// GSO superpacket: validate, fix the kernel-supplied HdrLen on the
// FORWARD path (CorrectHdrLen), pick the L4 protocol, and attach
// the metadata. The bytes stay in rxBuf untouched, segmentation
// happens in SegmentSuperpacket at encrypt time.
if err := virtio.CheckValid(body, hdr); err != nil {
return err
}
if err := virtio.CorrectHdrLen(body, &hdr); err != nil {
return err
}
proto, err := protoFromGSOType(hdr.GSOType)
if err != nil {
return err
}
r.pending = append(r.pending, Packet{
Bytes: body,
GSO: GSOInfo{
Size: hdr.GSOSize,
HdrLen: hdr.HdrLen,
CsumStart: hdr.CsumStart,
Proto: proto,
},
})
r.rxOff += pktLen
return nil
}
func (r *Offload) Write(buf []byte) (int, error) {
iovs := [2]unix.Iovec{
{Base: &validVnetHdr[0]},
{Base: &buf[0]},
}
iovs[0].SetLen(virtio.Size)
iovs[1].SetLen(len(buf))
return r.writeWithScratch(buf, &iovs)
}
func (r *Offload) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error) {
if len(buf) == 0 {
return 0, nil
}
iovs[1].Base = &buf[0]
iovs[1].SetLen(len(buf))
return r.rawWrite(unsafe.Slice(&iovs[0], len(iovs)))
}
func (r *Offload) rawWrite(iovs []unix.Iovec) (int, error) {
for {
n, _, errno := syscall.Syscall(unix.SYS_WRITEV, uintptr(r.fd), uintptr(unsafe.Pointer(&iovs[0])), uintptr(len(iovs)))
if errno == 0 {
if int(n) < virtio.Size {
return 0, io.ErrShortWrite
}
return int(n) - virtio.Size, nil
}
if errno == unix.EAGAIN {
if err := r.blockOnWrite(); err != nil {
return 0, err
}
continue
}
if errno == unix.EINTR {
continue
}
if errno == unix.EBADF {
return 0, os.ErrClosed
}
return 0, errno
}
}
// Capabilities reports the offload features negotiated for this Queue. TSO
// is always true for Offload (we only construct it on IFF_VNET_HDR FDs);
// USO is true only when the kernel agreed to TUN_F_USO4|6 at open time
// (Linux ≥ 6.2).
func (r *Offload) Capabilities() Capabilities {
return Capabilities{TSO: true, USO: r.usoEnabled}
}
func (r *Offload) WriteGSO(hdr []byte, transportHdr []byte, pays [][]byte, proto GSOProto) error {
if len(hdr) == 0 || len(pays) == 0 || len(transportHdr) == 0 {
return nil
}
// L4 checksum offset inside transportHdr: TCP=16 (the `check` field after
// seq/ack/dataoff/flags/window), UDP=6 (after sport/dport/length).
var csumOff uint16
switch proto {
case GSOProtoUDP:
csumOff = 6
default:
csumOff = 16
}
vhdr := virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
HdrLen: uint16(len(hdr) + len(transportHdr)),
GSOSize: uint16(len(pays[0])),
CsumStart: uint16(len(hdr)),
CsumOffset: csumOff,
}
if len(pays) > 1 {
ipVer := hdr[0] >> 4
switch {
case proto == GSOProtoUDP && (ipVer == 4 || ipVer == 6):
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_UDP_L4
case ipVer == 6:
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV6
case ipVer == 4:
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV4
default:
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE
vhdr.GSOSize = 0
}
} else {
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE
vhdr.GSOSize = 0
}
vhdr.Encode(r.gsoHdrBuf[:])
// Build the iovec array: [virtio_hdr, hdr, transportHdr, pays...]. r.gsoIovs[0] is
// wired to gsoHdrBuf at construction and never changes.
need := 3 + len(pays)
if need > cap(r.gsoIovs) {
slog.Default().Warn("tio: WriteGSO iovec budget exceeded; dropping superpacket",
"need", need, "cap", cap(r.gsoIovs), "segments", len(pays))
return fmt.Errorf("tio: WriteGSO needs %d iovecs but cap is %d", need, cap(r.gsoIovs))
}
r.gsoIovs = r.gsoIovs[:need]
r.gsoIovs[1].Base = &hdr[0]
r.gsoIovs[1].SetLen(len(hdr))
r.gsoIovs[2].Base = &transportHdr[0]
r.gsoIovs[2].SetLen(len(transportHdr))
for i, p := range pays {
r.gsoIovs[3+i].Base = &p[0]
r.gsoIovs[3+i].SetLen(len(p))
}
_, err := r.rawWrite(r.gsoIovs)
return err
}
func (r *Offload) Close() error {
if r.closed.Swap(true) {
return nil
}
//shutdownFd is owned by the container, so we should not close it
var err error
if r.fd >= 0 {
err = unix.Close(r.fd)
r.fd = -1
}
return err
}

View File

@@ -21,7 +21,7 @@ type Poll struct {
closed atomic.Bool
readBuf []byte
batchRet [1][]byte
batchRet [1]Packet
}
func newPoll(fd int, shutdownFd int) (*Poll, error) {
@@ -97,12 +97,12 @@ func (t *Poll) blockOnWrite() error {
return nil
}
func (t *Poll) Read() ([][]byte, error) {
func (t *Poll) Read() ([]Packet, error) {
n, err := t.readOne(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
t.batchRet[0] = Packet{Bytes: t.readBuf[:n]}
return t.batchRet[:], nil
}

View File

@@ -15,7 +15,7 @@ import (
)
// newReadPipe returns a read fd. The matching write fd is registered for cleanup.
// The caller takes ownership of the read fd (pass it to newOffload / newFriend).
// The caller takes ownership of the read fd (pass it into a QueueSet).
func newReadPipe(t *testing.T) int {
t.Helper()
var fds [2]int
@@ -29,7 +29,7 @@ func newReadPipe(t *testing.T) int {
func TestPoll_WakeForShutdown_WakesFriends(t *testing.T) {
pipe1 := newReadPipe(t)
pipe2 := newReadPipe(t)
parent, err := NewPollContainer()
parent, err := NewPollQueueSet()
require.NoError(t, err)
require.NoError(t, parent.Add(pipe1))
require.NoError(t, parent.Add(pipe2))

View File

@@ -0,0 +1,51 @@
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
package tio
import (
"fmt"
"golang.org/x/sys/unix"
"github.com/slackhq/nebula/overlay/tio/virtio"
)
// protoFromGSOType maps a virtio_net_hdr GSOType to the GSOProto value the
// segment-time helpers use. Returns an error for GSO_NONE or any unknown
// value — the caller should only invoke this on a confirmed superpacket.
func protoFromGSOType(t uint8) (GSOProto, error) {
switch t {
case unix.VIRTIO_NET_HDR_GSO_TCPV4, unix.VIRTIO_NET_HDR_GSO_TCPV6:
return GSOProtoTCP, nil
case unix.VIRTIO_NET_HDR_GSO_UDP_L4:
return GSOProtoUDP, nil
default:
return 0, fmt.Errorf("unsupported virtio gso type: %d", t)
}
}
// SegmentSuperpacket invokes fn once per segment of pkt. For non-GSO pkts
// fn is called once with pkt.Bytes (no segmentation, no copy). For GSO/USO
// superpackets fn is called once per segment with a slice of pkt.Bytes
// holding that segment's plaintext (a freshly-patched L3+L4 header sliced
// in front of the original payload chunk). The slide is destructive: pkt is
// consumed by this call and its bytes are in an undefined state when
// SegmentSuperpacket returns. Callers must not retain pkt or any earlier
// seg slice past fn's return for that segment. The scratch parameter is
// unused on the destructive path and kept only for cross-platform
// signature compatibility. Aborts and returns the first error from fn or
// from per-segment construction.
func SegmentSuperpacket(pkt Packet, fn func(seg []byte) error) error {
if !pkt.GSO.IsSuperpacket() {
return fn(pkt.Bytes)
}
switch pkt.GSO.Proto {
case GSOProtoTCP:
return virtio.SegmentTCP(pkt.Bytes, pkt.GSO.HdrLen, pkt.GSO.CsumStart, pkt.GSO.Size, fn)
case GSOProtoUDP:
return virtio.SegmentUDP(pkt.Bytes, pkt.GSO.HdrLen, pkt.GSO.CsumStart, pkt.GSO.Size, fn)
default:
return fmt.Errorf("unsupported gso proto: %d", pkt.GSO.Proto)
}
}

View File

@@ -0,0 +1,794 @@
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
package tio
import (
"encoding/binary"
"os"
"testing"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/checksum"
"github.com/slackhq/nebula/overlay/tio/virtio"
)
// testSegScratchSize is a generous segmentation scratch sized to fit any
// of the synthetic TSO/USO superpackets these tests generate (one
// worst-case 64 KiB superpacket plus replicated per-segment headers).
const testSegScratchSize = 192 * 1024
// verifyChecksum confirms that the one's-complement sum across `b`, seeded
// with a folded pseudo-header sum, equals all-ones (valid).
func verifyChecksum(b []byte, pseudo uint16) bool {
return checksum.Checksum(b, pseudo) == 0xffff
}
// segmentForTest is the test-only counterpart to the production
// SegmentSuperpacket path. It handles GSO_NONE (with optional
// finishChecksum) inline and dispatches GSO superpackets through
// SegmentSuperpacket, draining each yielded segment into a
// freshly-copied [][]byte slot so callers can iterate after the call
// returns. Tests pre-set hdr.HdrLen correctly, so correctHdrLen is not
// invoked here.
func segmentForTest(pkt []byte, hdr virtio.Hdr, out *[][]byte, scratch []byte) error {
if hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_NONE {
cp := append([]byte(nil), pkt...)
if hdr.Flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
if err := virtio.FinishChecksum(cp, hdr); err != nil {
return err
}
}
*out = append(*out, cp)
return nil
}
proto, err := protoFromGSOType(hdr.GSOType)
if err != nil {
return err
}
gso := GSOInfo{
Size: hdr.GSOSize,
HdrLen: hdr.HdrLen,
CsumStart: hdr.CsumStart,
Proto: proto,
}
return SegmentSuperpacket(Packet{Bytes: pkt, GSO: gso}, func(seg []byte) error {
*out = append(*out, append([]byte(nil), seg...))
return nil
})
}
// pseudoHeaderIPv4 returns the folded pseudo-header sum used to verify a
// TCP/UDP segment's checksum in tests. src/dst are 4 bytes each.
func pseudoHeaderIPv4(src, dst []byte, proto byte, l4Len int) uint16 {
s := uint32(checksum.Checksum(src, 0)) + uint32(checksum.Checksum(dst, 0))
s += uint32(proto) + uint32(l4Len)
s = (s & 0xffff) + (s >> 16)
s = (s & 0xffff) + (s >> 16)
return uint16(s)
}
// pseudoHeaderIPv6 returns the folded pseudo-header sum used to verify a
// TCP/UDP segment's checksum in tests. src/dst are 16 bytes each.
func pseudoHeaderIPv6(src, dst []byte, proto byte, l4Len int) uint16 {
s := uint32(checksum.Checksum(src, 0)) + uint32(checksum.Checksum(dst, 0))
s += uint32(l4Len>>16) + uint32(l4Len&0xffff) + uint32(proto)
s = (s & 0xffff) + (s >> 16)
s = (s & 0xffff) + (s >> 16)
return uint16(s)
}
// buildTSOv4 builds a synthetic IPv4/TCP TSO superpacket with a payload of
// `payLen` bytes split at `mss`.
func buildTSOv4(t *testing.T, payLen, mss int) ([]byte, virtio.Hdr) {
t.Helper()
const ipLen = 20
const tcpLen = 20
pkt := make([]byte, ipLen+tcpLen+payLen)
// IPv4 header
pkt[0] = 0x45 // version 4, IHL 5
// total length is meaningless for TSO but set it anyway
binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+tcpLen+payLen))
binary.BigEndian.PutUint16(pkt[4:6], 0x4242) // original ID
pkt[8] = 64 // TTL
pkt[9] = unix.IPPROTO_TCP
copy(pkt[12:16], []byte{10, 0, 0, 1}) // src
copy(pkt[16:20], []byte{10, 0, 0, 2}) // dst
// TCP header
binary.BigEndian.PutUint16(pkt[20:22], 12345) // sport
binary.BigEndian.PutUint16(pkt[22:24], 80) // dport
binary.BigEndian.PutUint32(pkt[24:28], 10000) // seq
binary.BigEndian.PutUint32(pkt[28:32], 20000) // ack
pkt[32] = 0x50 // data offset 5 words
pkt[33] = 0x18 // ACK | PSH
binary.BigEndian.PutUint16(pkt[34:36], 65535) // window
// payload
for i := 0; i < payLen; i++ {
pkt[ipLen+tcpLen+i] = byte(i & 0xff)
}
return pkt, virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
HdrLen: uint16(ipLen + tcpLen),
GSOSize: uint16(mss),
CsumStart: uint16(ipLen),
CsumOffset: 16,
}
}
func TestSegmentTCPv4(t *testing.T) {
const mss = 100
const numSeg = 3
pkt, hdr := buildTSOv4(t, mss*numSeg, mss)
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != numSeg {
t.Fatalf("expected %d segments, got %d", numSeg, len(out))
}
for i, seg := range out {
if len(seg) != 40+mss {
t.Errorf("seg %d: unexpected len %d", i, len(seg))
}
totalLen := binary.BigEndian.Uint16(seg[2:4])
if totalLen != uint16(40+mss) {
t.Errorf("seg %d: total_len=%d want %d", i, totalLen, 40+mss)
}
id := binary.BigEndian.Uint16(seg[4:6])
if id != 0x4242+uint16(i) {
t.Errorf("seg %d: ip id=%#x want %#x", i, id, 0x4242+uint16(i))
}
seq := binary.BigEndian.Uint32(seg[24:28])
wantSeq := uint32(10000 + i*mss)
if seq != wantSeq {
t.Errorf("seg %d: seq=%d want %d", i, seq, wantSeq)
}
flags := seg[33]
wantFlags := byte(0x10) // ACK only, PSH cleared
if i == numSeg-1 {
wantFlags = 0x18 // ACK | PSH preserved on last
}
if flags != wantFlags {
t.Errorf("seg %d: flags=%#x want %#x", i, flags, wantFlags)
}
// IPv4 header checksum must verify against itself.
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
// TCP checksum must verify against the pseudo-header.
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+mss)
if !verifyChecksum(seg[20:], psum) {
t.Errorf("seg %d: bad TCP checksum", i)
}
}
}
func TestSegmentTCPv4OddTail(t *testing.T) {
// Payload of 250 bytes with MSS 100 → segments of 100, 100, 50.
pkt, hdr := buildTSOv4(t, 250, 100)
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != 3 {
t.Fatalf("want 3 segments, got %d", len(out))
}
wantPayLens := []int{100, 100, 50}
for i, seg := range out {
if len(seg)-40 != wantPayLens[i] {
t.Errorf("seg %d: pay len %d want %d", i, len(seg)-40, wantPayLens[i])
}
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+wantPayLens[i])
if !verifyChecksum(seg[20:], psum) {
t.Errorf("seg %d: bad TCP checksum", i)
}
}
}
func TestSegmentTCPv6(t *testing.T) {
const ipLen = 40
const tcpLen = 20
const mss = 120
const numSeg = 2
payLen := mss * numSeg
pkt := make([]byte, ipLen+tcpLen+payLen)
// IPv6 header
pkt[0] = 0x60 // version 6
binary.BigEndian.PutUint16(pkt[4:6], uint16(tcpLen+payLen))
pkt[6] = unix.IPPROTO_TCP
pkt[7] = 64
// src/dst fe80::1 / fe80::2
pkt[8] = 0xfe
pkt[9] = 0x80
pkt[23] = 1
pkt[24] = 0xfe
pkt[25] = 0x80
pkt[39] = 2
// TCP header
binary.BigEndian.PutUint16(pkt[40:42], 12345)
binary.BigEndian.PutUint16(pkt[42:44], 80)
binary.BigEndian.PutUint32(pkt[44:48], 7)
binary.BigEndian.PutUint32(pkt[48:52], 99)
pkt[52] = 0x50
pkt[53] = 0x19 // FIN | ACK | PSH — exercise FIN clearing too
binary.BigEndian.PutUint16(pkt[54:56], 65535)
for i := 0; i < payLen; i++ {
pkt[ipLen+tcpLen+i] = byte(i)
}
hdr := virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
HdrLen: uint16(ipLen + tcpLen),
GSOSize: uint16(mss),
CsumStart: uint16(ipLen),
CsumOffset: 16,
}
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != numSeg {
t.Fatalf("want %d segments, got %d", numSeg, len(out))
}
for i, seg := range out {
if len(seg) != ipLen+tcpLen+mss {
t.Errorf("seg %d: len %d want %d", i, len(seg), ipLen+tcpLen+mss)
}
pl := binary.BigEndian.Uint16(seg[4:6])
if pl != uint16(tcpLen+mss) {
t.Errorf("seg %d: payload_length=%d want %d", i, pl, tcpLen+mss)
}
seq := binary.BigEndian.Uint32(seg[44:48])
if seq != uint32(7+i*mss) {
t.Errorf("seg %d: seq=%d want %d", i, seq, 7+i*mss)
}
flags := seg[53]
// Original flags = 0x19 (FIN|ACK|PSH). FIN(0x01)+PSH(0x08) should be
// cleared on all but the last; ACK(0x10) always preserved.
wantFlags := byte(0x10)
if i == numSeg-1 {
wantFlags = 0x19
}
if flags != wantFlags {
t.Errorf("seg %d: flags=%#x want %#x", i, flags, wantFlags)
}
psum := pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_TCP, tcpLen+mss)
if !verifyChecksum(seg[ipLen:], psum) {
t.Errorf("seg %d: bad TCP checksum", i)
}
}
}
func TestSegmentGSONonePassesThrough(t *testing.T) {
pkt, hdr := buildTSOv4(t, 100, 100)
hdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE
hdr.Flags = 0 // no NEEDS_CSUM, leave packet untouched
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != 1 {
t.Fatalf("want 1 segment, got %d", len(out))
}
if len(out[0]) != len(pkt) {
t.Fatalf("unexpected length: %d vs %d", len(out[0]), len(pkt))
}
}
// TestSegmentRejectsLegacyUDPGSO ensures the legacy GSO_UDP (UFO) marker is
// still rejected; only modern GSO_UDP_L4 (USO) is supported.
func TestSegmentRejectsLegacyUDPGSO(t *testing.T) {
hdr := virtio.Hdr{GSOType: unix.VIRTIO_NET_HDR_GSO_UDP}
var out [][]byte
if err := segmentForTest(nil, hdr, &out, nil); err == nil {
t.Fatalf("expected rejection for legacy UDP GSO")
}
}
// buildUSOv4 builds a synthetic IPv4/UDP USO superpacket with payload of
// payLen bytes, segmented at gsoSize.
func buildUSOv4(t *testing.T, payLen, gsoSize int) ([]byte, virtio.Hdr) {
t.Helper()
const ipLen = 20
const udpLen = 8
pkt := make([]byte, ipLen+udpLen+payLen)
// IPv4 header
pkt[0] = 0x45 // version 4, IHL 5
binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+udpLen+payLen))
binary.BigEndian.PutUint16(pkt[4:6], 0x4242)
pkt[8] = 64
pkt[9] = unix.IPPROTO_UDP
copy(pkt[12:16], []byte{10, 0, 0, 1})
copy(pkt[16:20], []byte{10, 0, 0, 2})
// UDP header (length + checksum filled in per segment by segmentUDPYield)
binary.BigEndian.PutUint16(pkt[20:22], 12345) // sport
binary.BigEndian.PutUint16(pkt[22:24], 53) // dport
for i := 0; i < payLen; i++ {
pkt[ipLen+udpLen+i] = byte(i & 0xff)
}
return pkt, virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
HdrLen: uint16(ipLen + udpLen),
GSOSize: uint16(gsoSize),
CsumStart: uint16(ipLen),
CsumOffset: 6,
}
}
func TestSegmentUDPv4(t *testing.T) {
const gso = 100
const numSeg = 3
pkt, hdr := buildUSOv4(t, gso*numSeg, gso)
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != numSeg {
t.Fatalf("expected %d segments, got %d", numSeg, len(out))
}
for i, seg := range out {
if len(seg) != 28+gso {
t.Errorf("seg %d: len %d want %d", i, len(seg), 28+gso)
}
totalLen := binary.BigEndian.Uint16(seg[2:4])
if totalLen != uint16(28+gso) {
t.Errorf("seg %d: total_len=%d want %d", i, totalLen, 28+gso)
}
// kernel UDP-GSO does NOT bump the IPv4 ID across segments; every
// segment carries the same ID as the seed.
id := binary.BigEndian.Uint16(seg[4:6])
if id != 0x4242 {
t.Errorf("seg %d: ip id=%#x want %#x", i, id, 0x4242)
}
udpLen := binary.BigEndian.Uint16(seg[24:26])
if udpLen != uint16(8+gso) {
t.Errorf("seg %d: udp len=%d want %d", i, udpLen, 8+gso)
}
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_UDP, 8+gso)
if !verifyChecksum(seg[20:], psum) {
t.Errorf("seg %d: bad UDP checksum", i)
}
}
}
func TestSegmentUDPv4OddTail(t *testing.T) {
// 250 bytes payload, gsoSize=100 → segments of 100, 100, 50.
pkt, hdr := buildUSOv4(t, 250, 100)
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != 3 {
t.Fatalf("want 3 segments, got %d", len(out))
}
wantPay := []int{100, 100, 50}
for i, seg := range out {
if len(seg)-28 != wantPay[i] {
t.Errorf("seg %d: pay len %d want %d", i, len(seg)-28, wantPay[i])
}
udpLen := binary.BigEndian.Uint16(seg[24:26])
if udpLen != uint16(8+wantPay[i]) {
t.Errorf("seg %d: udp len=%d want %d", i, udpLen, 8+wantPay[i])
}
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_UDP, 8+wantPay[i])
if !verifyChecksum(seg[20:], psum) {
t.Errorf("seg %d: bad UDP checksum", i)
}
}
}
func TestSegmentUDPv6(t *testing.T) {
const ipLen = 40
const udpLen = 8
const gso = 120
const numSeg = 2
payLen := gso * numSeg
pkt := make([]byte, ipLen+udpLen+payLen)
// IPv6 header
pkt[0] = 0x60
binary.BigEndian.PutUint16(pkt[4:6], uint16(udpLen+payLen))
pkt[6] = unix.IPPROTO_UDP
pkt[7] = 64
pkt[8] = 0xfe
pkt[9] = 0x80
pkt[23] = 1
pkt[24] = 0xfe
pkt[25] = 0x80
pkt[39] = 2
binary.BigEndian.PutUint16(pkt[40:42], 12345)
binary.BigEndian.PutUint16(pkt[42:44], 53)
for i := 0; i < payLen; i++ {
pkt[ipLen+udpLen+i] = byte(i)
}
hdr := virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
HdrLen: uint16(ipLen + udpLen),
GSOSize: uint16(gso),
CsumStart: uint16(ipLen),
CsumOffset: 6,
}
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != numSeg {
t.Fatalf("want %d segments, got %d", numSeg, len(out))
}
for i, seg := range out {
if len(seg) != ipLen+udpLen+gso {
t.Errorf("seg %d: len %d want %d", i, len(seg), ipLen+udpLen+gso)
}
pl := binary.BigEndian.Uint16(seg[4:6])
if pl != uint16(udpLen+gso) {
t.Errorf("seg %d: payload_length=%d want %d", i, pl, udpLen+gso)
}
ul := binary.BigEndian.Uint16(seg[ipLen+4 : ipLen+6])
if ul != uint16(udpLen+gso) {
t.Errorf("seg %d: udp len=%d want %d", i, ul, udpLen+gso)
}
psum := pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_UDP, udpLen+gso)
if !verifyChecksum(seg[ipLen:], psum) {
t.Errorf("seg %d: bad UDP checksum", i)
}
}
}
// TestSegmentUDPCEPropagates confirms IP-level CE marks on the seed appear on
// every segment. UDP has no transport-level CWR/ECE: the IP TOS/TC byte is
// copied verbatim into every segment by the segment-prefix copy.
func TestSegmentUDPCEPropagates(t *testing.T) {
pkt, hdr := buildUSOv4(t, 200, 100)
pkt[1] = 0x03 // CE codepoint in IP-ECN
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != 2 {
t.Fatalf("want 2 segments, got %d", len(out))
}
for i, seg := range out {
if seg[1]&0x03 != 0x03 {
t.Errorf("seg %d: CE missing (tos=%#x)", i, seg[1])
}
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
}
}
// TestSegmentTCPCwrFirstSegmentOnly confirms RFC 3168 §6.1.2: when a TSO
// burst's seed has CWR set, only the first emitted segment carries CWR.
// ECE is preserved on every segment (different signal, persistent state).
func TestSegmentTCPCwrFirstSegmentOnly(t *testing.T) {
const mss = 100
const numSeg = 3
pkt, hdr := buildTSOv4(t, mss*numSeg, mss)
// Seed flags: CWR | ECE | ACK | PSH.
pkt[33] = 0x80 | 0x40 | 0x10 | 0x08
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != numSeg {
t.Fatalf("expected %d segments, got %d", numSeg, len(out))
}
for i, seg := range out {
flags := seg[33]
hasCwr := flags&0x80 != 0
hasEce := flags&0x40 != 0
hasPsh := flags&0x08 != 0
wantCwr := i == 0
wantPsh := i == numSeg-1
if hasCwr != wantCwr {
t.Errorf("seg %d: CWR=%v want %v (flags=%#x)", i, hasCwr, wantCwr, flags)
}
if !hasEce {
t.Errorf("seg %d: ECE missing (flags=%#x)", i, flags)
}
if hasPsh != wantPsh {
t.Errorf("seg %d: PSH=%v want %v (flags=%#x)", i, hasPsh, wantPsh, flags)
}
// IP and TCP checksums must still verify after the flag rewrite.
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+mss)
if !verifyChecksum(seg[20:], psum) {
t.Errorf("seg %d: bad TCP checksum", i)
}
}
}
func BenchmarkSegmentTCPv4(b *testing.B) {
sizes := []struct {
name string
payLen int
mss int
}{
{"64KiB_MSS1460", 65000, 1460},
{"16KiB_MSS1460", 16384, 1460},
{"4KiB_MSS1460", 4096, 1460},
}
for _, sz := range sizes {
b.Run(sz.name, func(b *testing.B) {
const ipLen = 20
const tcpLen = 20
pkt := make([]byte, ipLen+tcpLen+sz.payLen)
pkt[0] = 0x45
binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+tcpLen+sz.payLen))
binary.BigEndian.PutUint16(pkt[4:6], 0x4242)
pkt[8] = 64
pkt[9] = unix.IPPROTO_TCP
copy(pkt[12:16], []byte{10, 0, 0, 1})
copy(pkt[16:20], []byte{10, 0, 0, 2})
binary.BigEndian.PutUint16(pkt[20:22], 12345)
binary.BigEndian.PutUint16(pkt[22:24], 80)
binary.BigEndian.PutUint32(pkt[24:28], 10000)
binary.BigEndian.PutUint32(pkt[28:32], 20000)
pkt[32] = 0x50
pkt[33] = 0x18
binary.BigEndian.PutUint16(pkt[34:36], 65535)
for i := 0; i < sz.payLen; i++ {
pkt[ipLen+tcpLen+i] = byte(i)
}
hdr := virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
HdrLen: uint16(ipLen + tcpLen),
GSOSize: uint16(sz.mss),
CsumStart: uint16(ipLen),
CsumOffset: 16,
}
scratch := make([]byte, testSegScratchSize)
out := make([][]byte, 0, 64)
// SegmentSuperpacket consumes its input destructively; restore
// pkt from a master copy each iteration. The restore mirrors the
// kernel→userspace copy that hands a fresh GSO blob to the
// segmenter in production, so it's representative cost rather
// than bench overhead.
master := append([]byte(nil), pkt...)
work := make([]byte, len(pkt))
b.SetBytes(int64(len(pkt)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
copy(work, master)
out = out[:0]
if err := segmentForTest(work, hdr, &out, scratch); err != nil {
b.Fatal(err)
}
}
})
}
}
// TestTunFileWriteVnetHdrNoAlloc verifies the IFF_VNET_HDR fast-path write is
// allocation-free. We write to /dev/null so every call succeeds synchronously.
func TestTunFileWriteVnetHdrNoAlloc(t *testing.T) {
fd, err := unix.Open("/dev/null", os.O_WRONLY, 0)
if err != nil {
t.Fatalf("open /dev/null: %v", err)
}
t.Cleanup(func() { _ = unix.Close(fd) })
tf := &Offload{fd: fd}
payload := make([]byte, 1400)
// Warm up (first call may trigger one-time internal allocations elsewhere).
if _, err := tf.Write(payload); err != nil {
t.Fatalf("Write: %v", err)
}
allocs := testing.AllocsPerRun(1000, func() {
if _, err := tf.Write(payload); err != nil {
t.Fatalf("Write: %v", err)
}
})
if allocs != 0 {
t.Fatalf("Write allocated %.1f times per call, want 0", allocs)
}
}
// buildTSOv6 builds a synthetic IPv6/TCP TSO superpacket with payLen bytes
// of payload, segmented at gso. Returns the packet bytes only; the
// virtio_net_hdr is the caller's responsibility.
func buildTSOv6(payLen, gso int) []byte {
const ipLen = 40
const tcpLen = 20
pkt := make([]byte, ipLen+tcpLen+payLen)
pkt[0] = 0x60 // version 6
binary.BigEndian.PutUint16(pkt[4:6], uint16(tcpLen+payLen))
pkt[6] = unix.IPPROTO_TCP
pkt[7] = 64
pkt[8] = 0xfe
pkt[9] = 0x80
pkt[23] = 1
pkt[24] = 0xfe
pkt[25] = 0x80
pkt[39] = 2
binary.BigEndian.PutUint16(pkt[40:42], 12345)
binary.BigEndian.PutUint16(pkt[42:44], 80)
binary.BigEndian.PutUint32(pkt[44:48], 7)
binary.BigEndian.PutUint32(pkt[48:52], 99)
pkt[52] = 0x50
pkt[53] = 0x10 // ACK only
binary.BigEndian.PutUint16(pkt[54:56], 65535)
for i := 0; i < payLen; i++ {
pkt[ipLen+tcpLen+i] = byte(i)
}
return pkt
}
// TestDecodeReadFitsMaxTSOAtDrainThreshold proves the rxBuf sizing is
// correct: when rxOff is at the maximum value the drain headroom check
// allows, decodeRead must still be able to absorb a worst-case 64KiB
// TSO superpacket without dropping the burst. With segmentation deferred
// to encrypt time, decodeRead writes only the kernel-supplied bytes into
// rxBuf, so the size requirement is just "fit one worst-case input."
//
// Regression history: in a prior layout the rx buffer doubled as the
// segmentation output, a near-threshold drain read returned "scratch too
// small", the whole 45-segment TSO burst was dropped, and the remote's TCP
// fast-retransmit collapsed cwnd. Keeping this test in the new layout
// guards against re-introducing a drain headroom shortfall.
func TestDecodeReadFitsMaxTSOAtDrainThreshold(t *testing.T) {
const ipv6HdrLen = 40
const tcpHdrLen = 20
const headerLen = ipv6HdrLen + tcpHdrLen
// Maximum TUN read body. The tunReadBufSize cap on readv's body iovec
// is what bounds the kernel's superpacket length.
pktLen := tunReadBufSize
payLen := pktLen - headerLen
const targetSegs = 64
gsoSize := (payLen + targetSegs - 1) / targetSegs
pkt := buildTSOv6(payLen, gsoSize)
if len(pkt) != pktLen {
t.Fatalf("buildTSOv6 produced %d bytes, want %d", len(pkt), pktLen)
}
o := &Offload{
rxBuf: make([]byte, tunRxBufCap),
}
// rxOff at the maximum value the drain headroom check permits before
// it would refuse another read. Any drain-time read up to this
// threshold MUST still process correctly.
o.rxOff = tunRxBufCap - tunRxBufSize
// Stage the body in rxBuf as if readv(2) just placed it there.
copy(o.rxBuf[o.rxOff:], pkt)
// Encode the matching virtio_net_hdr.
hdr := virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
HdrLen: uint16(headerLen),
GSOSize: uint16(gsoSize),
CsumStart: uint16(ipv6HdrLen),
CsumOffset: 16,
}
hdr.Encode(o.readVnetScratch[:])
startRxOff := o.rxOff
if err := o.decodeRead(pktLen); err != nil {
t.Fatalf("decodeRead at drain threshold returned %v — rxBuf sizing regression: "+
"tunRxBufSize=%d must hold one worst-case input (%d)",
err, tunRxBufSize, pktLen)
}
if len(o.pending) != 1 {
t.Fatalf("got %d packets, want 1 superpacket entry", len(o.pending))
}
got := o.pending[0]
if !got.GSO.IsSuperpacket() {
t.Fatalf("expected superpacket GSO metadata, got %+v", got.GSO)
}
if got.GSO.Proto != GSOProtoTCP {
t.Errorf("GSO.Proto=%d want TCP", got.GSO.Proto)
}
if got.GSO.Size != uint16(gsoSize) {
t.Errorf("GSO.Size=%d want %d", got.GSO.Size, gsoSize)
}
if got.GSO.HdrLen != uint16(headerLen) {
t.Errorf("GSO.HdrLen=%d want %d", got.GSO.HdrLen, headerLen)
}
if got.GSO.CsumStart != uint16(ipv6HdrLen) {
t.Errorf("GSO.CsumStart=%d want %d", got.GSO.CsumStart, ipv6HdrLen)
}
if len(got.Bytes) != pktLen {
t.Errorf("len(Bytes)=%d want %d", len(got.Bytes), pktLen)
}
// rxOff advances exactly by the kernel-supplied body length — no
// segmentation output to account for any more.
if o.rxOff != startRxOff+pktLen {
t.Errorf("rxOff=%d want %d", o.rxOff, startRxOff+pktLen)
}
if o.rxOff > tunRxBufCap {
t.Fatalf("rxOff=%d overran rxBuf (cap=%d)", o.rxOff, tunRxBufCap)
}
// Validate that segmenting the returned superpacket reproduces the
// expected per-segment IPv6 payload length and TCP checksum.
wantSegs := (payLen + gsoSize - 1) / gsoSize
gotSegs := 0
if err := SegmentSuperpacket(got, func(seg []byte) error {
defer func() { gotSegs++ }()
if len(seg) < headerLen+1 {
t.Errorf("seg %d too short: %d", gotSegs, len(seg))
return nil
}
if seg[0]>>4 != 6 {
t.Errorf("seg %d: bad IP version %#x", gotSegs, seg[0])
}
segPay := len(seg) - headerLen
gotPL := binary.BigEndian.Uint16(seg[4:6])
if gotPL != uint16(tcpHdrLen+segPay) {
t.Errorf("seg %d: payload_len=%d want %d", gotSegs, gotPL, tcpHdrLen+segPay)
}
psum := pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_TCP, tcpHdrLen+segPay)
if !verifyChecksum(seg[ipv6HdrLen:], psum) {
t.Errorf("seg %d: bad TCP checksum", gotSegs)
}
return nil
}); err != nil {
t.Fatalf("SegmentSuperpacket: %v", err)
}
if gotSegs != wantSegs {
t.Fatalf("got %d segments, want %d", gotSegs, wantSegs)
}
}

View File

@@ -0,0 +1,43 @@
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
package virtio
import "encoding/binary"
// Size is the on-wire length of struct virtio_net_hdr the kernel
// prepends/expects on a TUN opened with IFF_VNET_HDR (TUNSETVNETHDRSZ
// not set).
const Size = 10
// Hdr is the Go view of the legacy virtio_net_hdr.
type Hdr struct {
Flags uint8
GSOType uint8
HdrLen uint16
GSOSize uint16
CsumStart uint16
CsumOffset uint16
}
// Decode reads a virtio_net_hdr in host byte order (TUN default; we never
// call TUNSETVNETLE so the kernel matches our endianness).
func (h *Hdr) Decode(b []byte) {
h.Flags = b[0]
h.GSOType = b[1]
h.HdrLen = binary.NativeEndian.Uint16(b[2:4])
h.GSOSize = binary.NativeEndian.Uint16(b[4:6])
h.CsumStart = binary.NativeEndian.Uint16(b[6:8])
h.CsumOffset = binary.NativeEndian.Uint16(b[8:10])
}
// Encode is the inverse of Decode: writes the virtio_net_hdr fields into b
// (must be at least Size bytes). Used to emit a TSO superpacket on egress.
func (h *Hdr) Encode(b []byte) {
b[0] = h.Flags
b[1] = h.GSOType
binary.NativeEndian.PutUint16(b[2:4], h.HdrLen)
binary.NativeEndian.PutUint16(b[4:6], h.GSOSize)
binary.NativeEndian.PutUint16(b[6:8], h.CsumStart)
binary.NativeEndian.PutUint16(b[8:10], h.CsumOffset)
}

View File

@@ -0,0 +1,401 @@
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
// Package virtio implements the pure validation, header-correction, and
// per-segment slicing logic for kernel-supplied TSO/USO superpackets on
// IFF_VNET_HDR TUN devices. It is FD-free and depends only on the byte
// layout of the virtio_net_hdr and the IP/TCP/UDP headers it describes,
// so it can be unit-tested in isolation from the tio Queue runtime.
package virtio
import (
"encoding/binary"
"errors"
"fmt"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/checksum"
)
// Protocol header size bounds used to validate / cap kernel-supplied offsets.
const (
ipv4HeaderMinLen = 20 // IHL=5, no options
ipv4HeaderMaxLen = 60 // IHL=15, max options
ipv6FixedLen = 40 // IPv6 base header; extensions would extend this
tcpHeaderMinLen = 20 // data-offset=5, no options
tcpHeaderMaxLen = 60 // data-offset=15, max options
)
// Byte offsets inside an IPv4 header.
const (
ipv4TotalLenOff = 2
ipv4IDOff = 4
ipv4ChecksumOff = 10
ipv4SrcOff = 12
ipv4AddrsEnd = 20 // end of dst address (ipv4SrcOff + 2*4)
)
// Byte offsets inside an IPv6 header.
const (
ipv6PayloadLenOff = 4
ipv6SrcOff = 8
ipv6AddrsEnd = 40 // end of dst address (ipv6SrcOff + 2*16)
)
// Byte offsets inside a TCP header (relative to its start, i.e. csumStart).
const (
tcpSeqOff = 4
tcpDataOffOff = 12 // upper nibble is header len in 32-bit words
tcpFlagsOff = 13
tcpChecksumOff = 16
)
// UDP header is fixed at 8 bytes: {sport, dport, length, checksum}.
const (
udpHeaderLen = 8
udpLengthOff = 4
udpChecksumOff = 6
)
// tcpFinPshMask is cleared on every segment except the last of a TSO burst.
const tcpFinPshMask = 0x09 // FIN(0x01) | PSH(0x08)
// tcpCwrFlag is cleared on every segment except the first. Per RFC 3168
// §6.1.2 the CWR bit signals a one-shot transition (the sender just halved
// its window) and must appear on the first segment of a TSO burst only.
const tcpCwrFlag = 0x80
// CheckValid rejects packets whose virtio_net_hdr/IP combination would
// cause a downstream miscompute. The TUN should never emit RSC_INFO and
// the GSO type must agree with the IP version nibble.
func CheckValid(pkt []byte, hdr Hdr) error {
// When RSC_INFO is set the csum_start/csum_offset fields are repurposed to
// carry coalescing info rather than checksum offsets. A TUN writing via
// IFF_VNET_HDR should never emit this, but if it did we would silently
// miscompute the segment checksums — refuse the packet instead.
if hdr.Flags&unix.VIRTIO_NET_HDR_F_RSC_INFO != 0 {
return fmt.Errorf("virtio RSC_INFO flag not supported on TUN reads")
}
if len(pkt) < ipv4HeaderMinLen {
return fmt.Errorf("packet too short")
}
ipVersion := pkt[0] >> 4
switch hdr.GSOType {
case unix.VIRTIO_NET_HDR_GSO_TCPV4:
if ipVersion != 4 {
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
}
case unix.VIRTIO_NET_HDR_GSO_TCPV6:
if ipVersion != 6 {
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
}
case unix.VIRTIO_NET_HDR_GSO_UDP_L4:
// USO carries either v4 or v6; the leading nibble disambiguates.
if !(ipVersion == 4 || ipVersion == 6) {
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
}
default:
if !(ipVersion == 6 || ipVersion == 4) {
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
}
}
return nil
}
// CorrectHdrLen rewrites hdr.HdrLen based on the actual transport header
// length read out of pkt. The kernel's hdr.HdrLen on the FORWARD path can
// be the length of the entire first packet, so we don't trust it.
func CorrectHdrLen(pkt []byte, hdr *Hdr) error {
// Thank you wireguard-go for documenting these edge-cases
// Don't trust hdr.hdrLen from the kernel as it can be equal to the length
// of the entire first packet when the kernel is handling it as part of a
// FORWARD path. Instead, parse the transport header length and add it onto
// csumStart, which is synonymous for IP header length.
if hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
hdr.HdrLen = hdr.CsumStart + 8
} else {
if len(pkt) <= int(hdr.CsumStart+tcpDataOffOff) {
return errors.New("packet is too short")
}
tcpHLen := uint16(pkt[hdr.CsumStart+tcpDataOffOff] >> 4 * 4)
if tcpHLen < 20 || tcpHLen > 60 {
// A TCP header must be between 20 and 60 bytes in length.
return fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
}
hdr.HdrLen = hdr.CsumStart + tcpHLen
}
if len(pkt) < int(hdr.HdrLen) {
return fmt.Errorf("length of packet (%d) < virtioNetHdr.HdrLen (%d)", len(pkt), hdr.HdrLen)
}
if hdr.HdrLen < hdr.CsumStart {
return fmt.Errorf("virtioNetHdr.HdrLen (%d) < virtioNetHdr.CsumStart (%d)", hdr.HdrLen, hdr.CsumStart)
}
cSumAt := int(hdr.CsumStart + hdr.CsumStart)
if cSumAt+1 >= len(pkt) {
return fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(pkt))
}
return nil
}
// SegmentTCP walks a TSO superpacket pkt, yielding each segment as a
// slice into pkt itself. Per-segment plaintext is laid out by sliding a
// freshly-patched copy of the L3+L4 header into pkt at offset i*gsoSize,
// where it sits immediately before that segment's payload chunk in the
// original buffer. The slide is destructive: iter i's header write overwrites
// the last hdrLen bytes of seg_{i-1}'s payload, which is dead by the time
// the next iteration begins. pkt is consumed by this call and must not be
// inspected by the caller after the final yield.
func SegmentTCP(pkt []byte, hdrLenU, csumStartU, gsoSizeU uint16, yield func(seg []byte) error) error {
if gsoSizeU == 0 {
return fmt.Errorf("gso_size is zero")
}
if csumStartU == 0 {
return fmt.Errorf("csum_start is zero")
}
headerLen := int(hdrLenU)
csumStart := int(csumStartU)
isV4 := pkt[0]>>4 == 4
tcpHdrLen := int(pkt[csumStart+tcpDataOffOff]>>4) * 4
payLen := len(pkt) - headerLen
gsoSize := int(gsoSizeU)
numSeg := (payLen + gsoSize - 1) / gsoSize
if numSeg == 0 {
numSeg = 1
}
origSeq := binary.BigEndian.Uint32(pkt[csumStart+tcpSeqOff : csumStart+tcpSeqOff+4])
origFlags := pkt[csumStart+tcpFlagsOff]
var tmp [tcpHeaderMaxLen]byte
copy(tmp[:tcpHdrLen], pkt[csumStart:headerLen])
tmp[tcpSeqOff], tmp[tcpSeqOff+1], tmp[tcpSeqOff+2], tmp[tcpSeqOff+3] = 0, 0, 0, 0
tmp[tcpFlagsOff] = 0
tmp[tcpChecksumOff], tmp[tcpChecksumOff+1] = 0, 0
baseTcpHdrSum := uint32(checksum.Checksum(tmp[:tcpHdrLen], 0))
var baseProtoSum uint32
if isV4 {
baseProtoSum = uint32(checksum.Checksum(pkt[ipv4SrcOff:ipv4AddrsEnd], 0))
} else {
baseProtoSum = uint32(checksum.Checksum(pkt[ipv6SrcOff:ipv6AddrsEnd], 0))
}
baseProtoSum += uint32(unix.IPPROTO_TCP)
var origIPID uint16
var baseIPHdrSum uint32
if isV4 {
origIPID = binary.BigEndian.Uint16(pkt[ipv4IDOff : ipv4IDOff+2])
ihl := int(pkt[0]&0x0f) * 4
if ihl < ipv4HeaderMinLen || ihl > csumStart {
return fmt.Errorf("bad IPv4 IHL: %d", ihl)
}
var ipTmp [ipv4HeaderMaxLen]byte
copy(ipTmp[:ihl], pkt[:ihl])
ipTmp[ipv4TotalLenOff], ipTmp[ipv4TotalLenOff+1] = 0, 0
ipTmp[ipv4IDOff], ipTmp[ipv4IDOff+1] = 0, 0
ipTmp[ipv4ChecksumOff], ipTmp[ipv4ChecksumOff+1] = 0, 0
baseIPHdrSum = uint32(checksum.Checksum(ipTmp[:ihl], 0))
}
for i := 0; i < numSeg; i++ {
segStart := i * gsoSize
segEnd := segStart + gsoSize
if segEnd > payLen {
segEnd = payLen
}
segPayLen := segEnd - segStart
segLen := headerLen + segPayLen
headerOff := i * gsoSize
// Slide the header into place immediately before this segment's
// payload. Iter 0's header is already at pkt[:headerLen]; for
// i ≥ 1 we copy from there. The constant-byte fields of pkt[:headerLen]
// survive iter 0's in-place patches (only seq/flags/cksum/totalLen/id
// are touched), and iter 0's stale variable-field values are
// overwritten by the per-segment patches below.
if i > 0 {
copy(pkt[headerOff:headerOff+headerLen], pkt[:headerLen])
}
seg := pkt[headerOff : headerOff+segLen]
segSeq := origSeq + uint32(segStart)
segFlags := origFlags
if i != 0 {
segFlags &^= tcpCwrFlag
}
if i != numSeg-1 {
segFlags &^= tcpFinPshMask
}
totalLen := segLen
if isV4 {
segID := origIPID + uint16(i)
binary.BigEndian.PutUint16(seg[ipv4TotalLenOff:ipv4TotalLenOff+2], uint16(totalLen))
binary.BigEndian.PutUint16(seg[ipv4IDOff:ipv4IDOff+2], segID)
ipSum := baseIPHdrSum + uint32(totalLen) + uint32(segID)
binary.BigEndian.PutUint16(seg[ipv4ChecksumOff:ipv4ChecksumOff+2], foldComplement(ipSum))
} else {
binary.BigEndian.PutUint16(seg[ipv6PayloadLenOff:ipv6PayloadLenOff+2], uint16(headerLen-ipv6FixedLen+segPayLen))
}
binary.BigEndian.PutUint32(seg[csumStart+tcpSeqOff:csumStart+tcpSeqOff+4], segSeq)
seg[csumStart+tcpFlagsOff] = segFlags
tcpLen := tcpHdrLen + segPayLen
// Payload bytes still live at their original offset in pkt. The
// header slide above only writes into pkt[i*G : i*G+H], which is
// the tail of seg_{i-1}'s payload (already consumed) and never
// overlaps seg_i's own payload at pkt[H+i*G : H+(i+1)*G].
paySum := uint32(checksum.Checksum(pkt[headerLen+segStart:headerLen+segEnd], 0))
wide := uint64(baseTcpHdrSum) + uint64(paySum) + uint64(baseProtoSum)
wide += uint64(segSeq) + uint64(segFlags) + uint64(tcpLen)
wide = (wide & 0xffffffff) + (wide >> 32)
wide = (wide & 0xffffffff) + (wide >> 32)
binary.BigEndian.PutUint16(seg[csumStart+tcpChecksumOff:csumStart+tcpChecksumOff+2], foldComplement(uint32(wide)))
if err := yield(seg); err != nil {
return err
}
}
return nil
}
// SegmentUDP walks a USO superpacket, sliding a per-segment-patched
// L3+L4 header into pkt at offset i*gsoSize and yielding pkt[i*G:i*G+segLen]
// to the caller. Per-segment patches are total_len + IPv4 csum (or IPv6
// payload_len) plus the UDP length and checksum. pkt is consumed
// destructively; see SegmentTCP for the layout reasoning.
//
// UDP-GSO leaves the IPv4 ID identical across segments (the kernel does not
// bump it), which is why the IP-level per-segment work is limited to
// total_len + IPv4 header checksum (v4) or payload_len (v6).
func SegmentUDP(pkt []byte, hdrLenU, csumStartU, gsoSizeU uint16, yield func(seg []byte) error) error {
if gsoSizeU == 0 {
return fmt.Errorf("gso_size is zero")
}
if csumStartU == 0 {
return fmt.Errorf("csum_start is zero")
}
isV4 := pkt[0]>>4 == 4
headerLen := int(hdrLenU)
csumStart := int(csumStartU)
if headerLen-csumStart != udpHeaderLen {
return fmt.Errorf("udp header len mismatch: %d", headerLen-csumStart)
}
payLen := len(pkt) - headerLen
gsoSize := int(gsoSizeU)
numSeg := (payLen + gsoSize - 1) / gsoSize
if numSeg == 0 {
numSeg = 1
}
var udpTmp [udpHeaderLen]byte
copy(udpTmp[:], pkt[csumStart:headerLen])
udpTmp[udpLengthOff], udpTmp[udpLengthOff+1] = 0, 0
udpTmp[udpChecksumOff], udpTmp[udpChecksumOff+1] = 0, 0
baseUDPHdrSum := uint32(checksum.Checksum(udpTmp[:], 0))
var baseProtoSum uint32
if isV4 {
baseProtoSum = uint32(checksum.Checksum(pkt[ipv4SrcOff:ipv4AddrsEnd], 0))
} else {
baseProtoSum = uint32(checksum.Checksum(pkt[ipv6SrcOff:ipv6AddrsEnd], 0))
}
baseProtoSum += uint32(unix.IPPROTO_UDP)
var baseIPHdrSum uint32
if isV4 {
ihl := int(pkt[0]&0x0f) * 4
if ihl < ipv4HeaderMinLen || ihl > csumStart {
return fmt.Errorf("bad IPv4 IHL: %d", ihl)
}
var ipTmp [ipv4HeaderMaxLen]byte
copy(ipTmp[:ihl], pkt[:ihl])
ipTmp[ipv4TotalLenOff], ipTmp[ipv4TotalLenOff+1] = 0, 0
ipTmp[ipv4ChecksumOff], ipTmp[ipv4ChecksumOff+1] = 0, 0
baseIPHdrSum = uint32(checksum.Checksum(ipTmp[:ihl], 0))
}
for i := 0; i < numSeg; i++ {
segStart := i * gsoSize
segEnd := segStart + gsoSize
if segEnd > payLen {
segEnd = payLen
}
segPayLen := segEnd - segStart
segLen := headerLen + segPayLen
headerOff := i * gsoSize
if i > 0 {
copy(pkt[headerOff:headerOff+headerLen], pkt[:headerLen])
}
seg := pkt[headerOff : headerOff+segLen]
totalLen := segLen
udpLen := udpHeaderLen + segPayLen
if isV4 {
binary.BigEndian.PutUint16(seg[ipv4TotalLenOff:ipv4TotalLenOff+2], uint16(totalLen))
ipSum := baseIPHdrSum + uint32(totalLen)
binary.BigEndian.PutUint16(seg[ipv4ChecksumOff:ipv4ChecksumOff+2], foldComplement(ipSum))
} else {
binary.BigEndian.PutUint16(seg[ipv6PayloadLenOff:ipv6PayloadLenOff+2], uint16(headerLen-ipv6FixedLen+segPayLen))
}
binary.BigEndian.PutUint16(seg[csumStart+udpLengthOff:csumStart+udpLengthOff+2], uint16(udpLen))
paySum := uint32(checksum.Checksum(pkt[headerLen+segStart:headerLen+segEnd], 0))
wide := uint64(baseUDPHdrSum) + uint64(paySum) + uint64(baseProtoSum)
wide += uint64(udpLen) + uint64(udpLen)
wide = (wide & 0xffffffff) + (wide >> 32)
wide = (wide & 0xffffffff) + (wide >> 32)
csum := foldComplement(uint32(wide))
if csum == 0 {
csum = 0xffff
}
binary.BigEndian.PutUint16(seg[csumStart+udpChecksumOff:csumStart+udpChecksumOff+2], csum)
if err := yield(seg); err != nil {
return err
}
}
return nil
}
// FinishChecksum computes the L4 checksum for a non-GSO packet that the kernel
// handed us with NEEDS_CSUM set. csum_start / csum_offset point at the 16-bit
// checksum field; we zero it, fold a full sum (the field was pre-loaded with
// the pseudo-header partial sum by the kernel), and store the result.
func FinishChecksum(seg []byte, hdr Hdr) error {
cs := int(hdr.CsumStart)
co := int(hdr.CsumOffset)
if cs+co+2 > len(seg) {
return fmt.Errorf("csum offsets out of range: start=%d offset=%d len=%d", cs, co, len(seg))
}
// The kernel stores a partial pseudo-header sum at [cs+co:]; sum over the
// L4 region starting at cs, folding the prior partial in as the seed.
partial := binary.BigEndian.Uint16(seg[cs+co : cs+co+2])
seg[cs+co] = 0
seg[cs+co+1] = 0
binary.BigEndian.PutUint16(seg[cs+co:cs+co+2], ^checksum.Checksum(seg[cs:], partial))
return nil
}
// foldComplement folds a 32-bit one's-complement partial sum to 16 bits and
// complements it, yielding the on-wire Internet checksum value.
func foldComplement(sum uint32) uint16 {
sum = (sum & 0xffff) + (sum >> 16)
sum = (sum & 0xffff) + (sum >> 16)
return ^uint16(sum)
}