This commit is contained in:
JackDoan
2026-04-21 13:31:16 -05:00
parent bf4e37e99d
commit ad6b918e4d
28 changed files with 1039 additions and 698 deletions

View File

@@ -0,0 +1,70 @@
package tio
import (
"encoding/binary"
"errors"
"fmt"
"golang.org/x/sys/unix"
)
type gsoContainer struct {
pq []*tunFile
// pqi is exactly the same as pq, but stored as the interface type
pqi []Queue
shutdownFd int
}
func NewGSOContainer() (Container, error) {
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
if err != nil {
return nil, fmt.Errorf("failed to create eventfd: %w", err)
}
out := &gsoContainer{
pq: []*tunFile{},
pqi: []Queue{},
shutdownFd: shutdownFd,
}
return out, nil
}
func (c *gsoContainer) Queues() []Queue {
return c.pqi
}
func (c *gsoContainer) Add(fd int) error {
x, err := newTunFd(fd, c.shutdownFd)
if err != nil {
return err
}
c.pq = append(c.pq, x)
c.pqi = append(c.pqi, x)
return nil
}
func (c *gsoContainer) wakeForShutdown() error {
var buf [8]byte
binary.NativeEndian.PutUint64(buf[:], 1)
_, err := unix.Write(int(c.shutdownFd), buf[:])
return err
}
func (c *gsoContainer) Close() error {
errs := []error{}
// Signal all readers blocked in poll to wake up and exit
if err := c.wakeForShutdown(); err != nil {
errs = append(errs, err)
}
for _, x := range c.pq {
if err := x.Close(); err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}

View File

@@ -0,0 +1,69 @@
package tio
import (
"encoding/binary"
"errors"
"fmt"
"golang.org/x/sys/unix"
)
type pollContainer struct {
pq []*Poll
// pqi is exactly the same as pq, but stored as the interface type
pqi []Queue
shutdownFd int
}
func NewPollContainer() (Container, error) {
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
if err != nil {
return nil, fmt.Errorf("failed to create eventfd: %w", err)
}
out := &pollContainer{
pq: []*Poll{},
pqi: []Queue{},
shutdownFd: shutdownFd,
}
return out, nil
}
func (c *pollContainer) Queues() []Queue {
return c.pqi
}
func (c *pollContainer) Add(fd int) error {
x, err := newPoll(fd, c.shutdownFd)
if err != nil {
return err
}
c.pq = append(c.pq, x)
c.pqi = append(c.pqi, x)
return nil
}
func (c *pollContainer) wakeForShutdown() error {
var buf [8]byte
binary.NativeEndian.PutUint64(buf[:], 1)
_, err := unix.Write(int(c.shutdownFd), buf[:])
return err
}
func (c *pollContainer) Close() error {
errs := []error{}
if err := c.wakeForShutdown(); err != nil {
errs = append(errs, err)
}
for _, x := range c.pq {
if err := x.Close(); err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}

63
overlay/tio/tio.go Normal file
View File

@@ -0,0 +1,63 @@
package tio
import "io"
// defaultBatchBufSize is the per-Queue scratch size for Read on backends
// that don't do TSO segmentation. 65535 covers any single IP packet.
const defaultBatchBufSize = 65535
type Container interface {
Queues() []Queue
Add(fd int) error
io.Closer
}
// Queue is a readable/writable Poll queue. One Queue is driven by a single
// read goroutine plus concurrent writers (see Write / WriteReject below).
type Queue interface {
io.Closer
// Read returns one or more packets. The returned slices are borrowed
// from the Queue's internal buffer and are only valid until the next
// Read or Close on this Queue — callers must encrypt or copy each
// slice before the next call. Not safe for concurrent Reads; exactly
// one goroutine per Queue reads.
Read() ([][]byte, error)
// Write emits a single packet on the plaintext (outside→inside)
// delivery path. May run concurrently with WriteReject on the same
// Queue, but not with itself.
Write(p []byte) (int, error)
// WriteReject writes a single packet that originated from the inside
// path (reject replies or self-forward) using scratch state distinct
// from Write, so it can run concurrently with Write on the same Queue
// without a data race. On backends without a shared-scratch Write, a
// trivial delegation to Write is acceptable.
WriteReject(p []byte) (int, error)
}
// GSOWriter is implemented by Queues that can emit a TCP TSO superpacket
// assembled from a header prefix plus one or more borrowed payload
// fragments, in a single vectored write (writev with a leading
// virtio_net_hdr). This lets the coalescer avoid copying payload bytes
// between the caller's decrypt buffer and the TUN. Backends without GSO
// support return false from GSOSupported and coalescing is skipped.
//
// hdr contains the IPv4/IPv6 + TCP header prefix (mutable — callers will
// have filled in total length and pseudo-header partial). pays are
// non-overlapping payload fragments whose concatenation is the full
// superpacket payload; they are read-only from the writer's perspective
// and must remain valid until the call returns. gsoSize is the MSS:
// every segment except possibly the last is exactly that many bytes.
// csumStart is the byte offset where the TCP header begins within hdr.
//
// # TODO fold into Queue
//
// hdr's TCP checksum field must already hold the pseudo-header partial
// sum (single-fold, not inverted), per virtio NEEDS_CSUM semantics.
type GSOWriter interface {
WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error
GSOSupported() bool
}

View File

@@ -0,0 +1,405 @@
package tio
import (
"fmt"
"io"
"os"
"runtime"
"sync/atomic"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
)
// Space for segmented output. Worst case is many small segments, each paying
// an IP+TCP header. 128KiB comfortably covers the 64KiB payload ceiling.
const tunSegBufSize = 131072
// tunSegBufCap is the total size we allocate for the per-reader segment
// buffer. It is sized as one worst-case TSO superpacket (tunSegBufSize) plus
// the same again as drain headroom so a Read wake can accumulate
// additional packets after an initial big read without overflowing.
const tunSegBufCap = tunSegBufSize * 2
// tunDrainCap caps how many packets a single Read will accumulate via
// the post-wake drain loop. Sized to soak up a burst of small ACKs while
// bounding how much work a single caller holds before handing off.
const tunDrainCap = 64
// gsoInitialPayIovs is the starting capacity (in payload fragments) of
// tunFile.gsoIovs. Sized to cover the default coalesce segment cap without
// any reallocations.
const gsoInitialPayIovs = 66
// validVnetHdr is the 10-byte virtio_net_hdr we prepend to every non-GSO TUN
// write. Only flag set is VIRTIO_NET_HDR_F_DATA_VALID, which marks the skb
// CHECKSUM_UNNECESSARY so the receiving network stack skips L4 checksum
// verification. All packets that reach the plain Write / WriteReject paths
// already carry a valid L4 checksum (either supplied by a remote peer whose
// ciphertext we AEAD-authenticated, or produced by finishChecksum during TSO
// segmentation, or built locally by CreateRejectPacket), so trusting them is
// safe.
var validVnetHdr = [virtioNetHdrLen]byte{unix.VIRTIO_NET_HDR_F_DATA_VALID}
// tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking.
// A shared eventfd allows Close to wake all readers blocked in poll.
type tunFile struct { //todo rename GSO
fd int
shutdownFd int
readPoll [2]unix.PollFd
writePoll [2]unix.PollFd
closed atomic.Bool
readBuf []byte // scratch for a single raw read (virtio hdr + superpacket)
segBuf []byte // backing store for segmented output
segOff int // cursor into segBuf for the current Read drain
pending [][]byte // segments returned from the most recent Read
writeIovs [2]unix.Iovec // preallocated iovecs for Write (coalescer passthrough); iovs[0] is fixed to validVnetHdr
// rejectIovs is a second preallocated iovec scratch used exclusively by
// WriteReject (reject + self-forward from the inside path). It mirrors
// writeIovs but lets listenIn goroutines emit reject packets without
// racing with the listenOut coalescer that owns writeIovs.
rejectIovs [2]unix.Iovec
// gsoHdrBuf is a per-queue 10-byte scratch for the virtio_net_hdr emitted
// by WriteGSO. Separate from validVnetHdr so a concurrent non-GSO Write on
// another queue never observes a half-written header.
gsoHdrBuf [virtioNetHdrLen]byte
// gsoIovs is the writev iovec scratch for WriteGSO. Sized to hold the
// virtio header + IP/TCP header + up to gsoInitialPayIovs payload
// fragments; grown on demand if a coalescer pushes more.
gsoIovs []unix.Iovec
}
func newTunFd(fd int, shutdownFd int) (*tunFile, error) {
if err := unix.SetNonblock(fd, true); err != nil {
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
}
out := &tunFile{
fd: fd,
shutdownFd: shutdownFd,
closed: atomic.Bool{},
readBuf: make([]byte, tunReadBufSize),
readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
writePoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLOUT},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
segBuf: make([]byte, tunSegBufSize),
gsoIovs: make([]unix.Iovec, 2, 2+gsoInitialPayIovs),
}
out.writeIovs[0].Base = &validVnetHdr[0]
out.writeIovs[0].SetLen(virtioNetHdrLen)
out.rejectIovs[0].Base = &validVnetHdr[0]
out.rejectIovs[0].SetLen(virtioNetHdrLen)
out.gsoIovs[0].Base = &out.gsoHdrBuf[0]
out.gsoIovs[0].SetLen(virtioNetHdrLen)
return out, nil
}
func (r *tunFile) blockOnRead() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(r.readPoll[:], -1)
if err != unix.EINTR {
break
}
}
//always reset these!
tunEvents := r.readPoll[0].Revents
shutdownEvents := r.readPoll[1].Revents
r.readPoll[0].Revents = 0
r.readPoll[1].Revents = 0
//do the err check before trusting the potentially bogus bits we just got
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
} else if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (r *tunFile) blockOnWrite() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(r.writePoll[:], -1)
if err != unix.EINTR {
break
}
}
//always reset these!
tunEvents := r.writePoll[0].Revents
shutdownEvents := r.writePoll[1].Revents
r.writePoll[0].Revents = 0
r.writePoll[1].Revents = 0
//do the err check before trusting the potentially bogus bits we just got
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
} else if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (r *tunFile) readRaw(buf []byte) (int, error) {
for {
if n, err := unix.Read(r.fd, buf); err == nil {
return n, nil
} else if err == unix.EAGAIN {
if err = r.blockOnRead(); err != nil {
return 0, err
}
continue
} else if err == unix.EINTR {
continue
} else if err == unix.EBADF {
return 0, os.ErrClosed
} else {
return 0, err
}
}
}
// Read reads one or more superpackets from the tun and returns the
// resulting packets. The first read blocks via poll; once the fd is known
// readable we drain additional packets non-blocking until the kernel queue
// is empty (EAGAIN), we've collected tunDrainCap packets, or we're out of
// segBuf headroom. This amortizes the poll wake over bursts of small
// packets (e.g. TCP ACKs). Slices point into the tunFile's internal buffers
// and are only valid until the next Read or Close on this Queue.
func (r *tunFile) Read() ([][]byte, error) {
r.pending = r.pending[:0]
r.segOff = 0
// Initial (blocking) read. Retry on decode errors so a single bad
// packet does not stall the reader.
for {
n, err := r.readRaw(r.readBuf)
if err != nil {
return nil, err
}
if err := r.decodeRead(n); err != nil {
// Drop and read again — a bad packet should not kill the reader.
continue
}
break
}
// Drain: non-blocking reads until the kernel queue is empty, the drain
// cap is reached, or segBuf no longer has room for another worst-case
// superpacket.
for len(r.pending) < tunDrainCap && tunSegBufCap-r.segOff >= tunSegBufSize {
n, err := unix.Read(r.fd, r.readBuf)
if err != nil {
// EAGAIN / EINTR / anything else: stop draining. We already
// have a valid batch from the first read.
break
}
if n <= 0 {
break
}
if err := r.decodeRead(n); err != nil {
// Drop this packet and stop the drain; we'd rather hand off
// what we have than keep spinning here.
break
}
}
return r.pending, nil
}
// decodeRead decodes the virtio header plus payload in r.readBuf[:n], appends
// the segments to r.pending, and advances r.segOff by the total scratch used.
// Caller must have already ensured r.vnetHdr is true.
func (r *tunFile) decodeRead(n int) error {
if n < virtioNetHdrLen {
return fmt.Errorf("short tun read: %d < %d", n, virtioNetHdrLen)
}
var hdr VirtioNetHdr
hdr.decode(r.readBuf[:virtioNetHdrLen])
before := len(r.pending)
if err := segmentInto(r.readBuf[virtioNetHdrLen:n], hdr, &r.pending, r.segBuf[r.segOff:]); err != nil {
return err
}
for k := before; k < len(r.pending); k++ {
r.segOff += len(r.pending[k])
}
return nil
}
func (r *tunFile) Write(buf []byte) (int, error) {
return r.writeWithScratch(buf, &r.writeIovs)
}
// WriteReject emits a packet using a dedicated iovec scratch (rejectIovs)
// distinct from the one used by the coalescer's Write path. This avoids a
// data race between the inside (listenIn) goroutine emitting reject or
// self-forward packets and the outside (listenOut) goroutine flushing TCP
// coalescer passthroughs on the same tunFile.
func (r *tunFile) WriteReject(buf []byte) (int, error) {
return r.writeWithScratch(buf, &r.rejectIovs)
}
func (r *tunFile) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error) {
if len(buf) == 0 {
return 0, nil
}
// Point the payload iovec at the caller's buffer. iovs[0] is pre-wired
// to validVnetHdr during tunFile construction so we don't rebuild it here.
iovs[1].Base = &buf[0]
iovs[1].SetLen(len(buf))
iovPtr := uintptr(unsafe.Pointer(&iovs[0]))
// The TUN fd is non-blocking (set in newTunFd / newFriend), so writev
// either completes promptly or returns EAGAIN — it cannot park the
// goroutine inside the kernel. That lets us use syscall.RawSyscall and
// skip the runtime.entersyscall / exitsyscall bookkeeping on every
// packet; we only pay that cost when we fall through to blockOnWrite.
for {
n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, 2)
if errno == 0 {
runtime.KeepAlive(buf)
if int(n) < virtioNetHdrLen {
return 0, io.ErrShortWrite
}
return int(n) - virtioNetHdrLen, nil
}
if errno == unix.EAGAIN {
runtime.KeepAlive(buf)
if err := r.blockOnWrite(); err != nil {
return 0, err
}
continue
}
if errno == unix.EINTR {
continue
}
if errno == unix.EBADF {
return 0, os.ErrClosed
}
runtime.KeepAlive(buf)
return 0, errno
}
}
// GSOSupported reports whether this queue was opened with IFF_VNET_HDR and
// can accept WriteGSO. When false, callers should fall back to per-segment
// Write calls.
func (r *tunFile) GSOSupported() bool { return true }
// WriteGSO emits a TCP TSO superpacket in a single writev. hdr is the
// IPv4/IPv6 + TCP header prefix (already finalized — total length, IP csum,
// and TCP pseudo-header partial set by the caller). pays are payload
// fragments whose concatenation forms the full coalesced payload; each
// slice is read-only and must stay valid until return. gsoSize is the MSS;
// every segment except possibly the last is exactly gsoSize bytes.
// csumStart is the byte offset where the TCP header begins within hdr.
func (r *tunFile) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error {
if len(hdr) == 0 || len(pays) == 0 {
return nil
}
// Build the virtio_net_hdr. When pays total to <= gsoSize the kernel
// would produce a single segment; keep NEEDS_CSUM semantics but skip
// the GSO type so the kernel doesn't spuriously mark this as TSO.
vhdr := VirtioNetHdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
HdrLen: uint16(len(hdr)),
GSOSize: gsoSize,
CsumStart: csumStart,
CsumOffset: 16, // TCP checksum field lives 16 bytes into the TCP header
}
var totalPay int
for _, p := range pays {
totalPay += len(p)
}
if totalPay > int(gsoSize) {
if isV6 {
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV6
} else {
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV4
}
} else {
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE
vhdr.GSOSize = 0
}
vhdr.encode(r.gsoHdrBuf[:])
// Build the iovec array: [virtio_hdr, hdr, pays...]. r.gsoIovs[0] is
// wired to gsoHdrBuf at construction and never changes.
need := 2 + len(pays)
if cap(r.gsoIovs) < need {
grown := make([]unix.Iovec, need)
grown[0] = r.gsoIovs[0]
r.gsoIovs = grown
} else {
r.gsoIovs = r.gsoIovs[:need]
}
r.gsoIovs[1].Base = &hdr[0]
r.gsoIovs[1].SetLen(len(hdr))
for i, p := range pays {
r.gsoIovs[2+i].Base = &p[0]
r.gsoIovs[2+i].SetLen(len(p))
}
iovPtr := uintptr(unsafe.Pointer(&r.gsoIovs[0]))
iovCnt := uintptr(len(r.gsoIovs))
for {
n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, iovCnt)
if errno == 0 {
runtime.KeepAlive(hdr)
runtime.KeepAlive(pays)
if int(n) < virtioNetHdrLen {
return io.ErrShortWrite
}
return nil
}
if errno == unix.EAGAIN {
runtime.KeepAlive(hdr)
runtime.KeepAlive(pays)
if err := r.blockOnWrite(); err != nil {
return err
}
continue
}
if errno == unix.EINTR {
continue
}
if errno == unix.EBADF {
return os.ErrClosed
}
runtime.KeepAlive(hdr)
runtime.KeepAlive(pays)
return errno
}
}
func (r *tunFile) Close() error {
if r.closed.Swap(true) {
return nil
}
//shutdownFd is owned by the container, so we should not close it
var err error
if r.fd >= 0 {
err = unix.Close(r.fd)
r.fd = -1
}
return err
}

