mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-16 04:47:38 +02:00
GSO/GRO offloads, with TCP+ECN and UDP support
This commit is contained in:
79
overlay/tio/queueset_gso_linux.go
Normal file
79
overlay/tio/queueset_gso_linux.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package tio
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type offloadQueueSet struct {
|
||||
pq []*Offload
|
||||
// pqi is exactly the same as pq, but stored as the interface type
|
||||
pqi []Queue
|
||||
shutdownFd int
|
||||
// usoEnabled is true when newTun successfully negotiated TUN_F_USO4|6
|
||||
// with the kernel. Queues created by Add inherit this and surface it
|
||||
// via Offload.USOSupported so coalescers can gate USO emission.
|
||||
usoEnabled bool
|
||||
}
|
||||
|
||||
// NewOffloadQueueSet creates a QueueSet that uses virtio_net_hdr to do
|
||||
// TSO segmentation in userspace. usoEnabled tells downstream queues whether
|
||||
// the kernel agreed to deliver/accept GSO_UDP_L4 superpackets — coalescers
|
||||
// should fall back to per-packet writes when this is false.
|
||||
func NewOffloadQueueSet(usoEnabled bool) (QueueSet, error) {
|
||||
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create eventfd: %w", err)
|
||||
}
|
||||
|
||||
out := &offloadQueueSet{
|
||||
pq: []*Offload{},
|
||||
pqi: []Queue{},
|
||||
shutdownFd: shutdownFd,
|
||||
usoEnabled: usoEnabled,
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *offloadQueueSet) Queues() []Queue {
|
||||
return c.pqi
|
||||
}
|
||||
|
||||
func (c *offloadQueueSet) Add(fd int) error {
|
||||
x, err := newOffload(fd, c.shutdownFd, c.usoEnabled)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.pq = append(c.pq, x)
|
||||
c.pqi = append(c.pqi, x)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *offloadQueueSet) wakeForShutdown() error {
|
||||
var buf [8]byte
|
||||
binary.NativeEndian.PutUint64(buf[:], 1)
|
||||
_, err := unix.Write(c.shutdownFd, buf[:])
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *offloadQueueSet) Close() error {
|
||||
errs := []error{}
|
||||
|
||||
// Signal all readers blocked in poll to wake up and exit
|
||||
if err := c.wakeForShutdown(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
for _, x := range c.pq {
|
||||
if err := x.Close(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
@@ -8,20 +8,20 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type pollContainer struct {
|
||||
type pollQueueSet struct {
|
||||
pq []*Poll
|
||||
// pqi is exactly the same as pq, but stored as the interface type
|
||||
pqi []Queue
|
||||
shutdownFd int
|
||||
}
|
||||
|
||||
func NewPollContainer() (Container, error) {
|
||||
func NewPollQueueSet() (QueueSet, error) {
|
||||
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create eventfd: %w", err)
|
||||
}
|
||||
|
||||
out := &pollContainer{
|
||||
out := &pollQueueSet{
|
||||
pq: []*Poll{},
|
||||
pqi: []Queue{},
|
||||
shutdownFd: shutdownFd,
|
||||
@@ -30,11 +30,11 @@ func NewPollContainer() (Container, error) {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *pollContainer) Queues() []Queue {
|
||||
func (c *pollQueueSet) Queues() []Queue {
|
||||
return c.pqi
|
||||
}
|
||||
|
||||
func (c *pollContainer) Add(fd int) error {
|
||||
func (c *pollQueueSet) Add(fd int) error {
|
||||
x, err := newPoll(fd, c.shutdownFd)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -45,14 +45,14 @@ func (c *pollContainer) Add(fd int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *pollContainer) wakeForShutdown() error {
|
||||
func (c *pollQueueSet) wakeForShutdown() error {
|
||||
var buf [8]byte
|
||||
binary.NativeEndian.PutUint64(buf[:], 1)
|
||||
_, err := unix.Write(int(c.shutdownFd), buf[:])
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *pollContainer) Close() error {
|
||||
func (c *pollQueueSet) Close() error {
|
||||
errs := []error{}
|
||||
|
||||
if err := c.wakeForShutdown(); err != nil {
|
||||
65
overlay/tio/segment_bench_test.go
Normal file
65
overlay/tio/segment_bench_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
//go:build linux && !android && !e2e_testing
|
||||
|
||||
package tio
|
||||
|
||||
import "testing"
|
||||
|
||||
// fakeBatch stands in for batch.TxBatcher inside the bench — same shape
|
||||
// of pointer-capturing closure that sendInsideMessage builds.
|
||||
type fakeBatch struct{ buf [65536]byte }
|
||||
|
||||
func (b *fakeBatch) Reserve(sz int) []byte { return b.buf[:sz] }
|
||||
func (b *fakeBatch) Commit([]byte) {}
|
||||
|
||||
type fakeHostInfo struct {
|
||||
remoteIndexId uint32
|
||||
counter uint64
|
||||
}
|
||||
type fakeIface struct {
|
||||
rebindCount uint8
|
||||
hi *fakeHostInfo
|
||||
}
|
||||
|
||||
// BenchmarkSegmentSuperpacketAllocsTSO measures allocation per
|
||||
// SegmentSuperpacket call when a closure captures pointer-bearing
|
||||
// receivers — the realistic shape of sendInsideMessage's closure.
|
||||
func BenchmarkSegmentSuperpacketAllocsTSO(b *testing.B) {
|
||||
const mss = 1400
|
||||
const numSeg = 32
|
||||
pkt := buildTSOv6(mss*numSeg, mss)
|
||||
gso := GSOInfo{
|
||||
Size: mss,
|
||||
HdrLen: 60, // 40 (IPv6) + 20 (TCP)
|
||||
CsumStart: 40,
|
||||
Proto: GSOProtoTCP,
|
||||
}
|
||||
p := Packet{Bytes: pkt, GSO: gso}
|
||||
|
||||
hi := &fakeHostInfo{remoteIndexId: 0xdeadbeef}
|
||||
f := &fakeIface{rebindCount: 7, hi: hi}
|
||||
fb := &fakeBatch{}
|
||||
|
||||
// SegmentSuperpacket consumes pkt destructively; refresh from a master
|
||||
// copy each iter (matches the production pattern where every TUN read
|
||||
// hands the segmenter a fresh kernel-supplied buffer).
|
||||
master := append([]byte(nil), pkt...)
|
||||
work := make([]byte, len(pkt))
|
||||
p.Bytes = work
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
copy(work, master)
|
||||
err := SegmentSuperpacket(p, func(seg []byte) error {
|
||||
out := fb.Reserve(16 + len(seg) + 16)
|
||||
out[0] = byte(f.rebindCount)
|
||||
out[1] = byte(hi.counter)
|
||||
hi.counter++
|
||||
fb.Commit(out)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatalf("SegmentSuperpacket: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
18
overlay/tio/segment_other.go
Normal file
18
overlay/tio/segment_other.go
Normal file
@@ -0,0 +1,18 @@
|
||||
//go:build !linux || android || e2e_testing
|
||||
|
||||
package tio
|
||||
|
||||
import "fmt"
|
||||
|
||||
// SegmentSuperpacket invokes fn once per segment of pkt. On non-Linux
|
||||
// builds (and Android/e2e_testing) this package does not provide a Queue
|
||||
// implementation, so any caller that does construct a Packet here can only
|
||||
// be operating on non-superpacket bytes and the stub forwards them
|
||||
// directly. A non-zero GSO field is a programming error from the caller
|
||||
// and returns an explicit error rather than silently misbehaving.
|
||||
func SegmentSuperpacket(pkt Packet, scratch []byte, fn func(seg []byte) error) error {
|
||||
if pkt.GSO.IsSuperpacket() {
|
||||
return fmt.Errorf("tio: GSO superpacket on platform without segmentation support")
|
||||
}
|
||||
return fn(pkt.Bytes)
|
||||
}
|
||||
@@ -1,56 +1,170 @@
|
||||
package tio
|
||||
|
||||
import "io"
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
// defaultBatchBufSize is the per-Queue scratch size for Read on backends
|
||||
// that don't do TSO segmentation. 65535 covers any single IP packet.
|
||||
const defaultBatchBufSize = 65535
|
||||
|
||||
// Container holds one or many Queue objects and helps close them in an orderly way
|
||||
type Container interface {
|
||||
// QueueSet holds one or many Queue objects and helps close them in an orderly way.
|
||||
type QueueSet interface {
|
||||
io.Closer
|
||||
Queues() []Queue
|
||||
|
||||
// Add takes a tun fd, adds it to the container, and prepares it for use as a Queue
|
||||
// Add takes a tun fd, adds it to the set, and prepares it for use as a Queue.
|
||||
Add(fd int) error
|
||||
}
|
||||
|
||||
// Capabilities advertises which kernel offload features a Queue
|
||||
// successfully negotiated. Callers consult this to decide which coalescers
|
||||
// to wire onto the write path — a Queue without TSO can't usefully accept a
|
||||
// TCPCoalescer, and a Queue without USO can't accept a UDPCoalescer.
|
||||
type Capabilities struct {
|
||||
// TSO means the FD was opened with IFF_VNET_HDR and the kernel agreed
|
||||
// to TUN_F_TSO4|TSO6 — i.e. WriteGSO with GSOProtoTCP is safe.
|
||||
TSO bool
|
||||
// USO means the kernel additionally agreed to TUN_F_USO4|USO6, so
|
||||
// WriteGSO with GSOProtoUDP is safe. Linux ≥ 6.2.
|
||||
USO bool
|
||||
}
|
||||
|
||||
// Queue is a readable/writable Poll queue. One Queue is driven by a single
|
||||
// read goroutine plus concurrent writers (see Write / WriteReject below).
|
||||
// read goroutine plus a single writer (see Write below).
|
||||
type Queue interface {
|
||||
io.Closer
|
||||
|
||||
// Read returns one or more packets. The returned slices are borrowed
|
||||
// from the Queue's internal buffer and are only valid until the next
|
||||
// Read or Close on this Queue - callers must encrypt or copy each
|
||||
// slice before the next call. Not safe for concurrent Reads.
|
||||
Read() ([][]byte, error)
|
||||
// Read returns one or more packets. The returned Packet.Bytes slices
|
||||
// are borrowed from the Queue's internal buffer and are only valid
|
||||
// until the next Read or Close on this Queue - callers must encrypt
|
||||
// or copy each slice before the next call. A Packet may carry a
|
||||
// GSO/USO superpacket (see GSOInfo); when GSO.IsSuperpacket() is
|
||||
// true the caller must segment Bytes before treating it as a single
|
||||
// IP datagram. Not safe for concurrent Reads.
|
||||
Read() ([]Packet, error)
|
||||
|
||||
// Write emits a single packet on the plaintext (outside→inside)
|
||||
// delivery path. Not safe for concurrent Writes.
|
||||
Write(p []byte) (int, error)
|
||||
}
|
||||
|
||||
// GSOWriter is implemented by Queues that can emit a TCP TSO superpacket
|
||||
// Packet is the unit Queue.Read returns. Bytes points into the queue's
|
||||
// internal buffer and is only valid until the next Read or Close on the
|
||||
// queue that produced it. GSO is the zero value for an already-segmented
|
||||
// IP datagram; when non-zero it describes a kernel-supplied TSO/USO
|
||||
// superpacket the caller must segment before consuming.
|
||||
type Packet struct {
|
||||
Bytes []byte
|
||||
GSO GSOInfo
|
||||
}
|
||||
|
||||
// GSOInfo describes a kernel-supplied superpacket sitting in Packet.Bytes.
|
||||
// The zero value means "not a superpacket" — Bytes is one regular IP
|
||||
// datagram and no segmentation is required.
|
||||
type GSOInfo struct {
|
||||
// Size is the GSO segment size: max payload bytes per segment
|
||||
// (== TCP MSS for TSO, == UDP payload chunk for USO). Zero means
|
||||
// not a superpacket.
|
||||
Size uint16
|
||||
// HdrLen is the total L3+L4 header length within Bytes (already
|
||||
// corrected via correctHdrLen, so safe to slice on).
|
||||
HdrLen uint16
|
||||
// CsumStart is the L4 header offset inside Bytes (== L3 header
|
||||
// length).
|
||||
CsumStart uint16
|
||||
// Proto picks the L4 protocol (TCP or UDP) so the segmenter knows
|
||||
// which checksum/header layout to apply.
|
||||
Proto GSOProto
|
||||
}
|
||||
|
||||
// IsSuperpacket reports whether g describes a multi-segment GSO/USO
|
||||
// superpacket that needs segmentation before its bytes can be encrypted
|
||||
// and sent on the wire.
|
||||
func (g GSOInfo) IsSuperpacket() bool { return g.Size > 0 }
|
||||
|
||||
// Clone returns a Packet whose Bytes is a freshly allocated copy of p.Bytes,
|
||||
// safe to retain past the next Read or Close on the originating Queue.
|
||||
// GSO metadata is copied verbatim. Use this only when a caller genuinely
|
||||
// needs to outlive the borrowed-slice contract — the hot path reads should
|
||||
// continue to consume the borrow synchronously to avoid the allocation.
|
||||
func (p Packet) Clone() Packet {
|
||||
if p.Bytes == nil {
|
||||
return p
|
||||
}
|
||||
cp := make([]byte, len(p.Bytes))
|
||||
copy(cp, p.Bytes)
|
||||
return Packet{Bytes: cp, GSO: p.GSO}
|
||||
}
|
||||
|
||||
// CapsProvider is an optional interface implemented by Queues that
|
||||
// successfully negotiated kernel offload features at open time. Callers
|
||||
// pick a write-path coalescer based on the result. Queues that don't
|
||||
// implement it are treated as having no offload capability — callers must
|
||||
// fall back to plain per-packet writes.
|
||||
type CapsProvider interface {
|
||||
Capabilities() Capabilities
|
||||
}
|
||||
|
||||
// QueueCapabilities returns q's negotiated offload capabilities, or the
|
||||
// zero value when q does not advertise any.
|
||||
func QueueCapabilities(q Queue) Capabilities {
|
||||
if cp, ok := q.(CapsProvider); ok {
|
||||
return cp.Capabilities()
|
||||
}
|
||||
return Capabilities{}
|
||||
}
|
||||
|
||||
// GSOProto selects the L4 protocol for a GSO superpacket. Determines which
|
||||
// VIRTIO_NET_HDR_GSO_* type the writer stamps and which checksum offset
|
||||
// inside the transport header virtio NEEDS_CSUM expects.
|
||||
type GSOProto uint8
|
||||
|
||||
const (
|
||||
GSOProtoTCP GSOProto = iota
|
||||
GSOProtoUDP
|
||||
)
|
||||
|
||||
// GSOWriter is implemented by Queues that can emit a TCP or UDP superpacket
|
||||
// assembled from a header prefix plus one or more borrowed payload
|
||||
// fragments, in a single vectored write (writev with a leading
|
||||
// virtio_net_hdr). This lets the coalescer avoid copying payload bytes
|
||||
// between the caller's decrypt buffer and the TUN. Backends without GSO
|
||||
// support return false from GSOSupported and coalescing is skipped.
|
||||
// support do not implement this interface and coalescing is skipped.
|
||||
//
|
||||
// hdr contains the IPv4/IPv6 + TCP header prefix (mutable - callers will
|
||||
// have filled in total length and pseudo-header partial). pays are
|
||||
// non-overlapping payload fragments whose concatenation is the full
|
||||
// superpacket payload; they are read-only from the writer's perspective
|
||||
// and must remain valid until the call returns. gsoSize is the MSS:
|
||||
// every segment except possibly the last is exactly that many bytes.
|
||||
// csumStart is the byte offset where the TCP header begins within hdr.
|
||||
// hdr contains the IPv4/IPv6 header prefix (mutable - callers will have
|
||||
// filled in total length and IP csum). transportHdr is the TCP or UDP
|
||||
// header (mutable - the L4 checksum field must hold the pseudo-header
|
||||
// partial, single-fold not inverted, per virtio NEEDS_CSUM semantics).
|
||||
// pays are non-overlapping payload fragments whose concatenation is the
|
||||
// full superpacket payload; they are read-only from the writer's
|
||||
// perspective and must remain valid until the call returns. Every segment
|
||||
// in pays except possibly the last is exactly the same size. proto picks
|
||||
// the L4 protocol so the writer knows which GSOType / CsumOffset to set.
|
||||
//
|
||||
// # TODO fold into Queue
|
||||
//
|
||||
// hdr's TCP checksum field must already hold the pseudo-header partial
|
||||
// sum (single-fold, not inverted), per virtio NEEDS_CSUM semantics.
|
||||
// Callers should also consult CapsProvider (via SupportsGSO or
|
||||
// QueueCapabilities) for the per-protocol negotiated capability; an
|
||||
// implementation of GSOWriter is necessary but not sufficient since USO
|
||||
// may not have been negotiated even when TSO was.
|
||||
type GSOWriter interface {
|
||||
WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error
|
||||
GSOSupported() bool
|
||||
WriteGSO(hdr []byte, transportHdr []byte, pays [][]byte, proto GSOProto) error
|
||||
}
|
||||
|
||||
// SupportsGSO reports whether w implements GSOWriter and the underlying
|
||||
// queue advertises the negotiated capability for `want`. A writer that
|
||||
// implements GSOWriter but not CapsProvider is treated as permissive
|
||||
// (used by tests and fakes that don't negotiate).
|
||||
func SupportsGSO(w any, want GSOProto) (GSOWriter, bool) {
|
||||
gw, ok := w.(GSOWriter)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
cp, ok := w.(CapsProvider)
|
||||
if !ok {
|
||||
return gw, true
|
||||
}
|
||||
caps := cp.Capabilities()
|
||||
switch want {
|
||||
case GSOProtoTCP:
|
||||
return gw, caps.TSO
|
||||
case GSOProtoUDP:
|
||||
return gw, caps.USO
|
||||
}
|
||||
return gw, false
|
||||
}
|
||||
|
||||
461
overlay/tio/tio_gso_linux.go
Normal file
461
overlay/tio/tio_gso_linux.go
Normal file
@@ -0,0 +1,461 @@
|
||||
package tio
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/slackhq/nebula/overlay/tio/virtio"
|
||||
)
|
||||
|
||||
// tunRxBufSize is the per-Read worst-case footprint inside rxBuf: one
|
||||
// kernel-supplied packet body, which is at most ~64 KiB (tunReadBufSize).
|
||||
// Segmentation happens at encrypt time on a per-routine MTU-sized scratch
|
||||
// (see SegmentSuperpacket), so rxBuf only holds raw kernel-supplied bytes.
|
||||
// We round up to give comfortable margin for the drain headroom check
|
||||
// below.
|
||||
const tunRxBufSize = 64 * 1024
|
||||
|
||||
// tunRxBufCap is the total size we allocate for the per-reader rx
|
||||
// buffer. With reads landing directly in rxBuf, each drain iteration
|
||||
// consumes up to tunRxBufSize of headroom for the kernel-supplied bytes.
|
||||
// Sized to two such iterations so the initial blocking read plus one
|
||||
// drain read both fit without partial-drop.
|
||||
const tunRxBufCap = tunRxBufSize * 2
|
||||
|
||||
// tunDrainCap caps how many packets a single Read will accumulate via
|
||||
// the post-wake drain loop. Sized to soak up a burst of small ACKs while
|
||||
// bounding how much work a single caller holds before handing off.
|
||||
const tunDrainCap = 64
|
||||
|
||||
// gsoMaxIovs caps the iovec budget WriteGSO assembles per call: 3 fixed
|
||||
// entries (virtio_net_hdr, IP hdr, transport hdr) plus up to gsoMaxIovs-3
|
||||
// payload fragments. Sized comfortably above the typical kernel GSO
|
||||
// segment cap (Linux UDP_GRO is 64) so realistic coalesced bursts never
|
||||
// touch the limit. iovecs are tiny (16 bytes), so the entire scratch is
|
||||
// 4 KiB — fine to keep resident on every queue. WriteGSO returns an error
|
||||
// rather than reallocating when a caller exceeds this budget.
|
||||
const gsoMaxIovs = 256
|
||||
|
||||
// validVnetHdr is the 10-byte virtio_net_hdr we prepend to every non-GSO TUN
|
||||
// write. Only flag set is VIRTIO_NET_HDR_F_DATA_VALID, which marks the skb
|
||||
// CHECKSUM_UNNECESSARY so the receiving network stack skips L4 checksum
|
||||
// verification. All packets that reach the plain Write paths already carry
|
||||
// a valid L4 checksum (either supplied by a remote peer whose ciphertext we
|
||||
// AEAD-authenticated, produced by segmentTCPYield/segmentUDPYield during
|
||||
// superpacket segmentation, or built locally by CreateRejectPacket), so
|
||||
// trusting them is safe.
|
||||
var validVnetHdr = [virtio.Size]byte{unix.VIRTIO_NET_HDR_F_DATA_VALID}
|
||||
|
||||
// Offload wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking.
|
||||
// A shared eventfd allows Close to wake all readers blocked in poll.
|
||||
type Offload struct {
|
||||
fd int
|
||||
shutdownFd int
|
||||
readPoll [2]unix.PollFd
|
||||
writePoll [2]unix.PollFd
|
||||
// writeLock serializes blockOnWrite's read+clear of writePoll[*].Revents.
|
||||
// Any goroutine that calls Write may end up parked in poll(2); without
|
||||
// the lock concurrent waiters could race the Revents reset and lose
|
||||
// events.
|
||||
writeLock sync.Mutex
|
||||
closed atomic.Bool
|
||||
rxBuf []byte // backing store for kernel-handed packets read this drain
|
||||
rxOff int // cursor into rxBuf for the current Read drain
|
||||
pending []Packet // packets returned from the most recent Read
|
||||
|
||||
// readVnetScratch holds the 10-byte virtio_net_hdr split off the front of
|
||||
// every TUN read via readv(2). Decoupling the header from the packet body
|
||||
// lets us read the body directly into rxBuf at the current rxOff with
|
||||
// no userspace copy on the GSO_NONE fast path.
|
||||
readVnetScratch [virtio.Size]byte
|
||||
// readIovs is the readv(2) iovec scratch wired once at construction —
|
||||
// iovec[0] points at readVnetScratch; iovec[1].Base/Len is updated per
|
||||
// read to address the current rxBuf slot.
|
||||
readIovs [2]unix.Iovec
|
||||
|
||||
// usoEnabled records whether the kernel agreed to TUN_F_USO* on this FD,
|
||||
// so writers can decide whether emitting GSO_UDP_L4 superpackets is safe.
|
||||
usoEnabled bool
|
||||
|
||||
// gsoHdrBuf is a per-queue 10-byte scratch for the virtio_net_hdr emitted
|
||||
// by WriteGSO. Kept separate from the read-only package-level validVnetHdr
|
||||
// so non-GSO Writes can ship that constant directly while WriteGSO
|
||||
// rewrites this scratch on every call.
|
||||
gsoHdrBuf [virtio.Size]byte
|
||||
// gsoIovs is the writev iovec scratch for WriteGSO. Pre-sized to
|
||||
// gsoMaxIovs at construction; never grown. WriteGSO returns an error
|
||||
// (and drops the call) if a caller hands it more fragments than fit.
|
||||
gsoIovs []unix.Iovec
|
||||
}
|
||||
|
||||
func newOffload(fd int, shutdownFd int, usoEnabled bool) (*Offload, error) {
|
||||
if err := unix.SetNonblock(fd, true); err != nil {
|
||||
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
|
||||
}
|
||||
|
||||
out := &Offload{
|
||||
fd: fd,
|
||||
shutdownFd: shutdownFd,
|
||||
usoEnabled: usoEnabled,
|
||||
closed: atomic.Bool{},
|
||||
readPoll: [2]unix.PollFd{
|
||||
{Fd: int32(fd), Events: unix.POLLIN},
|
||||
{Fd: int32(shutdownFd), Events: unix.POLLIN},
|
||||
},
|
||||
writePoll: [2]unix.PollFd{
|
||||
{Fd: int32(fd), Events: unix.POLLOUT},
|
||||
{Fd: int32(shutdownFd), Events: unix.POLLIN},
|
||||
},
|
||||
writeLock: sync.Mutex{},
|
||||
|
||||
rxBuf: make([]byte, tunRxBufCap),
|
||||
gsoIovs: make([]unix.Iovec, 2, gsoMaxIovs),
|
||||
}
|
||||
|
||||
out.gsoIovs[0].Base = &out.gsoHdrBuf[0]
|
||||
out.gsoIovs[0].SetLen(virtio.Size)
|
||||
|
||||
// readIovs[0] is wired once to the virtio_net_hdr scratch; per-read we
|
||||
// only repoint readIovs[1] at the next rxBuf slot (see readPacket).
|
||||
out.readIovs[0].Base = &out.readVnetScratch[0]
|
||||
out.readIovs[0].SetLen(virtio.Size)
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *Offload) blockOnRead() error {
|
||||
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
|
||||
var err error
|
||||
for {
|
||||
_, err = unix.Poll(r.readPoll[:], -1)
|
||||
if err != unix.EINTR {
|
||||
break
|
||||
}
|
||||
}
|
||||
//always reset these!
|
||||
tunEvents := r.readPoll[0].Revents
|
||||
shutdownEvents := r.readPoll[1].Revents
|
||||
r.readPoll[0].Revents = 0
|
||||
r.readPoll[1].Revents = 0
|
||||
//do the err check before trusting the potentially bogus bits we just got
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
|
||||
return os.ErrClosed
|
||||
} else if tunEvents&problemFlags != 0 {
|
||||
return os.ErrClosed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Offload) blockOnWrite() error {
|
||||
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
|
||||
var err error
|
||||
for {
|
||||
_, err = unix.Poll(r.writePoll[:], -1)
|
||||
if err != unix.EINTR {
|
||||
break
|
||||
}
|
||||
}
|
||||
//always reset these!
|
||||
r.writeLock.Lock()
|
||||
tunEvents := r.writePoll[0].Revents
|
||||
shutdownEvents := r.writePoll[1].Revents
|
||||
r.writePoll[0].Revents = 0
|
||||
r.writePoll[1].Revents = 0
|
||||
r.writeLock.Unlock()
|
||||
//do the err check before trusting the potentially bogus bits we just got
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
|
||||
return os.ErrClosed
|
||||
} else if tunEvents&problemFlags != 0 {
|
||||
return os.ErrClosed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// readPacket issues a single readv(2) splitting the virtio_net_hdr off
|
||||
// into readVnetScratch and reading the packet body directly into rxBuf at
|
||||
// the current rxOff. Returns the body length (zero virtio header bytes,
|
||||
// just the IP packet/superpacket). block controls whether EAGAIN is
|
||||
// retried via poll: the initial read of a drain blocks; subsequent drain
|
||||
// reads do not.
|
||||
//
|
||||
// The body iovec capacity is always tunReadBufSize; callers (the Read
|
||||
// drain loop) gate entry on tunRxBufCap-rxOff >= tunRxBufSize, sized to
|
||||
// hold one worst-case kernel-supplied packet body. Without that gate the
|
||||
// body iovec could be smaller than the next inbound packet and the
|
||||
// kernel would truncate.
|
||||
func (r *Offload) readPacket(block bool) (int, error) {
|
||||
for {
|
||||
r.readIovs[1].Base = &r.rxBuf[r.rxOff]
|
||||
r.readIovs[1].SetLen(tunReadBufSize)
|
||||
n, _, errno := syscall.Syscall(unix.SYS_READV, uintptr(r.fd), uintptr(unsafe.Pointer(&r.readIovs[0])), uintptr(len(r.readIovs)))
|
||||
if errno == 0 {
|
||||
if int(n) < virtio.Size {
|
||||
return 0, io.ErrShortWrite
|
||||
}
|
||||
return int(n) - virtio.Size, nil
|
||||
}
|
||||
if errno == unix.EAGAIN {
|
||||
if !block {
|
||||
return 0, errno
|
||||
}
|
||||
if err := r.blockOnRead(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if errno == unix.EINTR {
|
||||
continue
|
||||
}
|
||||
if errno == unix.EBADF {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
return 0, errno
|
||||
}
|
||||
}
|
||||
|
||||
// Read returns one or more packets from the tun. Each Packet either
|
||||
// carries a single ready-to-use IP datagram (GSO zero) or a TSO/USO
|
||||
// superpacket plus the GSOInfo a caller needs to segment it (see
|
||||
// SegmentSuperpacket). The first read blocks via poll; once the fd is
|
||||
// known readable we drain additional packets non-blocking until the
|
||||
// kernel queue is empty (EAGAIN), we've collected tunDrainCap packets,
|
||||
// or we're out of rxBuf headroom. This amortizes the poll wake over
|
||||
// bursts of small packets (e.g. TCP ACKs). Packet.Bytes slices point
|
||||
// into the Offload's internal buffer and are only valid until the next
|
||||
// Read or Close on this Queue.
|
||||
func (r *Offload) Read() ([]Packet, error) {
|
||||
r.pending = r.pending[:0]
|
||||
r.rxOff = 0
|
||||
|
||||
// Initial (blocking) read. Retry on decode errors so a single bad
|
||||
// packet does not stall the reader.
|
||||
for {
|
||||
n, err := r.readPacket(true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := r.decodeRead(n); err != nil {
|
||||
// Drop and read again — a bad packet should not kill the reader.
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// Drain: non-blocking reads until the kernel queue is empty, the drain
|
||||
// cap is reached, or rxBuf no longer has room for another worst-case
|
||||
// kernel-supplied packet (tunRxBufSize).
|
||||
for len(r.pending) < tunDrainCap && tunRxBufCap-r.rxOff >= tunRxBufSize {
|
||||
n, err := r.readPacket(false)
|
||||
if err != nil {
|
||||
// EAGAIN / EINTR / anything else: stop draining. We already
|
||||
// have a valid batch from the first read.
|
||||
break
|
||||
}
|
||||
if n <= 0 {
|
||||
break
|
||||
}
|
||||
if err := r.decodeRead(n); err != nil {
|
||||
// Drop this packet and stop the drain; we'd rather hand off
|
||||
// what we have than keep spinning here.
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return r.pending, nil
|
||||
}
|
||||
|
||||
// decodeRead processes the packet sitting in rxBuf at rxOff (length
|
||||
// pktLen). The bytes stay in rxBuf — for GSO_NONE we slice them as a
|
||||
// regular IP datagram (running finishChecksum if NEEDS_CSUM is set);
|
||||
// for TSO/USO superpackets we attach the corrected GSO metadata so the
|
||||
// caller can segment lazily at encrypt time. rxOff advances past the
|
||||
// kernel-supplied body and nothing else, since segmentation no longer
|
||||
// writes back into rxBuf.
|
||||
func (r *Offload) decodeRead(pktLen int) error {
|
||||
if pktLen <= 0 {
|
||||
return fmt.Errorf("short tun read: %d", pktLen)
|
||||
}
|
||||
var hdr virtio.Hdr
|
||||
hdr.Decode(r.readVnetScratch[:])
|
||||
|
||||
body := r.rxBuf[r.rxOff : r.rxOff+pktLen]
|
||||
|
||||
if hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_NONE {
|
||||
if hdr.Flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
|
||||
if err := virtio.FinishChecksum(body, hdr); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
r.pending = append(r.pending, Packet{Bytes: body})
|
||||
r.rxOff += pktLen
|
||||
return nil
|
||||
}
|
||||
|
||||
// GSO superpacket: validate, fix the kernel-supplied HdrLen on the
|
||||
// FORWARD path (CorrectHdrLen), pick the L4 protocol, and attach
|
||||
// the metadata. The bytes stay in rxBuf untouched, segmentation
|
||||
// happens in SegmentSuperpacket at encrypt time.
|
||||
if err := virtio.CheckValid(body, hdr); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := virtio.CorrectHdrLen(body, &hdr); err != nil {
|
||||
return err
|
||||
}
|
||||
proto, err := protoFromGSOType(hdr.GSOType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.pending = append(r.pending, Packet{
|
||||
Bytes: body,
|
||||
GSO: GSOInfo{
|
||||
Size: hdr.GSOSize,
|
||||
HdrLen: hdr.HdrLen,
|
||||
CsumStart: hdr.CsumStart,
|
||||
Proto: proto,
|
||||
},
|
||||
})
|
||||
r.rxOff += pktLen
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Offload) Write(buf []byte) (int, error) {
|
||||
iovs := [2]unix.Iovec{
|
||||
{Base: &validVnetHdr[0]},
|
||||
{Base: &buf[0]},
|
||||
}
|
||||
iovs[0].SetLen(virtio.Size)
|
||||
iovs[1].SetLen(len(buf))
|
||||
return r.writeWithScratch(buf, &iovs)
|
||||
}
|
||||
|
||||
func (r *Offload) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error) {
|
||||
if len(buf) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
iovs[1].Base = &buf[0]
|
||||
iovs[1].SetLen(len(buf))
|
||||
return r.rawWrite(unsafe.Slice(&iovs[0], len(iovs)))
|
||||
}
|
||||
|
||||
func (r *Offload) rawWrite(iovs []unix.Iovec) (int, error) {
|
||||
for {
|
||||
n, _, errno := syscall.Syscall(unix.SYS_WRITEV, uintptr(r.fd), uintptr(unsafe.Pointer(&iovs[0])), uintptr(len(iovs)))
|
||||
if errno == 0 {
|
||||
if int(n) < virtio.Size {
|
||||
return 0, io.ErrShortWrite
|
||||
}
|
||||
return int(n) - virtio.Size, nil
|
||||
}
|
||||
if errno == unix.EAGAIN {
|
||||
if err := r.blockOnWrite(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if errno == unix.EINTR {
|
||||
continue
|
||||
}
|
||||
if errno == unix.EBADF {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
return 0, errno
|
||||
}
|
||||
}
|
||||
|
||||
// Capabilities reports the offload features negotiated for this Queue. TSO
|
||||
// is always true for Offload (we only construct it on IFF_VNET_HDR FDs);
|
||||
// USO is true only when the kernel agreed to TUN_F_USO4|6 at open time
|
||||
// (Linux ≥ 6.2).
|
||||
func (r *Offload) Capabilities() Capabilities {
|
||||
return Capabilities{TSO: true, USO: r.usoEnabled}
|
||||
}
|
||||
|
||||
func (r *Offload) WriteGSO(hdr []byte, transportHdr []byte, pays [][]byte, proto GSOProto) error {
|
||||
if len(hdr) == 0 || len(pays) == 0 || len(transportHdr) == 0 {
|
||||
return nil
|
||||
}
|
||||
// L4 checksum offset inside transportHdr: TCP=16 (the `check` field after
|
||||
// seq/ack/dataoff/flags/window), UDP=6 (after sport/dport/length).
|
||||
var csumOff uint16
|
||||
switch proto {
|
||||
case GSOProtoUDP:
|
||||
csumOff = 6
|
||||
default:
|
||||
csumOff = 16
|
||||
}
|
||||
vhdr := virtio.Hdr{
|
||||
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||||
HdrLen: uint16(len(hdr) + len(transportHdr)),
|
||||
GSOSize: uint16(len(pays[0])),
|
||||
CsumStart: uint16(len(hdr)),
|
||||
CsumOffset: csumOff,
|
||||
}
|
||||
if len(pays) > 1 {
|
||||
ipVer := hdr[0] >> 4
|
||||
switch {
|
||||
case proto == GSOProtoUDP && (ipVer == 4 || ipVer == 6):
|
||||
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_UDP_L4
|
||||
case ipVer == 6:
|
||||
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV6
|
||||
case ipVer == 4:
|
||||
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV4
|
||||
default:
|
||||
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE
|
||||
vhdr.GSOSize = 0
|
||||
}
|
||||
} else {
|
||||
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE
|
||||
vhdr.GSOSize = 0
|
||||
}
|
||||
vhdr.Encode(r.gsoHdrBuf[:])
|
||||
|
||||
// Build the iovec array: [virtio_hdr, hdr, transportHdr, pays...]. r.gsoIovs[0] is
|
||||
// wired to gsoHdrBuf at construction and never changes.
|
||||
need := 3 + len(pays)
|
||||
if need > cap(r.gsoIovs) {
|
||||
slog.Default().Warn("tio: WriteGSO iovec budget exceeded; dropping superpacket",
|
||||
"need", need, "cap", cap(r.gsoIovs), "segments", len(pays))
|
||||
return fmt.Errorf("tio: WriteGSO needs %d iovecs but cap is %d", need, cap(r.gsoIovs))
|
||||
}
|
||||
r.gsoIovs = r.gsoIovs[:need]
|
||||
r.gsoIovs[1].Base = &hdr[0]
|
||||
r.gsoIovs[1].SetLen(len(hdr))
|
||||
r.gsoIovs[2].Base = &transportHdr[0]
|
||||
r.gsoIovs[2].SetLen(len(transportHdr))
|
||||
for i, p := range pays {
|
||||
r.gsoIovs[3+i].Base = &p[0]
|
||||
r.gsoIovs[3+i].SetLen(len(p))
|
||||
}
|
||||
|
||||
_, err := r.rawWrite(r.gsoIovs)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Offload) Close() error {
|
||||
if r.closed.Swap(true) {
|
||||
return nil
|
||||
}
|
||||
|
||||
//shutdownFd is owned by the container, so we should not close it
|
||||
var err error
|
||||
if r.fd >= 0 {
|
||||
err = unix.Close(r.fd)
|
||||
r.fd = -1
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -21,7 +21,7 @@ type Poll struct {
|
||||
closed atomic.Bool
|
||||
|
||||
readBuf []byte
|
||||
batchRet [1][]byte
|
||||
batchRet [1]Packet
|
||||
}
|
||||
|
||||
func newPoll(fd int, shutdownFd int) (*Poll, error) {
|
||||
@@ -97,12 +97,12 @@ func (t *Poll) blockOnWrite() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Poll) Read() ([][]byte, error) {
|
||||
func (t *Poll) Read() ([]Packet, error) {
|
||||
n, err := t.readOne(t.readBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.batchRet[0] = t.readBuf[:n]
|
||||
t.batchRet[0] = Packet{Bytes: t.readBuf[:n]}
|
||||
return t.batchRet[:], nil
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
)
|
||||
|
||||
// newReadPipe returns a read fd. The matching write fd is registered for cleanup.
|
||||
// The caller takes ownership of the read fd (pass it to newOffload / newFriend).
|
||||
// The caller takes ownership of the read fd (pass it into a QueueSet).
|
||||
func newReadPipe(t *testing.T) int {
|
||||
t.Helper()
|
||||
var fds [2]int
|
||||
@@ -29,7 +29,7 @@ func newReadPipe(t *testing.T) int {
|
||||
func TestPoll_WakeForShutdown_WakesFriends(t *testing.T) {
|
||||
pipe1 := newReadPipe(t)
|
||||
pipe2 := newReadPipe(t)
|
||||
parent, err := NewPollContainer()
|
||||
parent, err := NewPollQueueSet()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, parent.Add(pipe1))
|
||||
require.NoError(t, parent.Add(pipe2))
|
||||
|
||||
51
overlay/tio/tun_linux_offload.go
Normal file
51
overlay/tio/tun_linux_offload.go
Normal file
@@ -0,0 +1,51 @@
|
||||
//go:build linux && !android && !e2e_testing
|
||||
// +build linux,!android,!e2e_testing
|
||||
|
||||
package tio
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/slackhq/nebula/overlay/tio/virtio"
|
||||
)
|
||||
|
||||
// protoFromGSOType maps a virtio_net_hdr GSOType to the GSOProto value the
|
||||
// segment-time helpers use. Returns an error for GSO_NONE or any unknown
|
||||
// value — the caller should only invoke this on a confirmed superpacket.
|
||||
func protoFromGSOType(t uint8) (GSOProto, error) {
|
||||
switch t {
|
||||
case unix.VIRTIO_NET_HDR_GSO_TCPV4, unix.VIRTIO_NET_HDR_GSO_TCPV6:
|
||||
return GSOProtoTCP, nil
|
||||
case unix.VIRTIO_NET_HDR_GSO_UDP_L4:
|
||||
return GSOProtoUDP, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported virtio gso type: %d", t)
|
||||
}
|
||||
}
|
||||
|
||||
// SegmentSuperpacket invokes fn once per segment of pkt. For non-GSO pkts
|
||||
// fn is called once with pkt.Bytes (no segmentation, no copy). For GSO/USO
|
||||
// superpackets fn is called once per segment with a slice of pkt.Bytes
|
||||
// holding that segment's plaintext (a freshly-patched L3+L4 header sliced
|
||||
// in front of the original payload chunk). The slide is destructive: pkt is
|
||||
// consumed by this call and its bytes are in an undefined state when
|
||||
// SegmentSuperpacket returns. Callers must not retain pkt or any earlier
|
||||
// seg slice past fn's return for that segment. The scratch parameter is
|
||||
// unused on the destructive path and kept only for cross-platform
|
||||
// signature compatibility. Aborts and returns the first error from fn or
|
||||
// from per-segment construction.
|
||||
func SegmentSuperpacket(pkt Packet, fn func(seg []byte) error) error {
|
||||
if !pkt.GSO.IsSuperpacket() {
|
||||
return fn(pkt.Bytes)
|
||||
}
|
||||
switch pkt.GSO.Proto {
|
||||
case GSOProtoTCP:
|
||||
return virtio.SegmentTCP(pkt.Bytes, pkt.GSO.HdrLen, pkt.GSO.CsumStart, pkt.GSO.Size, fn)
|
||||
case GSOProtoUDP:
|
||||
return virtio.SegmentUDP(pkt.Bytes, pkt.GSO.HdrLen, pkt.GSO.CsumStart, pkt.GSO.Size, fn)
|
||||
default:
|
||||
return fmt.Errorf("unsupported gso proto: %d", pkt.GSO.Proto)
|
||||
}
|
||||
}
|
||||
794
overlay/tio/tun_linux_offload_test.go
Normal file
794
overlay/tio/tun_linux_offload_test.go
Normal file
@@ -0,0 +1,794 @@
|
||||
//go:build linux && !android && !e2e_testing
|
||||
// +build linux,!android,!e2e_testing
|
||||
|
||||
package tio
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/checksum"
|
||||
|
||||
"github.com/slackhq/nebula/overlay/tio/virtio"
|
||||
)
|
||||
|
||||
// testSegScratchSize is a generous segmentation scratch sized to fit any
|
||||
// of the synthetic TSO/USO superpackets these tests generate (one
|
||||
// worst-case 64 KiB superpacket plus replicated per-segment headers).
|
||||
const testSegScratchSize = 192 * 1024
|
||||
|
||||
// verifyChecksum confirms that the one's-complement sum across `b`, seeded
|
||||
// with a folded pseudo-header sum, equals all-ones (valid).
|
||||
func verifyChecksum(b []byte, pseudo uint16) bool {
|
||||
return checksum.Checksum(b, pseudo) == 0xffff
|
||||
}
|
||||
|
||||
// segmentForTest is the test-only counterpart to the production
|
||||
// SegmentSuperpacket path. It handles GSO_NONE (with optional
|
||||
// finishChecksum) inline and dispatches GSO superpackets through
|
||||
// SegmentSuperpacket, draining each yielded segment into a
|
||||
// freshly-copied [][]byte slot so callers can iterate after the call
|
||||
// returns. Tests pre-set hdr.HdrLen correctly, so correctHdrLen is not
|
||||
// invoked here.
|
||||
func segmentForTest(pkt []byte, hdr virtio.Hdr, out *[][]byte, scratch []byte) error {
|
||||
if hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_NONE {
|
||||
cp := append([]byte(nil), pkt...)
|
||||
if hdr.Flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
|
||||
if err := virtio.FinishChecksum(cp, hdr); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
*out = append(*out, cp)
|
||||
return nil
|
||||
}
|
||||
proto, err := protoFromGSOType(hdr.GSOType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
gso := GSOInfo{
|
||||
Size: hdr.GSOSize,
|
||||
HdrLen: hdr.HdrLen,
|
||||
CsumStart: hdr.CsumStart,
|
||||
Proto: proto,
|
||||
}
|
||||
return SegmentSuperpacket(Packet{Bytes: pkt, GSO: gso}, func(seg []byte) error {
|
||||
*out = append(*out, append([]byte(nil), seg...))
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// pseudoHeaderIPv4 returns the folded pseudo-header sum used to verify a
|
||||
// TCP/UDP segment's checksum in tests. src/dst are 4 bytes each.
|
||||
func pseudoHeaderIPv4(src, dst []byte, proto byte, l4Len int) uint16 {
|
||||
s := uint32(checksum.Checksum(src, 0)) + uint32(checksum.Checksum(dst, 0))
|
||||
s += uint32(proto) + uint32(l4Len)
|
||||
s = (s & 0xffff) + (s >> 16)
|
||||
s = (s & 0xffff) + (s >> 16)
|
||||
return uint16(s)
|
||||
}
|
||||
|
||||
// pseudoHeaderIPv6 returns the folded pseudo-header sum used to verify a
|
||||
// TCP/UDP segment's checksum in tests. src/dst are 16 bytes each.
|
||||
func pseudoHeaderIPv6(src, dst []byte, proto byte, l4Len int) uint16 {
|
||||
s := uint32(checksum.Checksum(src, 0)) + uint32(checksum.Checksum(dst, 0))
|
||||
s += uint32(l4Len>>16) + uint32(l4Len&0xffff) + uint32(proto)
|
||||
s = (s & 0xffff) + (s >> 16)
|
||||
s = (s & 0xffff) + (s >> 16)
|
||||
return uint16(s)
|
||||
}
|
||||
|
||||
// buildTSOv4 builds a synthetic IPv4/TCP TSO superpacket with a payload of
|
||||
// `payLen` bytes split at `mss`.
|
||||
func buildTSOv4(t *testing.T, payLen, mss int) ([]byte, virtio.Hdr) {
|
||||
t.Helper()
|
||||
const ipLen = 20
|
||||
const tcpLen = 20
|
||||
pkt := make([]byte, ipLen+tcpLen+payLen)
|
||||
|
||||
// IPv4 header
|
||||
pkt[0] = 0x45 // version 4, IHL 5
|
||||
// total length is meaningless for TSO but set it anyway
|
||||
binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+tcpLen+payLen))
|
||||
binary.BigEndian.PutUint16(pkt[4:6], 0x4242) // original ID
|
||||
pkt[8] = 64 // TTL
|
||||
pkt[9] = unix.IPPROTO_TCP
|
||||
copy(pkt[12:16], []byte{10, 0, 0, 1}) // src
|
||||
copy(pkt[16:20], []byte{10, 0, 0, 2}) // dst
|
||||
|
||||
// TCP header
|
||||
binary.BigEndian.PutUint16(pkt[20:22], 12345) // sport
|
||||
binary.BigEndian.PutUint16(pkt[22:24], 80) // dport
|
||||
binary.BigEndian.PutUint32(pkt[24:28], 10000) // seq
|
||||
binary.BigEndian.PutUint32(pkt[28:32], 20000) // ack
|
||||
pkt[32] = 0x50 // data offset 5 words
|
||||
pkt[33] = 0x18 // ACK | PSH
|
||||
binary.BigEndian.PutUint16(pkt[34:36], 65535) // window
|
||||
|
||||
// payload
|
||||
for i := 0; i < payLen; i++ {
|
||||
pkt[ipLen+tcpLen+i] = byte(i & 0xff)
|
||||
}
|
||||
|
||||
return pkt, virtio.Hdr{
|
||||
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||||
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
|
||||
HdrLen: uint16(ipLen + tcpLen),
|
||||
GSOSize: uint16(mss),
|
||||
CsumStart: uint16(ipLen),
|
||||
CsumOffset: 16,
|
||||
}
|
||||
}
|
||||
|
||||
func TestSegmentTCPv4(t *testing.T) {
|
||||
const mss = 100
|
||||
const numSeg = 3
|
||||
pkt, hdr := buildTSOv4(t, mss*numSeg, mss)
|
||||
|
||||
scratch := make([]byte, testSegScratchSize)
|
||||
var out [][]byte
|
||||
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
|
||||
t.Fatalf("segmentForTest: %v", err)
|
||||
}
|
||||
if len(out) != numSeg {
|
||||
t.Fatalf("expected %d segments, got %d", numSeg, len(out))
|
||||
}
|
||||
|
||||
for i, seg := range out {
|
||||
if len(seg) != 40+mss {
|
||||
t.Errorf("seg %d: unexpected len %d", i, len(seg))
|
||||
}
|
||||
totalLen := binary.BigEndian.Uint16(seg[2:4])
|
||||
if totalLen != uint16(40+mss) {
|
||||
t.Errorf("seg %d: total_len=%d want %d", i, totalLen, 40+mss)
|
||||
}
|
||||
id := binary.BigEndian.Uint16(seg[4:6])
|
||||
if id != 0x4242+uint16(i) {
|
||||
t.Errorf("seg %d: ip id=%#x want %#x", i, id, 0x4242+uint16(i))
|
||||
}
|
||||
seq := binary.BigEndian.Uint32(seg[24:28])
|
||||
wantSeq := uint32(10000 + i*mss)
|
||||
if seq != wantSeq {
|
||||
t.Errorf("seg %d: seq=%d want %d", i, seq, wantSeq)
|
||||
}
|
||||
flags := seg[33]
|
||||
wantFlags := byte(0x10) // ACK only, PSH cleared
|
||||
if i == numSeg-1 {
|
||||
wantFlags = 0x18 // ACK | PSH preserved on last
|
||||
}
|
||||
if flags != wantFlags {
|
||||
t.Errorf("seg %d: flags=%#x want %#x", i, flags, wantFlags)
|
||||
}
|
||||
// IPv4 header checksum must verify against itself.
|
||||
if !verifyChecksum(seg[:20], 0) {
|
||||
t.Errorf("seg %d: bad IPv4 header checksum", i)
|
||||
}
|
||||
// TCP checksum must verify against the pseudo-header.
|
||||
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+mss)
|
||||
if !verifyChecksum(seg[20:], psum) {
|
||||
t.Errorf("seg %d: bad TCP checksum", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSegmentTCPv4OddTail(t *testing.T) {
|
||||
// Payload of 250 bytes with MSS 100 → segments of 100, 100, 50.
|
||||
pkt, hdr := buildTSOv4(t, 250, 100)
|
||||
scratch := make([]byte, testSegScratchSize)
|
||||
var out [][]byte
|
||||
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
|
||||
t.Fatalf("segmentForTest: %v", err)
|
||||
}
|
||||
if len(out) != 3 {
|
||||
t.Fatalf("want 3 segments, got %d", len(out))
|
||||
}
|
||||
wantPayLens := []int{100, 100, 50}
|
||||
for i, seg := range out {
|
||||
if len(seg)-40 != wantPayLens[i] {
|
||||
t.Errorf("seg %d: pay len %d want %d", i, len(seg)-40, wantPayLens[i])
|
||||
}
|
||||
if !verifyChecksum(seg[:20], 0) {
|
||||
t.Errorf("seg %d: bad IPv4 header checksum", i)
|
||||
}
|
||||
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+wantPayLens[i])
|
||||
if !verifyChecksum(seg[20:], psum) {
|
||||
t.Errorf("seg %d: bad TCP checksum", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSegmentTCPv6(t *testing.T) {
|
||||
const ipLen = 40
|
||||
const tcpLen = 20
|
||||
const mss = 120
|
||||
const numSeg = 2
|
||||
payLen := mss * numSeg
|
||||
pkt := make([]byte, ipLen+tcpLen+payLen)
|
||||
|
||||
// IPv6 header
|
||||
pkt[0] = 0x60 // version 6
|
||||
binary.BigEndian.PutUint16(pkt[4:6], uint16(tcpLen+payLen))
|
||||
pkt[6] = unix.IPPROTO_TCP
|
||||
pkt[7] = 64
|
||||
// src/dst fe80::1 / fe80::2
|
||||
pkt[8] = 0xfe
|
||||
pkt[9] = 0x80
|
||||
pkt[23] = 1
|
||||
pkt[24] = 0xfe
|
||||
pkt[25] = 0x80
|
||||
pkt[39] = 2
|
||||
|
||||
// TCP header
|
||||
binary.BigEndian.PutUint16(pkt[40:42], 12345)
|
||||
binary.BigEndian.PutUint16(pkt[42:44], 80)
|
||||
binary.BigEndian.PutUint32(pkt[44:48], 7)
|
||||
binary.BigEndian.PutUint32(pkt[48:52], 99)
|
||||
pkt[52] = 0x50
|
||||
pkt[53] = 0x19 // FIN | ACK | PSH — exercise FIN clearing too
|
||||
binary.BigEndian.PutUint16(pkt[54:56], 65535)
|
||||
|
||||
for i := 0; i < payLen; i++ {
|
||||
pkt[ipLen+tcpLen+i] = byte(i)
|
||||
}
|
||||
|
||||
hdr := virtio.Hdr{
|
||||
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||||
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
|
||||
HdrLen: uint16(ipLen + tcpLen),
|
||||
GSOSize: uint16(mss),
|
||||
CsumStart: uint16(ipLen),
|
||||
CsumOffset: 16,
|
||||
}
|
||||
|
||||
scratch := make([]byte, testSegScratchSize)
|
||||
var out [][]byte
|
||||
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
|
||||
t.Fatalf("segmentForTest: %v", err)
|
||||
}
|
||||
if len(out) != numSeg {
|
||||
t.Fatalf("want %d segments, got %d", numSeg, len(out))
|
||||
}
|
||||
|
||||
for i, seg := range out {
|
||||
if len(seg) != ipLen+tcpLen+mss {
|
||||
t.Errorf("seg %d: len %d want %d", i, len(seg), ipLen+tcpLen+mss)
|
||||
}
|
||||
pl := binary.BigEndian.Uint16(seg[4:6])
|
||||
if pl != uint16(tcpLen+mss) {
|
||||
t.Errorf("seg %d: payload_length=%d want %d", i, pl, tcpLen+mss)
|
||||
}
|
||||
seq := binary.BigEndian.Uint32(seg[44:48])
|
||||
if seq != uint32(7+i*mss) {
|
||||
t.Errorf("seg %d: seq=%d want %d", i, seq, 7+i*mss)
|
||||
}
|
||||
flags := seg[53]
|
||||
// Original flags = 0x19 (FIN|ACK|PSH). FIN(0x01)+PSH(0x08) should be
|
||||
// cleared on all but the last; ACK(0x10) always preserved.
|
||||
wantFlags := byte(0x10)
|
||||
if i == numSeg-1 {
|
||||
wantFlags = 0x19
|
||||
}
|
||||
if flags != wantFlags {
|
||||
t.Errorf("seg %d: flags=%#x want %#x", i, flags, wantFlags)
|
||||
}
|
||||
psum := pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_TCP, tcpLen+mss)
|
||||
if !verifyChecksum(seg[ipLen:], psum) {
|
||||
t.Errorf("seg %d: bad TCP checksum", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSegmentGSONonePassesThrough(t *testing.T) {
|
||||
pkt, hdr := buildTSOv4(t, 100, 100)
|
||||
hdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE
|
||||
hdr.Flags = 0 // no NEEDS_CSUM, leave packet untouched
|
||||
|
||||
scratch := make([]byte, testSegScratchSize)
|
||||
var out [][]byte
|
||||
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
|
||||
t.Fatalf("segmentForTest: %v", err)
|
||||
}
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("want 1 segment, got %d", len(out))
|
||||
}
|
||||
if len(out[0]) != len(pkt) {
|
||||
t.Fatalf("unexpected length: %d vs %d", len(out[0]), len(pkt))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSegmentRejectsLegacyUDPGSO ensures the legacy GSO_UDP (UFO) marker is
|
||||
// still rejected; only modern GSO_UDP_L4 (USO) is supported.
|
||||
func TestSegmentRejectsLegacyUDPGSO(t *testing.T) {
|
||||
hdr := virtio.Hdr{GSOType: unix.VIRTIO_NET_HDR_GSO_UDP}
|
||||
var out [][]byte
|
||||
if err := segmentForTest(nil, hdr, &out, nil); err == nil {
|
||||
t.Fatalf("expected rejection for legacy UDP GSO")
|
||||
}
|
||||
}
|
||||
|
||||
// buildUSOv4 builds a synthetic IPv4/UDP USO superpacket with payload of
|
||||
// payLen bytes, segmented at gsoSize.
|
||||
func buildUSOv4(t *testing.T, payLen, gsoSize int) ([]byte, virtio.Hdr) {
|
||||
t.Helper()
|
||||
const ipLen = 20
|
||||
const udpLen = 8
|
||||
pkt := make([]byte, ipLen+udpLen+payLen)
|
||||
|
||||
// IPv4 header
|
||||
pkt[0] = 0x45 // version 4, IHL 5
|
||||
binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+udpLen+payLen))
|
||||
binary.BigEndian.PutUint16(pkt[4:6], 0x4242)
|
||||
pkt[8] = 64
|
||||
pkt[9] = unix.IPPROTO_UDP
|
||||
copy(pkt[12:16], []byte{10, 0, 0, 1})
|
||||
copy(pkt[16:20], []byte{10, 0, 0, 2})
|
||||
|
||||
// UDP header (length + checksum filled in per segment by segmentUDPYield)
|
||||
binary.BigEndian.PutUint16(pkt[20:22], 12345) // sport
|
||||
binary.BigEndian.PutUint16(pkt[22:24], 53) // dport
|
||||
|
||||
for i := 0; i < payLen; i++ {
|
||||
pkt[ipLen+udpLen+i] = byte(i & 0xff)
|
||||
}
|
||||
|
||||
return pkt, virtio.Hdr{
|
||||
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||||
GSOType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
|
||||
HdrLen: uint16(ipLen + udpLen),
|
||||
GSOSize: uint16(gsoSize),
|
||||
CsumStart: uint16(ipLen),
|
||||
CsumOffset: 6,
|
||||
}
|
||||
}
|
||||
|
||||
func TestSegmentUDPv4(t *testing.T) {
|
||||
const gso = 100
|
||||
const numSeg = 3
|
||||
pkt, hdr := buildUSOv4(t, gso*numSeg, gso)
|
||||
|
||||
scratch := make([]byte, testSegScratchSize)
|
||||
var out [][]byte
|
||||
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
|
||||
t.Fatalf("segmentForTest: %v", err)
|
||||
}
|
||||
if len(out) != numSeg {
|
||||
t.Fatalf("expected %d segments, got %d", numSeg, len(out))
|
||||
}
|
||||
|
||||
for i, seg := range out {
|
||||
if len(seg) != 28+gso {
|
||||
t.Errorf("seg %d: len %d want %d", i, len(seg), 28+gso)
|
||||
}
|
||||
totalLen := binary.BigEndian.Uint16(seg[2:4])
|
||||
if totalLen != uint16(28+gso) {
|
||||
t.Errorf("seg %d: total_len=%d want %d", i, totalLen, 28+gso)
|
||||
}
|
||||
// kernel UDP-GSO does NOT bump the IPv4 ID across segments; every
|
||||
// segment carries the same ID as the seed.
|
||||
id := binary.BigEndian.Uint16(seg[4:6])
|
||||
if id != 0x4242 {
|
||||
t.Errorf("seg %d: ip id=%#x want %#x", i, id, 0x4242)
|
||||
}
|
||||
udpLen := binary.BigEndian.Uint16(seg[24:26])
|
||||
if udpLen != uint16(8+gso) {
|
||||
t.Errorf("seg %d: udp len=%d want %d", i, udpLen, 8+gso)
|
||||
}
|
||||
if !verifyChecksum(seg[:20], 0) {
|
||||
t.Errorf("seg %d: bad IPv4 header checksum", i)
|
||||
}
|
||||
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_UDP, 8+gso)
|
||||
if !verifyChecksum(seg[20:], psum) {
|
||||
t.Errorf("seg %d: bad UDP checksum", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSegmentUDPv4OddTail(t *testing.T) {
|
||||
// 250 bytes payload, gsoSize=100 → segments of 100, 100, 50.
|
||||
pkt, hdr := buildUSOv4(t, 250, 100)
|
||||
scratch := make([]byte, testSegScratchSize)
|
||||
var out [][]byte
|
||||
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
|
||||
t.Fatalf("segmentForTest: %v", err)
|
||||
}
|
||||
if len(out) != 3 {
|
||||
t.Fatalf("want 3 segments, got %d", len(out))
|
||||
}
|
||||
wantPay := []int{100, 100, 50}
|
||||
for i, seg := range out {
|
||||
if len(seg)-28 != wantPay[i] {
|
||||
t.Errorf("seg %d: pay len %d want %d", i, len(seg)-28, wantPay[i])
|
||||
}
|
||||
udpLen := binary.BigEndian.Uint16(seg[24:26])
|
||||
if udpLen != uint16(8+wantPay[i]) {
|
||||
t.Errorf("seg %d: udp len=%d want %d", i, udpLen, 8+wantPay[i])
|
||||
}
|
||||
if !verifyChecksum(seg[:20], 0) {
|
||||
t.Errorf("seg %d: bad IPv4 header checksum", i)
|
||||
}
|
||||
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_UDP, 8+wantPay[i])
|
||||
if !verifyChecksum(seg[20:], psum) {
|
||||
t.Errorf("seg %d: bad UDP checksum", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSegmentUDPv6(t *testing.T) {
|
||||
const ipLen = 40
|
||||
const udpLen = 8
|
||||
const gso = 120
|
||||
const numSeg = 2
|
||||
payLen := gso * numSeg
|
||||
pkt := make([]byte, ipLen+udpLen+payLen)
|
||||
|
||||
// IPv6 header
|
||||
pkt[0] = 0x60
|
||||
binary.BigEndian.PutUint16(pkt[4:6], uint16(udpLen+payLen))
|
||||
pkt[6] = unix.IPPROTO_UDP
|
||||
pkt[7] = 64
|
||||
pkt[8] = 0xfe
|
||||
pkt[9] = 0x80
|
||||
pkt[23] = 1
|
||||
pkt[24] = 0xfe
|
||||
pkt[25] = 0x80
|
||||
pkt[39] = 2
|
||||
|
||||
binary.BigEndian.PutUint16(pkt[40:42], 12345)
|
||||
binary.BigEndian.PutUint16(pkt[42:44], 53)
|
||||
|
||||
for i := 0; i < payLen; i++ {
|
||||
pkt[ipLen+udpLen+i] = byte(i)
|
||||
}
|
||||
|
||||
hdr := virtio.Hdr{
|
||||
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||||
GSOType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
|
||||
HdrLen: uint16(ipLen + udpLen),
|
||||
GSOSize: uint16(gso),
|
||||
CsumStart: uint16(ipLen),
|
||||
CsumOffset: 6,
|
||||
}
|
||||
|
||||
scratch := make([]byte, testSegScratchSize)
|
||||
var out [][]byte
|
||||
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
|
||||
t.Fatalf("segmentForTest: %v", err)
|
||||
}
|
||||
if len(out) != numSeg {
|
||||
t.Fatalf("want %d segments, got %d", numSeg, len(out))
|
||||
}
|
||||
|
||||
for i, seg := range out {
|
||||
if len(seg) != ipLen+udpLen+gso {
|
||||
t.Errorf("seg %d: len %d want %d", i, len(seg), ipLen+udpLen+gso)
|
||||
}
|
||||
pl := binary.BigEndian.Uint16(seg[4:6])
|
||||
if pl != uint16(udpLen+gso) {
|
||||
t.Errorf("seg %d: payload_length=%d want %d", i, pl, udpLen+gso)
|
||||
}
|
||||
ul := binary.BigEndian.Uint16(seg[ipLen+4 : ipLen+6])
|
||||
if ul != uint16(udpLen+gso) {
|
||||
t.Errorf("seg %d: udp len=%d want %d", i, ul, udpLen+gso)
|
||||
}
|
||||
psum := pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_UDP, udpLen+gso)
|
||||
if !verifyChecksum(seg[ipLen:], psum) {
|
||||
t.Errorf("seg %d: bad UDP checksum", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSegmentUDPCEPropagates confirms IP-level CE marks on the seed appear on
|
||||
// every segment. UDP has no transport-level CWR/ECE: the IP TOS/TC byte is
|
||||
// copied verbatim into every segment by the segment-prefix copy.
|
||||
func TestSegmentUDPCEPropagates(t *testing.T) {
|
||||
pkt, hdr := buildUSOv4(t, 200, 100)
|
||||
pkt[1] = 0x03 // CE codepoint in IP-ECN
|
||||
|
||||
scratch := make([]byte, testSegScratchSize)
|
||||
var out [][]byte
|
||||
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
|
||||
t.Fatalf("segmentForTest: %v", err)
|
||||
}
|
||||
if len(out) != 2 {
|
||||
t.Fatalf("want 2 segments, got %d", len(out))
|
||||
}
|
||||
for i, seg := range out {
|
||||
if seg[1]&0x03 != 0x03 {
|
||||
t.Errorf("seg %d: CE missing (tos=%#x)", i, seg[1])
|
||||
}
|
||||
if !verifyChecksum(seg[:20], 0) {
|
||||
t.Errorf("seg %d: bad IPv4 header checksum", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSegmentTCPCwrFirstSegmentOnly confirms RFC 3168 §6.1.2: when a TSO
|
||||
// burst's seed has CWR set, only the first emitted segment carries CWR.
|
||||
// ECE is preserved on every segment (different signal, persistent state).
|
||||
func TestSegmentTCPCwrFirstSegmentOnly(t *testing.T) {
|
||||
const mss = 100
|
||||
const numSeg = 3
|
||||
pkt, hdr := buildTSOv4(t, mss*numSeg, mss)
|
||||
// Seed flags: CWR | ECE | ACK | PSH.
|
||||
pkt[33] = 0x80 | 0x40 | 0x10 | 0x08
|
||||
|
||||
scratch := make([]byte, testSegScratchSize)
|
||||
var out [][]byte
|
||||
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
|
||||
t.Fatalf("segmentForTest: %v", err)
|
||||
}
|
||||
if len(out) != numSeg {
|
||||
t.Fatalf("expected %d segments, got %d", numSeg, len(out))
|
||||
}
|
||||
for i, seg := range out {
|
||||
flags := seg[33]
|
||||
hasCwr := flags&0x80 != 0
|
||||
hasEce := flags&0x40 != 0
|
||||
hasPsh := flags&0x08 != 0
|
||||
wantCwr := i == 0
|
||||
wantPsh := i == numSeg-1
|
||||
if hasCwr != wantCwr {
|
||||
t.Errorf("seg %d: CWR=%v want %v (flags=%#x)", i, hasCwr, wantCwr, flags)
|
||||
}
|
||||
if !hasEce {
|
||||
t.Errorf("seg %d: ECE missing (flags=%#x)", i, flags)
|
||||
}
|
||||
if hasPsh != wantPsh {
|
||||
t.Errorf("seg %d: PSH=%v want %v (flags=%#x)", i, hasPsh, wantPsh, flags)
|
||||
}
|
||||
// IP and TCP checksums must still verify after the flag rewrite.
|
||||
if !verifyChecksum(seg[:20], 0) {
|
||||
t.Errorf("seg %d: bad IPv4 header checksum", i)
|
||||
}
|
||||
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+mss)
|
||||
if !verifyChecksum(seg[20:], psum) {
|
||||
t.Errorf("seg %d: bad TCP checksum", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSegmentTCPv4(b *testing.B) {
|
||||
sizes := []struct {
|
||||
name string
|
||||
payLen int
|
||||
mss int
|
||||
}{
|
||||
{"64KiB_MSS1460", 65000, 1460},
|
||||
{"16KiB_MSS1460", 16384, 1460},
|
||||
{"4KiB_MSS1460", 4096, 1460},
|
||||
}
|
||||
for _, sz := range sizes {
|
||||
b.Run(sz.name, func(b *testing.B) {
|
||||
const ipLen = 20
|
||||
const tcpLen = 20
|
||||
pkt := make([]byte, ipLen+tcpLen+sz.payLen)
|
||||
pkt[0] = 0x45
|
||||
binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+tcpLen+sz.payLen))
|
||||
binary.BigEndian.PutUint16(pkt[4:6], 0x4242)
|
||||
pkt[8] = 64
|
||||
pkt[9] = unix.IPPROTO_TCP
|
||||
copy(pkt[12:16], []byte{10, 0, 0, 1})
|
||||
copy(pkt[16:20], []byte{10, 0, 0, 2})
|
||||
binary.BigEndian.PutUint16(pkt[20:22], 12345)
|
||||
binary.BigEndian.PutUint16(pkt[22:24], 80)
|
||||
binary.BigEndian.PutUint32(pkt[24:28], 10000)
|
||||
binary.BigEndian.PutUint32(pkt[28:32], 20000)
|
||||
pkt[32] = 0x50
|
||||
pkt[33] = 0x18
|
||||
binary.BigEndian.PutUint16(pkt[34:36], 65535)
|
||||
for i := 0; i < sz.payLen; i++ {
|
||||
pkt[ipLen+tcpLen+i] = byte(i)
|
||||
}
|
||||
hdr := virtio.Hdr{
|
||||
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||||
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
|
||||
HdrLen: uint16(ipLen + tcpLen),
|
||||
GSOSize: uint16(sz.mss),
|
||||
CsumStart: uint16(ipLen),
|
||||
CsumOffset: 16,
|
||||
}
|
||||
|
||||
scratch := make([]byte, testSegScratchSize)
|
||||
out := make([][]byte, 0, 64)
|
||||
|
||||
// SegmentSuperpacket consumes its input destructively; restore
|
||||
// pkt from a master copy each iteration. The restore mirrors the
|
||||
// kernel→userspace copy that hands a fresh GSO blob to the
|
||||
// segmenter in production, so it's representative cost rather
|
||||
// than bench overhead.
|
||||
master := append([]byte(nil), pkt...)
|
||||
work := make([]byte, len(pkt))
|
||||
|
||||
b.SetBytes(int64(len(pkt)))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
copy(work, master)
|
||||
out = out[:0]
|
||||
if err := segmentForTest(work, hdr, &out, scratch); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTunFileWriteVnetHdrNoAlloc verifies the IFF_VNET_HDR fast-path write is
|
||||
// allocation-free. We write to /dev/null so every call succeeds synchronously.
|
||||
func TestTunFileWriteVnetHdrNoAlloc(t *testing.T) {
|
||||
fd, err := unix.Open("/dev/null", os.O_WRONLY, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("open /dev/null: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = unix.Close(fd) })
|
||||
|
||||
tf := &Offload{fd: fd}
|
||||
|
||||
payload := make([]byte, 1400)
|
||||
// Warm up (first call may trigger one-time internal allocations elsewhere).
|
||||
if _, err := tf.Write(payload); err != nil {
|
||||
t.Fatalf("Write: %v", err)
|
||||
}
|
||||
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
if _, err := tf.Write(payload); err != nil {
|
||||
t.Fatalf("Write: %v", err)
|
||||
}
|
||||
})
|
||||
if allocs != 0 {
|
||||
t.Fatalf("Write allocated %.1f times per call, want 0", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
// buildTSOv6 builds a synthetic IPv6/TCP TSO superpacket with payLen bytes
|
||||
// of payload, segmented at gso. Returns the packet bytes only; the
|
||||
// virtio_net_hdr is the caller's responsibility.
|
||||
func buildTSOv6(payLen, gso int) []byte {
|
||||
const ipLen = 40
|
||||
const tcpLen = 20
|
||||
pkt := make([]byte, ipLen+tcpLen+payLen)
|
||||
|
||||
pkt[0] = 0x60 // version 6
|
||||
binary.BigEndian.PutUint16(pkt[4:6], uint16(tcpLen+payLen))
|
||||
pkt[6] = unix.IPPROTO_TCP
|
||||
pkt[7] = 64
|
||||
pkt[8] = 0xfe
|
||||
pkt[9] = 0x80
|
||||
pkt[23] = 1
|
||||
pkt[24] = 0xfe
|
||||
pkt[25] = 0x80
|
||||
pkt[39] = 2
|
||||
|
||||
binary.BigEndian.PutUint16(pkt[40:42], 12345)
|
||||
binary.BigEndian.PutUint16(pkt[42:44], 80)
|
||||
binary.BigEndian.PutUint32(pkt[44:48], 7)
|
||||
binary.BigEndian.PutUint32(pkt[48:52], 99)
|
||||
pkt[52] = 0x50
|
||||
pkt[53] = 0x10 // ACK only
|
||||
binary.BigEndian.PutUint16(pkt[54:56], 65535)
|
||||
|
||||
for i := 0; i < payLen; i++ {
|
||||
pkt[ipLen+tcpLen+i] = byte(i)
|
||||
}
|
||||
return pkt
|
||||
}
|
||||
|
||||
// TestDecodeReadFitsMaxTSOAtDrainThreshold proves the rxBuf sizing is
|
||||
// correct: when rxOff is at the maximum value the drain headroom check
|
||||
// allows, decodeRead must still be able to absorb a worst-case 64KiB
|
||||
// TSO superpacket without dropping the burst. With segmentation deferred
|
||||
// to encrypt time, decodeRead writes only the kernel-supplied bytes into
|
||||
// rxBuf, so the size requirement is just "fit one worst-case input."
|
||||
//
|
||||
// Regression history: in a prior layout the rx buffer doubled as the
|
||||
// segmentation output, a near-threshold drain read returned "scratch too
|
||||
// small", the whole 45-segment TSO burst was dropped, and the remote's TCP
|
||||
// fast-retransmit collapsed cwnd. Keeping this test in the new layout
|
||||
// guards against re-introducing a drain headroom shortfall.
|
||||
func TestDecodeReadFitsMaxTSOAtDrainThreshold(t *testing.T) {
|
||||
const ipv6HdrLen = 40
|
||||
const tcpHdrLen = 20
|
||||
const headerLen = ipv6HdrLen + tcpHdrLen
|
||||
// Maximum TUN read body. The tunReadBufSize cap on readv's body iovec
|
||||
// is what bounds the kernel's superpacket length.
|
||||
pktLen := tunReadBufSize
|
||||
payLen := pktLen - headerLen
|
||||
const targetSegs = 64
|
||||
gsoSize := (payLen + targetSegs - 1) / targetSegs
|
||||
|
||||
pkt := buildTSOv6(payLen, gsoSize)
|
||||
if len(pkt) != pktLen {
|
||||
t.Fatalf("buildTSOv6 produced %d bytes, want %d", len(pkt), pktLen)
|
||||
}
|
||||
|
||||
o := &Offload{
|
||||
rxBuf: make([]byte, tunRxBufCap),
|
||||
}
|
||||
// rxOff at the maximum value the drain headroom check permits before
|
||||
// it would refuse another read. Any drain-time read up to this
|
||||
// threshold MUST still process correctly.
|
||||
o.rxOff = tunRxBufCap - tunRxBufSize
|
||||
|
||||
// Stage the body in rxBuf as if readv(2) just placed it there.
|
||||
copy(o.rxBuf[o.rxOff:], pkt)
|
||||
|
||||
// Encode the matching virtio_net_hdr.
|
||||
hdr := virtio.Hdr{
|
||||
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||||
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
|
||||
HdrLen: uint16(headerLen),
|
||||
GSOSize: uint16(gsoSize),
|
||||
CsumStart: uint16(ipv6HdrLen),
|
||||
CsumOffset: 16,
|
||||
}
|
||||
hdr.Encode(o.readVnetScratch[:])
|
||||
|
||||
startRxOff := o.rxOff
|
||||
if err := o.decodeRead(pktLen); err != nil {
|
||||
t.Fatalf("decodeRead at drain threshold returned %v — rxBuf sizing regression: "+
|
||||
"tunRxBufSize=%d must hold one worst-case input (%d)",
|
||||
err, tunRxBufSize, pktLen)
|
||||
}
|
||||
|
||||
if len(o.pending) != 1 {
|
||||
t.Fatalf("got %d packets, want 1 superpacket entry", len(o.pending))
|
||||
}
|
||||
got := o.pending[0]
|
||||
if !got.GSO.IsSuperpacket() {
|
||||
t.Fatalf("expected superpacket GSO metadata, got %+v", got.GSO)
|
||||
}
|
||||
if got.GSO.Proto != GSOProtoTCP {
|
||||
t.Errorf("GSO.Proto=%d want TCP", got.GSO.Proto)
|
||||
}
|
||||
if got.GSO.Size != uint16(gsoSize) {
|
||||
t.Errorf("GSO.Size=%d want %d", got.GSO.Size, gsoSize)
|
||||
}
|
||||
if got.GSO.HdrLen != uint16(headerLen) {
|
||||
t.Errorf("GSO.HdrLen=%d want %d", got.GSO.HdrLen, headerLen)
|
||||
}
|
||||
if got.GSO.CsumStart != uint16(ipv6HdrLen) {
|
||||
t.Errorf("GSO.CsumStart=%d want %d", got.GSO.CsumStart, ipv6HdrLen)
|
||||
}
|
||||
if len(got.Bytes) != pktLen {
|
||||
t.Errorf("len(Bytes)=%d want %d", len(got.Bytes), pktLen)
|
||||
}
|
||||
|
||||
// rxOff advances exactly by the kernel-supplied body length — no
|
||||
// segmentation output to account for any more.
|
||||
if o.rxOff != startRxOff+pktLen {
|
||||
t.Errorf("rxOff=%d want %d", o.rxOff, startRxOff+pktLen)
|
||||
}
|
||||
if o.rxOff > tunRxBufCap {
|
||||
t.Fatalf("rxOff=%d overran rxBuf (cap=%d)", o.rxOff, tunRxBufCap)
|
||||
}
|
||||
|
||||
// Validate that segmenting the returned superpacket reproduces the
|
||||
// expected per-segment IPv6 payload length and TCP checksum.
|
||||
wantSegs := (payLen + gsoSize - 1) / gsoSize
|
||||
gotSegs := 0
|
||||
if err := SegmentSuperpacket(got, func(seg []byte) error {
|
||||
defer func() { gotSegs++ }()
|
||||
if len(seg) < headerLen+1 {
|
||||
t.Errorf("seg %d too short: %d", gotSegs, len(seg))
|
||||
return nil
|
||||
}
|
||||
if seg[0]>>4 != 6 {
|
||||
t.Errorf("seg %d: bad IP version %#x", gotSegs, seg[0])
|
||||
}
|
||||
segPay := len(seg) - headerLen
|
||||
gotPL := binary.BigEndian.Uint16(seg[4:6])
|
||||
if gotPL != uint16(tcpHdrLen+segPay) {
|
||||
t.Errorf("seg %d: payload_len=%d want %d", gotSegs, gotPL, tcpHdrLen+segPay)
|
||||
}
|
||||
psum := pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_TCP, tcpHdrLen+segPay)
|
||||
if !verifyChecksum(seg[ipv6HdrLen:], psum) {
|
||||
t.Errorf("seg %d: bad TCP checksum", gotSegs)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatalf("SegmentSuperpacket: %v", err)
|
||||
}
|
||||
if gotSegs != wantSegs {
|
||||
t.Fatalf("got %d segments, want %d", gotSegs, wantSegs)
|
||||
}
|
||||
}
|
||||
43
overlay/tio/virtio/header_linux.go
Normal file
43
overlay/tio/virtio/header_linux.go
Normal file
@@ -0,0 +1,43 @@
|
||||
//go:build linux && !android && !e2e_testing
|
||||
// +build linux,!android,!e2e_testing
|
||||
|
||||
package virtio
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
// Size is the on-wire length of struct virtio_net_hdr the kernel
|
||||
// prepends/expects on a TUN opened with IFF_VNET_HDR (TUNSETVNETHDRSZ
|
||||
// not set).
|
||||
const Size = 10
|
||||
|
||||
// Hdr is the Go view of the legacy virtio_net_hdr.
|
||||
type Hdr struct {
|
||||
Flags uint8
|
||||
GSOType uint8
|
||||
HdrLen uint16
|
||||
GSOSize uint16
|
||||
CsumStart uint16
|
||||
CsumOffset uint16
|
||||
}
|
||||
|
||||
// Decode reads a virtio_net_hdr in host byte order (TUN default; we never
|
||||
// call TUNSETVNETLE so the kernel matches our endianness).
|
||||
func (h *Hdr) Decode(b []byte) {
|
||||
h.Flags = b[0]
|
||||
h.GSOType = b[1]
|
||||
h.HdrLen = binary.NativeEndian.Uint16(b[2:4])
|
||||
h.GSOSize = binary.NativeEndian.Uint16(b[4:6])
|
||||
h.CsumStart = binary.NativeEndian.Uint16(b[6:8])
|
||||
h.CsumOffset = binary.NativeEndian.Uint16(b[8:10])
|
||||
}
|
||||
|
||||
// Encode is the inverse of Decode: writes the virtio_net_hdr fields into b
|
||||
// (must be at least Size bytes). Used to emit a TSO superpacket on egress.
|
||||
func (h *Hdr) Encode(b []byte) {
|
||||
b[0] = h.Flags
|
||||
b[1] = h.GSOType
|
||||
binary.NativeEndian.PutUint16(b[2:4], h.HdrLen)
|
||||
binary.NativeEndian.PutUint16(b[4:6], h.GSOSize)
|
||||
binary.NativeEndian.PutUint16(b[6:8], h.CsumStart)
|
||||
binary.NativeEndian.PutUint16(b[8:10], h.CsumOffset)
|
||||
}
|
||||
401
overlay/tio/virtio/segment_linux.go
Normal file
401
overlay/tio/virtio/segment_linux.go
Normal file
@@ -0,0 +1,401 @@
|
||||
//go:build linux && !android && !e2e_testing
|
||||
// +build linux,!android,!e2e_testing
|
||||
|
||||
// Package virtio implements the pure validation, header-correction, and
|
||||
// per-segment slicing logic for kernel-supplied TSO/USO superpackets on
|
||||
// IFF_VNET_HDR TUN devices. It is FD-free and depends only on the byte
|
||||
// layout of the virtio_net_hdr and the IP/TCP/UDP headers it describes,
|
||||
// so it can be unit-tested in isolation from the tio Queue runtime.
|
||||
package virtio
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/checksum"
|
||||
)
|
||||
|
||||
// Protocol header size bounds used to validate / cap kernel-supplied offsets.
|
||||
const (
|
||||
ipv4HeaderMinLen = 20 // IHL=5, no options
|
||||
ipv4HeaderMaxLen = 60 // IHL=15, max options
|
||||
ipv6FixedLen = 40 // IPv6 base header; extensions would extend this
|
||||
tcpHeaderMinLen = 20 // data-offset=5, no options
|
||||
tcpHeaderMaxLen = 60 // data-offset=15, max options
|
||||
)
|
||||
|
||||
// Byte offsets inside an IPv4 header.
|
||||
const (
|
||||
ipv4TotalLenOff = 2
|
||||
ipv4IDOff = 4
|
||||
ipv4ChecksumOff = 10
|
||||
ipv4SrcOff = 12
|
||||
ipv4AddrsEnd = 20 // end of dst address (ipv4SrcOff + 2*4)
|
||||
)
|
||||
|
||||
// Byte offsets inside an IPv6 header.
|
||||
const (
|
||||
ipv6PayloadLenOff = 4
|
||||
ipv6SrcOff = 8
|
||||
ipv6AddrsEnd = 40 // end of dst address (ipv6SrcOff + 2*16)
|
||||
)
|
||||
|
||||
// Byte offsets inside a TCP header (relative to its start, i.e. csumStart).
|
||||
const (
|
||||
tcpSeqOff = 4
|
||||
tcpDataOffOff = 12 // upper nibble is header len in 32-bit words
|
||||
tcpFlagsOff = 13
|
||||
tcpChecksumOff = 16
|
||||
)
|
||||
|
||||
// UDP header is fixed at 8 bytes: {sport, dport, length, checksum}.
|
||||
const (
|
||||
udpHeaderLen = 8
|
||||
udpLengthOff = 4
|
||||
udpChecksumOff = 6
|
||||
)
|
||||
|
||||
// tcpFinPshMask is cleared on every segment except the last of a TSO burst.
|
||||
const tcpFinPshMask = 0x09 // FIN(0x01) | PSH(0x08)
|
||||
|
||||
// tcpCwrFlag is cleared on every segment except the first. Per RFC 3168
|
||||
// §6.1.2 the CWR bit signals a one-shot transition (the sender just halved
|
||||
// its window) and must appear on the first segment of a TSO burst only.
|
||||
const tcpCwrFlag = 0x80
|
||||
|
||||
// CheckValid rejects packets whose virtio_net_hdr/IP combination would
|
||||
// cause a downstream miscompute. The TUN should never emit RSC_INFO and
|
||||
// the GSO type must agree with the IP version nibble.
|
||||
func CheckValid(pkt []byte, hdr Hdr) error {
|
||||
// When RSC_INFO is set the csum_start/csum_offset fields are repurposed to
|
||||
// carry coalescing info rather than checksum offsets. A TUN writing via
|
||||
// IFF_VNET_HDR should never emit this, but if it did we would silently
|
||||
// miscompute the segment checksums — refuse the packet instead.
|
||||
if hdr.Flags&unix.VIRTIO_NET_HDR_F_RSC_INFO != 0 {
|
||||
return fmt.Errorf("virtio RSC_INFO flag not supported on TUN reads")
|
||||
}
|
||||
if len(pkt) < ipv4HeaderMinLen {
|
||||
return fmt.Errorf("packet too short")
|
||||
}
|
||||
ipVersion := pkt[0] >> 4
|
||||
switch hdr.GSOType {
|
||||
case unix.VIRTIO_NET_HDR_GSO_TCPV4:
|
||||
if ipVersion != 4 {
|
||||
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
|
||||
}
|
||||
case unix.VIRTIO_NET_HDR_GSO_TCPV6:
|
||||
if ipVersion != 6 {
|
||||
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
|
||||
}
|
||||
case unix.VIRTIO_NET_HDR_GSO_UDP_L4:
|
||||
// USO carries either v4 or v6; the leading nibble disambiguates.
|
||||
if !(ipVersion == 4 || ipVersion == 6) {
|
||||
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
|
||||
}
|
||||
default:
|
||||
if !(ipVersion == 6 || ipVersion == 4) {
|
||||
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CorrectHdrLen rewrites hdr.HdrLen based on the actual transport header
|
||||
// length read out of pkt. The kernel's hdr.HdrLen on the FORWARD path can
|
||||
// be the length of the entire first packet, so we don't trust it.
|
||||
func CorrectHdrLen(pkt []byte, hdr *Hdr) error {
|
||||
// Thank you wireguard-go for documenting these edge-cases
|
||||
// Don't trust hdr.hdrLen from the kernel as it can be equal to the length
|
||||
// of the entire first packet when the kernel is handling it as part of a
|
||||
// FORWARD path. Instead, parse the transport header length and add it onto
|
||||
// csumStart, which is synonymous for IP header length.
|
||||
|
||||
if hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
|
||||
hdr.HdrLen = hdr.CsumStart + 8
|
||||
} else {
|
||||
if len(pkt) <= int(hdr.CsumStart+tcpDataOffOff) {
|
||||
return errors.New("packet is too short")
|
||||
}
|
||||
|
||||
tcpHLen := uint16(pkt[hdr.CsumStart+tcpDataOffOff] >> 4 * 4)
|
||||
if tcpHLen < 20 || tcpHLen > 60 {
|
||||
// A TCP header must be between 20 and 60 bytes in length.
|
||||
return fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
|
||||
}
|
||||
hdr.HdrLen = hdr.CsumStart + tcpHLen
|
||||
}
|
||||
|
||||
if len(pkt) < int(hdr.HdrLen) {
|
||||
return fmt.Errorf("length of packet (%d) < virtioNetHdr.HdrLen (%d)", len(pkt), hdr.HdrLen)
|
||||
}
|
||||
|
||||
if hdr.HdrLen < hdr.CsumStart {
|
||||
return fmt.Errorf("virtioNetHdr.HdrLen (%d) < virtioNetHdr.CsumStart (%d)", hdr.HdrLen, hdr.CsumStart)
|
||||
}
|
||||
cSumAt := int(hdr.CsumStart + hdr.CsumStart)
|
||||
if cSumAt+1 >= len(pkt) {
|
||||
return fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(pkt))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SegmentTCP walks a TSO superpacket pkt, yielding each segment as a
|
||||
// slice into pkt itself. Per-segment plaintext is laid out by sliding a
|
||||
// freshly-patched copy of the L3+L4 header into pkt at offset i*gsoSize,
|
||||
// where it sits immediately before that segment's payload chunk in the
|
||||
// original buffer. The slide is destructive: iter i's header write overwrites
|
||||
// the last hdrLen bytes of seg_{i-1}'s payload, which is dead by the time
|
||||
// the next iteration begins. pkt is consumed by this call and must not be
|
||||
// inspected by the caller after the final yield.
|
||||
func SegmentTCP(pkt []byte, hdrLenU, csumStartU, gsoSizeU uint16, yield func(seg []byte) error) error {
|
||||
if gsoSizeU == 0 {
|
||||
return fmt.Errorf("gso_size is zero")
|
||||
}
|
||||
if csumStartU == 0 {
|
||||
return fmt.Errorf("csum_start is zero")
|
||||
}
|
||||
|
||||
headerLen := int(hdrLenU)
|
||||
csumStart := int(csumStartU)
|
||||
isV4 := pkt[0]>>4 == 4
|
||||
|
||||
tcpHdrLen := int(pkt[csumStart+tcpDataOffOff]>>4) * 4
|
||||
payLen := len(pkt) - headerLen
|
||||
gsoSize := int(gsoSizeU)
|
||||
numSeg := (payLen + gsoSize - 1) / gsoSize
|
||||
if numSeg == 0 {
|
||||
numSeg = 1
|
||||
}
|
||||
|
||||
origSeq := binary.BigEndian.Uint32(pkt[csumStart+tcpSeqOff : csumStart+tcpSeqOff+4])
|
||||
origFlags := pkt[csumStart+tcpFlagsOff]
|
||||
|
||||
var tmp [tcpHeaderMaxLen]byte
|
||||
copy(tmp[:tcpHdrLen], pkt[csumStart:headerLen])
|
||||
tmp[tcpSeqOff], tmp[tcpSeqOff+1], tmp[tcpSeqOff+2], tmp[tcpSeqOff+3] = 0, 0, 0, 0
|
||||
tmp[tcpFlagsOff] = 0
|
||||
tmp[tcpChecksumOff], tmp[tcpChecksumOff+1] = 0, 0
|
||||
baseTcpHdrSum := uint32(checksum.Checksum(tmp[:tcpHdrLen], 0))
|
||||
|
||||
var baseProtoSum uint32
|
||||
if isV4 {
|
||||
baseProtoSum = uint32(checksum.Checksum(pkt[ipv4SrcOff:ipv4AddrsEnd], 0))
|
||||
} else {
|
||||
baseProtoSum = uint32(checksum.Checksum(pkt[ipv6SrcOff:ipv6AddrsEnd], 0))
|
||||
}
|
||||
baseProtoSum += uint32(unix.IPPROTO_TCP)
|
||||
|
||||
var origIPID uint16
|
||||
var baseIPHdrSum uint32
|
||||
if isV4 {
|
||||
origIPID = binary.BigEndian.Uint16(pkt[ipv4IDOff : ipv4IDOff+2])
|
||||
ihl := int(pkt[0]&0x0f) * 4
|
||||
if ihl < ipv4HeaderMinLen || ihl > csumStart {
|
||||
return fmt.Errorf("bad IPv4 IHL: %d", ihl)
|
||||
}
|
||||
var ipTmp [ipv4HeaderMaxLen]byte
|
||||
copy(ipTmp[:ihl], pkt[:ihl])
|
||||
ipTmp[ipv4TotalLenOff], ipTmp[ipv4TotalLenOff+1] = 0, 0
|
||||
ipTmp[ipv4IDOff], ipTmp[ipv4IDOff+1] = 0, 0
|
||||
ipTmp[ipv4ChecksumOff], ipTmp[ipv4ChecksumOff+1] = 0, 0
|
||||
baseIPHdrSum = uint32(checksum.Checksum(ipTmp[:ihl], 0))
|
||||
}
|
||||
|
||||
for i := 0; i < numSeg; i++ {
|
||||
segStart := i * gsoSize
|
||||
segEnd := segStart + gsoSize
|
||||
if segEnd > payLen {
|
||||
segEnd = payLen
|
||||
}
|
||||
segPayLen := segEnd - segStart
|
||||
segLen := headerLen + segPayLen
|
||||
headerOff := i * gsoSize
|
||||
|
||||
// Slide the header into place immediately before this segment's
|
||||
// payload. Iter 0's header is already at pkt[:headerLen]; for
|
||||
// i ≥ 1 we copy from there. The constant-byte fields of pkt[:headerLen]
|
||||
// survive iter 0's in-place patches (only seq/flags/cksum/totalLen/id
|
||||
// are touched), and iter 0's stale variable-field values are
|
||||
// overwritten by the per-segment patches below.
|
||||
if i > 0 {
|
||||
copy(pkt[headerOff:headerOff+headerLen], pkt[:headerLen])
|
||||
}
|
||||
seg := pkt[headerOff : headerOff+segLen]
|
||||
|
||||
segSeq := origSeq + uint32(segStart)
|
||||
segFlags := origFlags
|
||||
if i != 0 {
|
||||
segFlags &^= tcpCwrFlag
|
||||
}
|
||||
if i != numSeg-1 {
|
||||
segFlags &^= tcpFinPshMask
|
||||
}
|
||||
totalLen := segLen
|
||||
|
||||
if isV4 {
|
||||
segID := origIPID + uint16(i)
|
||||
binary.BigEndian.PutUint16(seg[ipv4TotalLenOff:ipv4TotalLenOff+2], uint16(totalLen))
|
||||
binary.BigEndian.PutUint16(seg[ipv4IDOff:ipv4IDOff+2], segID)
|
||||
ipSum := baseIPHdrSum + uint32(totalLen) + uint32(segID)
|
||||
binary.BigEndian.PutUint16(seg[ipv4ChecksumOff:ipv4ChecksumOff+2], foldComplement(ipSum))
|
||||
} else {
|
||||
binary.BigEndian.PutUint16(seg[ipv6PayloadLenOff:ipv6PayloadLenOff+2], uint16(headerLen-ipv6FixedLen+segPayLen))
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint32(seg[csumStart+tcpSeqOff:csumStart+tcpSeqOff+4], segSeq)
|
||||
seg[csumStart+tcpFlagsOff] = segFlags
|
||||
|
||||
tcpLen := tcpHdrLen + segPayLen
|
||||
// Payload bytes still live at their original offset in pkt. The
|
||||
// header slide above only writes into pkt[i*G : i*G+H], which is
|
||||
// the tail of seg_{i-1}'s payload (already consumed) and never
|
||||
// overlaps seg_i's own payload at pkt[H+i*G : H+(i+1)*G].
|
||||
paySum := uint32(checksum.Checksum(pkt[headerLen+segStart:headerLen+segEnd], 0))
|
||||
wide := uint64(baseTcpHdrSum) + uint64(paySum) + uint64(baseProtoSum)
|
||||
wide += uint64(segSeq) + uint64(segFlags) + uint64(tcpLen)
|
||||
wide = (wide & 0xffffffff) + (wide >> 32)
|
||||
wide = (wide & 0xffffffff) + (wide >> 32)
|
||||
binary.BigEndian.PutUint16(seg[csumStart+tcpChecksumOff:csumStart+tcpChecksumOff+2], foldComplement(uint32(wide)))
|
||||
|
||||
if err := yield(seg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SegmentUDP walks a USO superpacket, sliding a per-segment-patched
|
||||
// L3+L4 header into pkt at offset i*gsoSize and yielding pkt[i*G:i*G+segLen]
|
||||
// to the caller. Per-segment patches are total_len + IPv4 csum (or IPv6
|
||||
// payload_len) plus the UDP length and checksum. pkt is consumed
|
||||
// destructively; see SegmentTCP for the layout reasoning.
|
||||
//
|
||||
// UDP-GSO leaves the IPv4 ID identical across segments (the kernel does not
|
||||
// bump it), which is why the IP-level per-segment work is limited to
|
||||
// total_len + IPv4 header checksum (v4) or payload_len (v6).
|
||||
func SegmentUDP(pkt []byte, hdrLenU, csumStartU, gsoSizeU uint16, yield func(seg []byte) error) error {
|
||||
if gsoSizeU == 0 {
|
||||
return fmt.Errorf("gso_size is zero")
|
||||
}
|
||||
if csumStartU == 0 {
|
||||
return fmt.Errorf("csum_start is zero")
|
||||
}
|
||||
|
||||
isV4 := pkt[0]>>4 == 4
|
||||
headerLen := int(hdrLenU)
|
||||
csumStart := int(csumStartU)
|
||||
if headerLen-csumStart != udpHeaderLen {
|
||||
return fmt.Errorf("udp header len mismatch: %d", headerLen-csumStart)
|
||||
}
|
||||
|
||||
payLen := len(pkt) - headerLen
|
||||
gsoSize := int(gsoSizeU)
|
||||
numSeg := (payLen + gsoSize - 1) / gsoSize
|
||||
if numSeg == 0 {
|
||||
numSeg = 1
|
||||
}
|
||||
|
||||
var udpTmp [udpHeaderLen]byte
|
||||
copy(udpTmp[:], pkt[csumStart:headerLen])
|
||||
udpTmp[udpLengthOff], udpTmp[udpLengthOff+1] = 0, 0
|
||||
udpTmp[udpChecksumOff], udpTmp[udpChecksumOff+1] = 0, 0
|
||||
baseUDPHdrSum := uint32(checksum.Checksum(udpTmp[:], 0))
|
||||
|
||||
var baseProtoSum uint32
|
||||
if isV4 {
|
||||
baseProtoSum = uint32(checksum.Checksum(pkt[ipv4SrcOff:ipv4AddrsEnd], 0))
|
||||
} else {
|
||||
baseProtoSum = uint32(checksum.Checksum(pkt[ipv6SrcOff:ipv6AddrsEnd], 0))
|
||||
}
|
||||
baseProtoSum += uint32(unix.IPPROTO_UDP)
|
||||
|
||||
var baseIPHdrSum uint32
|
||||
if isV4 {
|
||||
ihl := int(pkt[0]&0x0f) * 4
|
||||
if ihl < ipv4HeaderMinLen || ihl > csumStart {
|
||||
return fmt.Errorf("bad IPv4 IHL: %d", ihl)
|
||||
}
|
||||
var ipTmp [ipv4HeaderMaxLen]byte
|
||||
copy(ipTmp[:ihl], pkt[:ihl])
|
||||
ipTmp[ipv4TotalLenOff], ipTmp[ipv4TotalLenOff+1] = 0, 0
|
||||
ipTmp[ipv4ChecksumOff], ipTmp[ipv4ChecksumOff+1] = 0, 0
|
||||
baseIPHdrSum = uint32(checksum.Checksum(ipTmp[:ihl], 0))
|
||||
}
|
||||
|
||||
for i := 0; i < numSeg; i++ {
|
||||
segStart := i * gsoSize
|
||||
segEnd := segStart + gsoSize
|
||||
if segEnd > payLen {
|
||||
segEnd = payLen
|
||||
}
|
||||
segPayLen := segEnd - segStart
|
||||
segLen := headerLen + segPayLen
|
||||
headerOff := i * gsoSize
|
||||
|
||||
if i > 0 {
|
||||
copy(pkt[headerOff:headerOff+headerLen], pkt[:headerLen])
|
||||
}
|
||||
seg := pkt[headerOff : headerOff+segLen]
|
||||
|
||||
totalLen := segLen
|
||||
udpLen := udpHeaderLen + segPayLen
|
||||
|
||||
if isV4 {
|
||||
binary.BigEndian.PutUint16(seg[ipv4TotalLenOff:ipv4TotalLenOff+2], uint16(totalLen))
|
||||
ipSum := baseIPHdrSum + uint32(totalLen)
|
||||
binary.BigEndian.PutUint16(seg[ipv4ChecksumOff:ipv4ChecksumOff+2], foldComplement(ipSum))
|
||||
} else {
|
||||
binary.BigEndian.PutUint16(seg[ipv6PayloadLenOff:ipv6PayloadLenOff+2], uint16(headerLen-ipv6FixedLen+segPayLen))
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint16(seg[csumStart+udpLengthOff:csumStart+udpLengthOff+2], uint16(udpLen))
|
||||
|
||||
paySum := uint32(checksum.Checksum(pkt[headerLen+segStart:headerLen+segEnd], 0))
|
||||
wide := uint64(baseUDPHdrSum) + uint64(paySum) + uint64(baseProtoSum)
|
||||
wide += uint64(udpLen) + uint64(udpLen)
|
||||
wide = (wide & 0xffffffff) + (wide >> 32)
|
||||
wide = (wide & 0xffffffff) + (wide >> 32)
|
||||
csum := foldComplement(uint32(wide))
|
||||
if csum == 0 {
|
||||
csum = 0xffff
|
||||
}
|
||||
binary.BigEndian.PutUint16(seg[csumStart+udpChecksumOff:csumStart+udpChecksumOff+2], csum)
|
||||
|
||||
if err := yield(seg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FinishChecksum computes the L4 checksum for a non-GSO packet that the kernel
|
||||
// handed us with NEEDS_CSUM set. csum_start / csum_offset point at the 16-bit
|
||||
// checksum field; we zero it, fold a full sum (the field was pre-loaded with
|
||||
// the pseudo-header partial sum by the kernel), and store the result.
|
||||
func FinishChecksum(seg []byte, hdr Hdr) error {
|
||||
cs := int(hdr.CsumStart)
|
||||
co := int(hdr.CsumOffset)
|
||||
if cs+co+2 > len(seg) {
|
||||
return fmt.Errorf("csum offsets out of range: start=%d offset=%d len=%d", cs, co, len(seg))
|
||||
}
|
||||
// The kernel stores a partial pseudo-header sum at [cs+co:]; sum over the
|
||||
// L4 region starting at cs, folding the prior partial in as the seed.
|
||||
partial := binary.BigEndian.Uint16(seg[cs+co : cs+co+2])
|
||||
seg[cs+co] = 0
|
||||
seg[cs+co+1] = 0
|
||||
binary.BigEndian.PutUint16(seg[cs+co:cs+co+2], ^checksum.Checksum(seg[cs:], partial))
|
||||
return nil
|
||||
}
|
||||
|
||||
// foldComplement folds a 32-bit one's-complement partial sum to 16 bits and
|
||||
// complements it, yielding the on-wire Internet checksum value.
|
||||
func foldComplement(sum uint32) uint16 {
|
||||
sum = (sum & 0xffff) + (sum >> 16)
|
||||
sum = (sum & 0xffff) + (sum >> 16)
|
||||
return ^uint16(sum)
|
||||
}
|
||||
Reference in New Issue
Block a user