View File

@@ -0,0 +1,205 @@
package tio
import (
"fmt"
"os"
"sync/atomic"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
)
// Maximum size we accept for a single read from a TUN with IFF_VNET_HDR. A
// TSO superpacket can be up to 64KiB of payload plus a single L2/L3/L4 header
// prefix plus the virtio header.
const tunReadBufSize = 65535
type Poll struct {
fd int
readPoll [2]unix.PollFd
writePoll [2]unix.PollFd
closed atomic.Bool
readBuf []byte
batchRet [1][]byte
}
func newPoll(fd int, shutdownFd int) (*Poll, error) {
if err := unix.SetNonblock(fd, true); err != nil {
_ = unix.Close(fd)
return nil, fmt.Errorf("failed to set Poll device as nonblocking: %w", err)
}
out := &Poll{
fd: fd,
readBuf: make([]byte, tunReadBufSize),
readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
writePoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLOUT},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
}
return out, nil
}
// blockOnRead waits until the Poll fd is readable or shutdown has been signaled.
// Returns os.ErrClosed if Close was called.
func (t *Poll) blockOnRead() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(t.readPoll[:], -1)
if err != unix.EINTR {
break
}
}
tunEvents := t.readPoll[0].Revents
shutdownEvents := t.readPoll[1].Revents
t.readPoll[0].Revents = 0
t.readPoll[1].Revents = 0
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
}
if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (t *Poll) blockOnWrite() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(t.writePoll[:], -1)
if err != unix.EINTR {
break
}
}
tunEvents := t.writePoll[0].Revents
shutdownEvents := t.writePoll[1].Revents
t.writePoll[0].Revents = 0
t.writePoll[1].Revents = 0
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
}
if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (t *Poll) Read() ([][]byte, error) {
if t.readBuf == nil {
t.readBuf = make([]byte, defaultBatchBufSize)
}
n, err := t.readOne(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
return t.batchRet[:], nil
}
func (t *Poll) readOne(to []byte) (int, error) {
// first 4 bytes is protocol family, in network byte order
var head [4]byte
iovecs := [2]syscall.Iovec{ //todo plat-specific
{&head[0], 4},
{&to[0], uint64(len(to))},
}
for {
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.fd), uintptr(unsafe.Pointer(&iovecs[0])), 2)
if errno == 0 {
bytesRead := int(n)
if bytesRead < 4 {
return 0, nil
}
return bytesRead - 4, nil
}
switch errno {
case unix.EAGAIN:
if err := t.blockOnRead(); err != nil {
return 0, err
}
case unix.EINTR:
// retry
case unix.EBADF:
return 0, os.ErrClosed
default:
return 0, errno
}
}
}
// Write is only valid for single threaded use
func (t *Poll) Write(from []byte) (int, error) {
if len(from) <= 1 {
return 0, syscall.EIO
}
ipVer := from[0] >> 4
var head [4]byte
// first 4 bytes is protocol family, in network byte order
switch ipVer {
case 4:
head[3] = syscall.AF_INET
case 6:
head[3] = syscall.AF_INET6
default:
return 0, fmt.Errorf("unable to determine IP version from packet")
}
iovecs := [2]syscall.Iovec{ //todo plat specific
{&head[0], 4},
{&from[0], uint64(len(from))},
}
for {
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.fd), uintptr(unsafe.Pointer(&iovecs[0])), 2)
if errno == 0 {
return int(n) - 4, nil
}
switch errno {
case unix.EAGAIN:
if err := t.blockOnWrite(); err != nil {
return 0, err
}
case unix.EINTR:
// retry
case unix.EBADF:
return 0, os.ErrClosed
default:
return 0, errno
}
}
}
func (t *Poll) Close() error {
if t.closed.Swap(true) {
return nil
}
//shutdownFd is owned by the container, so we should not close it
var err error
if t.fd >= 0 {
err = unix.Close(t.fd)
t.fd = -1
}
return err
}
func (t *Poll) WriteReject(p []byte) (int, error) {
return t.Write(p)
}

View File

@@ -0,0 +1,120 @@
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
package tio
import (
"errors"
"os"
"sync"
"testing"
"time"
"golang.org/x/sys/unix"
)
// newReadPipe returns a read fd. The matching write fd is registered for cleanup.
// The caller takes ownership of the read fd (pass it to newTunFd / newFriend).
func newReadPipe(t *testing.T) int {
t.Helper()
var fds [2]int
if err := unix.Pipe2(fds[:], unix.O_CLOEXEC); err != nil {
t.Fatalf("pipe2: %v", err)
}
t.Cleanup(func() { _ = unix.Close(fds[1]) })
return fds[0]
}
func TestTunFile_WakeForShutdown_UnblocksRead(t *testing.T) {
tf, err := newTunFd(newReadPipe(t))
if err != nil {
t.Fatalf("newTunFd: %v", err)
}
t.Cleanup(func() { _ = tf.Close() })
done := make(chan error, 1)
go func() {
_, err := tf.Read(make([]byte, 64))
done <- err
}()
// Verify Read is actually blocked in poll.
select {
case err := <-done:
t.Fatalf("Read returned before shutdown signal: %v", err)
case <-time.After(50 * time.Millisecond):
}
if err := tf.wakeForShutdown(); err != nil {
t.Fatalf("wakeForShutdown: %v", err)
}
select {
case err := <-done:
if !errors.Is(err, os.ErrClosed) {
t.Fatalf("expected os.ErrClosed, got %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("Read did not wake on shutdown")
}
}
func TestTunFile_WakeForShutdown_WakesFriends(t *testing.T) {
parent, err := newTunFd(newReadPipe(t))
if err != nil {
t.Fatalf("newTunFd: %v", err)
}
friend, err := parent.newFriend(newReadPipe(t))
if err != nil {
_ = parent.Close()
t.Fatalf("newFriend: %v", err)
}
t.Cleanup(func() {
_ = friend.Close()
_ = parent.Close()
})
readers := []*tunFile{parent, friend}
errs := make([]error, len(readers))
var wg sync.WaitGroup
for i, r := range readers {
wg.Add(1)
go func(i int, r *tunFile) {
defer wg.Done()
_, errs[i] = r.Read(make([]byte, 64))
}(i, r)
}
time.Sleep(50 * time.Millisecond)
if err := parent.wakeForShutdown(); err != nil {
t.Fatalf("wakeForShutdown: %v", err)
}
done := make(chan struct{})
go func() { wg.Wait(); close(done) }()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("readers did not wake")
}
for i, err := range errs {
if !errors.Is(err, os.ErrClosed) {
t.Errorf("reader %d: expected os.ErrClosed, got %v", i, err)
}
}
}
func TestTunFile_Close_Idempotent(t *testing.T) {
tf, err := newTunFd(newReadPipe(t))
if err != nil {
t.Fatalf("newTunFd: %v", err)
}
if err := tf.Close(); err != nil {
t.Fatalf("first Close: %v", err)
}
if err := tf.Close(); err != nil {
t.Fatalf("second Close should be a no-op, got %v", err)
}
}

View File

@@ -0,0 +1,275 @@
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
package tio
import (
"encoding/binary"
"fmt"
"golang.org/x/sys/unix"
)
// segmentInto splits a TUN-side packet described by hdr into one or more
// IP packets, each appended to *out as a slice of scratch. scratch must be
// sized to hold every segment (including replicated headers).
func segmentInto(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) error {
// When RSC_INFO is set the csum_start/csum_offset fields are repurposed to
// carry coalescing info rather than checksum offsets. A TUN writing via
// IFF_VNET_HDR should never emit this, but if it did we would silently
// miscompute the segment checksums — refuse the packet instead.
if hdr.Flags&unix.VIRTIO_NET_HDR_F_RSC_INFO != 0 {
return fmt.Errorf("virtio RSC_INFO flag not supported on TUN reads")
}
switch hdr.GSOType {
case unix.VIRTIO_NET_HDR_GSO_NONE:
if len(pkt) > len(scratch) {
return fmt.Errorf("packet larger than segment buffer: %d > %d", len(pkt), len(scratch))
}
copy(scratch, pkt)
seg := scratch[:len(pkt)]
if hdr.Flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
if err := finishChecksum(seg, hdr); err != nil {
return err
}
}
*out = append(*out, seg)
return nil
case unix.VIRTIO_NET_HDR_GSO_TCPV4, unix.VIRTIO_NET_HDR_GSO_TCPV6:
return segmentTCP(pkt, hdr, out, scratch)
default:
return fmt.Errorf("unsupported virtio gso type: %d", hdr.GSOType)
}
}
// finishChecksum computes the L4 checksum for a non-GSO packet that the kernel
// handed us with NEEDS_CSUM set. csum_start / csum_offset point at the 16-bit
// checksum field; we zero it, fold a full sum (the field was pre-loaded with
// the pseudo-header partial sum by the kernel), and store the result.
func finishChecksum(seg []byte, hdr VirtioNetHdr) error {
cs := int(hdr.CsumStart)
co := int(hdr.CsumOffset)
if cs+co+2 > len(seg) {
return fmt.Errorf("csum offsets out of range: start=%d offset=%d len=%d", cs, co, len(seg))
}
// The kernel stores a partial pseudo-header sum at [cs+co:]; sum over the
// L4 region starting at cs, folding the prior partial in as the seed.
partial := uint32(binary.BigEndian.Uint16(seg[cs+co : cs+co+2]))
seg[cs+co] = 0
seg[cs+co+1] = 0
sum := checksumBytes(seg[cs:], partial)
binary.BigEndian.PutUint16(seg[cs+co:cs+co+2], checksumFold(sum))
return nil
}
// segmentTCP software-segments a TSO superpacket into one IP packet per MSS
// chunk. The caller guarantees hdr.GSOType is TCPV4 or TCPV6.
//
// Hot-path shape: the per-segment loop only sums the payload chunk. The TCP
// header, the IPv4 header, and the pseudo-header src/dst/proto contributions
// are each summed once up front — every segment reuses those three pre-folded
// uint32 values and combines them with small per-segment deltas (seq, flags,
// tcpLen, ip_id, total_len) that are cheap to fold in.
func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) error {
if hdr.GSOSize == 0 {
return fmt.Errorf("gso_size is zero")
}
if int(hdr.HdrLen) > len(pkt) || hdr.HdrLen == 0 {
return fmt.Errorf("hdr_len %d out of range (pkt %d)", hdr.HdrLen, len(pkt))
}
if hdr.CsumStart == 0 || hdr.CsumStart >= hdr.HdrLen {
return fmt.Errorf("csum_start %d out of range (hdr_len %d)", hdr.CsumStart, hdr.HdrLen)
}
isV4 := hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_TCPV4
headerLen := int(hdr.HdrLen)
csumStart := int(hdr.CsumStart)
if isV4 && csumStart < 20 {
return fmt.Errorf("csum_start %d too small for IPv4", csumStart)
}
if !isV4 && csumStart < 40 {
return fmt.Errorf("csum_start %d too small for IPv6", csumStart)
}
tcpHdrLen := headerLen - csumStart
if tcpHdrLen < 20 {
return fmt.Errorf("tcp header region too small: %d", tcpHdrLen)
}
payload := pkt[headerLen:]
payLen := len(payload)
gso := int(hdr.GSOSize)
numSeg := (payLen + gso - 1) / gso
if numSeg == 0 {
numSeg = 1
}
need := numSeg*headerLen + payLen
if need > len(scratch) {
return fmt.Errorf("scratch too small for %d segments: need %d have %d", numSeg, need, len(scratch))
}
origSeq := binary.BigEndian.Uint32(pkt[csumStart+4 : csumStart+8])
origFlags := pkt[csumStart+13]
const tcpFinPsh = 0x09 // FIN(0x01) | PSH(0x08)
// Precompute the TCP header sum with seq/flags/csum zeroed. The max TCP
// header is 60 bytes; copy onto the stack, zero the per-segment-varying
// fields, sum once.
var tmp [60]byte
copy(tmp[:tcpHdrLen], pkt[csumStart:headerLen])
tmp[4], tmp[5], tmp[6], tmp[7] = 0, 0, 0, 0 // seq
tmp[13] = 0 // flags
tmp[16], tmp[17] = 0, 0 // csum
baseTcpHdrSum := checksumBytes(tmp[:tcpHdrLen], 0)
// Pseudo-header src+dst+proto contribution (tcpLen varies per segment).
var baseProtoSum uint32
if isV4 {
baseProtoSum = checksumBytes(pkt[12:16], 0)
baseProtoSum = checksumBytes(pkt[16:20], baseProtoSum)
} else {
baseProtoSum = checksumBytes(pkt[8:24], 0)
baseProtoSum = checksumBytes(pkt[24:40], baseProtoSum)
}
baseProtoSum += uint32(unix.IPPROTO_TCP)
// Precompute IPv4 header sum with total_len/id/csum zeroed.
var origIPID uint16
var ihl int
var baseIPHdrSum uint32
if isV4 {
origIPID = binary.BigEndian.Uint16(pkt[4:6])
ihl = int(pkt[0]&0x0f) * 4
if ihl < 20 || ihl > csumStart {
return fmt.Errorf("bad IPv4 IHL: %d", ihl)
}
var ipTmp [60]byte
copy(ipTmp[:ihl], pkt[:ihl])
ipTmp[2], ipTmp[3] = 0, 0 // total_len
ipTmp[4], ipTmp[5] = 0, 0 // id
ipTmp[10], ipTmp[11] = 0, 0 // checksum
baseIPHdrSum = checksumBytes(ipTmp[:ihl], 0)
}
off := 0
for i := 0; i < numSeg; i++ {
segStart := i * gso
segEnd := segStart + gso
if segEnd > payLen {
segEnd = payLen
}
segPayLen := segEnd - segStart
copy(scratch[off:], pkt[:headerLen])
copy(scratch[off+headerLen:], payload[segStart:segEnd])
seg := scratch[off : off+headerLen+segPayLen]
off += headerLen + segPayLen
segSeq := origSeq + uint32(segStart)
segFlags := origFlags
if i != numSeg-1 {
segFlags = origFlags &^ tcpFinPsh
}
totalLen := headerLen + segPayLen
// Patch IP header and write the v4 header checksum from the precomputed base.
if isV4 {
segID := origIPID + uint16(i)
binary.BigEndian.PutUint16(seg[2:4], uint16(totalLen))
binary.BigEndian.PutUint16(seg[4:6], segID)
ipSum := baseIPHdrSum + uint32(totalLen) + uint32(segID)
binary.BigEndian.PutUint16(seg[10:12], checksumFold(ipSum))
} else {
// IPv6 payload length excludes the 40-byte fixed header but
// includes any extension headers between [40:csumStart].
binary.BigEndian.PutUint16(seg[4:6], uint16(headerLen-40+segPayLen))
}
// Patch TCP header.
binary.BigEndian.PutUint32(seg[csumStart+4:csumStart+8], segSeq)
seg[csumStart+13] = segFlags
// (csum is written below; its prior contents in `seg` don't affect the
// computation since we never sum over the segment's own header.)
tcpLen := tcpHdrLen + segPayLen
paySum := checksumBytes(payload[segStart:segEnd], 0)
// Combine pre-folded uint32s into a wider accumulator, then fold. Using
// uint64 guards against overflow when segSeq's high bits set.
wide := uint64(baseTcpHdrSum) + uint64(paySum) + uint64(baseProtoSum)
wide += uint64(segSeq) + uint64(segFlags) + uint64(tcpLen)
wide = (wide & 0xffffffff) + (wide >> 32)
wide = (wide & 0xffffffff) + (wide >> 32)
binary.BigEndian.PutUint16(seg[csumStart+16:csumStart+18], checksumFold(uint32(wide)))
*out = append(*out, seg)
}
return nil
}
// checksumBytes returns the Internet-checksum partial sum of b, seeded with
// initial. Result is a 32-bit accumulator; the caller folds to 16.
//
// Each 4-byte load is added directly into a 64-bit accumulator. Two parallel
// accumulators break the serial dependency through `sum` and let the CPU
// overlap independent adds. The final fold from 64 → 32 → 16 handles the
// carries that accumulated across the 32-bit lane boundary.
func checksumBytes(b []byte, initial uint32) uint32 {
s0 := uint64(initial)
var s1 uint64
for len(b) >= 32 {
s0 += uint64(binary.BigEndian.Uint32(b[0:4]))
s1 += uint64(binary.BigEndian.Uint32(b[4:8]))
s0 += uint64(binary.BigEndian.Uint32(b[8:12]))
s1 += uint64(binary.BigEndian.Uint32(b[12:16]))
s0 += uint64(binary.BigEndian.Uint32(b[16:20]))
s1 += uint64(binary.BigEndian.Uint32(b[20:24]))
s0 += uint64(binary.BigEndian.Uint32(b[24:28]))
s1 += uint64(binary.BigEndian.Uint32(b[28:32]))
b = b[32:]
}
sum := s0 + s1
for len(b) >= 4 {
sum += uint64(binary.BigEndian.Uint32(b[:4]))
b = b[4:]
}
if len(b) >= 2 {
sum += uint64(binary.BigEndian.Uint16(b[:2]))
b = b[2:]
}
if len(b) == 1 {
sum += uint64(b[0]) << 8
}
sum = (sum & 0xffffffff) + (sum >> 32)
sum = (sum & 0xffffffff) + (sum >> 32)
return uint32(sum)
}
func checksumFold(sum uint32) uint16 {
for sum>>16 != 0 {
sum = (sum & 0xffff) + (sum >> 16)
}
return ^uint16(sum)
}
func pseudoHeaderIPv4(src, dst []byte, proto byte, tcpLen int) uint32 {
sum := checksumBytes(src, 0)
sum = checksumBytes(dst, sum)
sum += uint32(proto)
sum += uint32(tcpLen)
return sum
}
func pseudoHeaderIPv6(src, dst []byte, proto byte, tcpLen int) uint32 {
sum := checksumBytes(src, 0)
sum = checksumBytes(dst, sum)
sum += uint32(tcpLen >> 16)
sum += uint32(tcpLen & 0xffff)
sum += uint32(proto)
return sum
}

View File

@@ -0,0 +1,333 @@
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
package tio
import (
"encoding/binary"
"os"
"testing"
"golang.org/x/sys/unix"
)
// verifyChecksum confirms that the one's-complement sum across `b`, optionally
// seeded with a pseudo-header sum, folds to all-ones (valid).
func verifyChecksum(b []byte, pseudo uint32) bool {
sum := checksumBytes(b, pseudo)
for sum>>16 != 0 {
sum = (sum & 0xffff) + (sum >> 16)
}
return uint16(sum) == 0xffff
}
// buildTSOv4 builds a synthetic IPv4/TCP TSO superpacket with a payload of
// `payLen` bytes split at `mss`.
func buildTSOv4(t *testing.T, payLen, mss int) ([]byte, VirtioNetHdr) {
t.Helper()
const ipLen = 20
const tcpLen = 20
pkt := make([]byte, ipLen+tcpLen+payLen)
// IPv4 header
pkt[0] = 0x45 // version 4, IHL 5
// total length is meaningless for TSO but set it anyway
binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+tcpLen+payLen))
binary.BigEndian.PutUint16(pkt[4:6], 0x4242) // original ID
pkt[8] = 64 // TTL
pkt[9] = unix.IPPROTO_TCP
copy(pkt[12:16], []byte{10, 0, 0, 1}) // src
copy(pkt[16:20], []byte{10, 0, 0, 2}) // dst
// TCP header
binary.BigEndian.PutUint16(pkt[20:22], 12345) // sport
binary.BigEndian.PutUint16(pkt[22:24], 80) // dport
binary.BigEndian.PutUint32(pkt[24:28], 10000) // seq
binary.BigEndian.PutUint32(pkt[28:32], 20000) // ack
pkt[32] = 0x50 // data offset 5 words
pkt[33] = 0x18 // ACK | PSH
binary.BigEndian.PutUint16(pkt[34:36], 65535) // window
// payload
for i := 0; i < payLen; i++ {
pkt[ipLen+tcpLen+i] = byte(i & 0xff)
}
return pkt, VirtioNetHdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
HdrLen: uint16(ipLen + tcpLen),
GSOSize: uint16(mss),
CsumStart: uint16(ipLen),
CsumOffset: 16,
}
}
func TestSegmentTCPv4(t *testing.T) {
const mss = 100
const numSeg = 3
pkt, hdr := buildTSOv4(t, mss*numSeg, mss)
scratch := make([]byte, tunSegBufSize)
var out [][]byte
if err := segmentTCP(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentTCP: %v", err)
}
if len(out) != numSeg {
t.Fatalf("expected %d segments, got %d", numSeg, len(out))
}
for i, seg := range out {
if len(seg) != 40+mss {
t.Errorf("seg %d: unexpected len %d", i, len(seg))
}
totalLen := binary.BigEndian.Uint16(seg[2:4])
if totalLen != uint16(40+mss) {
t.Errorf("seg %d: total_len=%d want %d", i, totalLen, 40+mss)
}
id := binary.BigEndian.Uint16(seg[4:6])
if id != 0x4242+uint16(i) {
t.Errorf("seg %d: ip id=%#x want %#x", i, id, 0x4242+uint16(i))
}
seq := binary.BigEndian.Uint32(seg[24:28])
wantSeq := uint32(10000 + i*mss)
if seq != wantSeq {
t.Errorf("seg %d: seq=%d want %d", i, seq, wantSeq)
}
flags := seg[33]
wantFlags := byte(0x10) // ACK only, PSH cleared
if i == numSeg-1 {
wantFlags = 0x18 // ACK | PSH preserved on last
}
if flags != wantFlags {
t.Errorf("seg %d: flags=%#x want %#x", i, flags, wantFlags)
}
// IPv4 header checksum must verify against itself.
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
// TCP checksum must verify against the pseudo-header.
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+mss)
if !verifyChecksum(seg[20:], psum) {
t.Errorf("seg %d: bad TCP checksum", i)
}
}
}
func TestSegmentTCPv4OddTail(t *testing.T) {
// Payload of 250 bytes with MSS 100 → segments of 100, 100, 50.
pkt, hdr := buildTSOv4(t, 250, 100)
scratch := make([]byte, tunSegBufSize)
var out [][]byte
if err := segmentTCP(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentTCP: %v", err)
}
if len(out) != 3 {
t.Fatalf("want 3 segments, got %d", len(out))
}
wantPayLens := []int{100, 100, 50}
for i, seg := range out {
if len(seg)-40 != wantPayLens[i] {
t.Errorf("seg %d: pay len %d want %d", i, len(seg)-40, wantPayLens[i])
}
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+wantPayLens[i])
if !verifyChecksum(seg[20:], psum) {
t.Errorf("seg %d: bad TCP checksum", i)
}
}
}
func TestSegmentTCPv6(t *testing.T) {
const ipLen = 40
const tcpLen = 20
const mss = 120
const numSeg = 2
payLen := mss * numSeg
pkt := make([]byte, ipLen+tcpLen+payLen)
// IPv6 header
pkt[0] = 0x60 // version 6
binary.BigEndian.PutUint16(pkt[4:6], uint16(tcpLen+payLen))
pkt[6] = unix.IPPROTO_TCP
pkt[7] = 64
// src/dst fe80::1 / fe80::2
pkt[8] = 0xfe
pkt[9] = 0x80
pkt[23] = 1
pkt[24] = 0xfe
pkt[25] = 0x80
pkt[39] = 2
// TCP header
binary.BigEndian.PutUint16(pkt[40:42], 12345)
binary.BigEndian.PutUint16(pkt[42:44], 80)
binary.BigEndian.PutUint32(pkt[44:48], 7)
binary.BigEndian.PutUint32(pkt[48:52], 99)
pkt[52] = 0x50
pkt[53] = 0x19 // FIN | ACK | PSH — exercise FIN clearing too
binary.BigEndian.PutUint16(pkt[54:56], 65535)
for i := 0; i < payLen; i++ {
pkt[ipLen+tcpLen+i] = byte(i)
}
hdr := VirtioNetHdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
HdrLen: uint16(ipLen + tcpLen),
GSOSize: uint16(mss),
CsumStart: uint16(ipLen),
CsumOffset: 16,
}
scratch := make([]byte, tunSegBufSize)
var out [][]byte
if err := segmentTCP(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentTCP: %v", err)
}
if len(out) != numSeg {
t.Fatalf("want %d segments, got %d", numSeg, len(out))
}
for i, seg := range out {
if len(seg) != ipLen+tcpLen+mss {
t.Errorf("seg %d: len %d want %d", i, len(seg), ipLen+tcpLen+mss)
}
pl := binary.BigEndian.Uint16(seg[4:6])
if pl != uint16(tcpLen+mss) {
t.Errorf("seg %d: payload_length=%d want %d", i, pl, tcpLen+mss)
}
seq := binary.BigEndian.Uint32(seg[44:48])
if seq != uint32(7+i*mss) {
t.Errorf("seg %d: seq=%d want %d", i, seq, 7+i*mss)
}
flags := seg[53]
// Original flags = 0x19 (FIN|ACK|PSH). FIN(0x01)+PSH(0x08) should be
// cleared on all but the last; ACK(0x10) always preserved.
wantFlags := byte(0x10)
if i == numSeg-1 {
wantFlags = 0x19
}
if flags != wantFlags {
t.Errorf("seg %d: flags=%#x want %#x", i, flags, wantFlags)
}
psum := pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_TCP, tcpLen+mss)
if !verifyChecksum(seg[ipLen:], psum) {
t.Errorf("seg %d: bad TCP checksum", i)
}
}
}
func TestSegmentGSONonePassesThrough(t *testing.T) {
pkt, hdr := buildTSOv4(t, 100, 100)
hdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE
hdr.Flags = 0 // no NEEDS_CSUM, leave packet untouched
scratch := make([]byte, tunSegBufSize)
var out [][]byte
if err := segmentInto(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentInto: %v", err)
}
if len(out) != 1 {
t.Fatalf("want 1 segment, got %d", len(out))
}
if len(out[0]) != len(pkt) {
t.Fatalf("unexpected length: %d vs %d", len(out[0]), len(pkt))
}
}
func TestSegmentRejectsUDP(t *testing.T) {
hdr := VirtioNetHdr{GSOType: unix.VIRTIO_NET_HDR_GSO_UDP}
var out [][]byte
if err := segmentInto(nil, hdr, &out, nil); err == nil {
t.Fatalf("expected rejection for UDP GSO")
}
}
func BenchmarkSegmentTCPv4(b *testing.B) {
sizes := []struct {
name string
payLen int
mss int
}{
{"64KiB_MSS1460", 65000, 1460},
{"16KiB_MSS1460", 16384, 1460},
{"4KiB_MSS1460", 4096, 1460},
}
for _, sz := range sizes {
b.Run(sz.name, func(b *testing.B) {
const ipLen = 20
const tcpLen = 20
pkt := make([]byte, ipLen+tcpLen+sz.payLen)
pkt[0] = 0x45
binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+tcpLen+sz.payLen))
binary.BigEndian.PutUint16(pkt[4:6], 0x4242)
pkt[8] = 64
pkt[9] = unix.IPPROTO_TCP
copy(pkt[12:16], []byte{10, 0, 0, 1})
copy(pkt[16:20], []byte{10, 0, 0, 2})
binary.BigEndian.PutUint16(pkt[20:22], 12345)
binary.BigEndian.PutUint16(pkt[22:24], 80)
binary.BigEndian.PutUint32(pkt[24:28], 10000)
binary.BigEndian.PutUint32(pkt[28:32], 20000)
pkt[32] = 0x50
pkt[33] = 0x18
binary.BigEndian.PutUint16(pkt[34:36], 65535)
for i := 0; i < sz.payLen; i++ {
pkt[ipLen+tcpLen+i] = byte(i)
}
hdr := VirtioNetHdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
HdrLen: uint16(ipLen + tcpLen),
GSOSize: uint16(sz.mss),
CsumStart: uint16(ipLen),
CsumOffset: 16,
}
scratch := make([]byte, tunSegBufSize)
out := make([][]byte, 0, 64)
b.SetBytes(int64(len(pkt)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
out = out[:0]
if err := segmentTCP(pkt, hdr, &out, scratch); err != nil {
b.Fatal(err)
}
}
})
}
}
// TestTunFileWriteVnetHdrNoAlloc verifies the IFF_VNET_HDR fast-path write is
// allocation-free. We write to /dev/null so every call succeeds synchronously.
func TestTunFileWriteVnetHdrNoAlloc(t *testing.T) {
fd, err := unix.Open("/dev/null", os.O_WRONLY, 0)
if err != nil {
t.Fatalf("open /dev/null: %v", err)
}
t.Cleanup(func() { _ = unix.Close(fd) })
tf := &tunFile{fd: fd}
tf.writeIovs[0].Base = &validVnetHdr[0]
tf.writeIovs[0].SetLen(virtioNetHdrLen)
payload := make([]byte, 1400)
// Warm up (first call may trigger one-time internal allocations elsewhere).
if _, err := tf.Write(payload); err != nil {
t.Fatalf("Write: %v", err)
}
allocs := testing.AllocsPerRun(1000, func() {
if _, err := tf.Write(payload); err != nil {
t.Fatalf("Write: %v", err)
}
})
if allocs != 0 {
t.Fatalf("Write allocated %.1f times per call, want 0", allocs)
}
}

View File

@@ -0,0 +1,38 @@
//go:build !e2e_testing
// +build !e2e_testing
package tio
import (
"testing"
"github.com/slackhq/nebula/overlay"
)
var runAdvMSSTests = []struct {
name string
tun *overlay.tun
r overlay.Route
expected int
}{
// Standard case, default MTU is the device max MTU
{"default", &overlay.tun{DefaultMTU: 1440, MaxMTU: 1440}, overlay.Route{}, 0},
{"default-min", &overlay.tun{DefaultMTU: 1440, MaxMTU: 1440}, overlay.Route{MTU: 1440}, 0},
{"default-low", &overlay.tun{DefaultMTU: 1440, MaxMTU: 1440}, overlay.Route{MTU: 1200}, 1160},
// Case where we have a route MTU set higher than the default
{"route", &overlay.tun{DefaultMTU: 1440, MaxMTU: 8941}, overlay.Route{}, 1400},
{"route-min", &overlay.tun{DefaultMTU: 1440, MaxMTU: 8941}, overlay.Route{MTU: 1440}, 1400},
{"route-high", &overlay.tun{DefaultMTU: 1440, MaxMTU: 8941}, overlay.Route{MTU: 8941}, 0},
}
func TestTunAdvMSS(t *testing.T) {
for _, tt := range runAdvMSSTests {
t.Run(tt.name, func(t *testing.T) {
o := tt.tun.advMSS(tt.r)
if o != tt.expected {
t.Errorf("got %d, want %d", o, tt.expected)
}
})
}
}

View File

@@ -0,0 +1,39 @@
package tio
import "encoding/binary"
// Size of the legacy struct virtio_net_hdr that the kernel prepends/expects on
// a TUN opened with IFF_VNET_HDR (TUNSETVNETHDRSZ not set).
const virtioNetHdrLen = 10
type VirtioNetHdr struct {
Flags uint8
GSOType uint8
HdrLen uint16
GSOSize uint16
CsumStart uint16
CsumOffset uint16
}
// decode reads a virtio_net_hdr in host byte order (TUN default; we never
// call TUNSETVNETLE so the kernel matches our endianness).
func (h *VirtioNetHdr) decode(b []byte) {
h.Flags = b[0]
h.GSOType = b[1]
h.HdrLen = binary.NativeEndian.Uint16(b[2:4])
h.GSOSize = binary.NativeEndian.Uint16(b[4:6])
h.CsumStart = binary.NativeEndian.Uint16(b[6:8])
h.CsumOffset = binary.NativeEndian.Uint16(b[8:10])
}
// encode is the inverse of decode: writes the virtio_net_hdr fields into b
// (must be at least virtioNetHdrLen bytes). Used to emit a TSO superpacket
// on egress.
func (h *VirtioNetHdr) encode(b []byte) {
b[0] = h.Flags
b[1] = h.GSOType
binary.NativeEndian.PutUint16(b[2:4], h.HdrLen)
binary.NativeEndian.PutUint16(b[4:6], h.GSOSize)
binary.NativeEndian.PutUint16(b[6:8], h.CsumStart)
binary.NativeEndian.PutUint16(b[8:10], h.CsumOffset)
}