mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-16 04:47:38 +02:00
GSO again
This commit is contained in:
@@ -7,12 +7,63 @@ import (
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
// defaultBatchBufSize is the per-Queue scratch size for Read on backends
|
||||
// that don't do TSO segmentation. 65535 covers any single IP packet.
|
||||
const defaultBatchBufSize = 65535
|
||||
|
||||
// Queue is a readable/writable tun queue. One Queue is driven by a single
|
||||
// read goroutine plus concurrent writers (see Write / WriteReject below).
|
||||
type Queue interface {
|
||||
io.Closer
|
||||
|
||||
// Read returns one or more packets. The returned slices are borrowed
|
||||
// from the Queue's internal buffer and are only valid until the next
|
||||
// Read or Close on this Queue — callers must encrypt or copy each
|
||||
// slice before the next call. Not safe for concurrent Reads; exactly
|
||||
// one goroutine per Queue reads.
|
||||
Read() ([][]byte, error)
|
||||
|
||||
// Write emits a single packet on the plaintext (outside→inside)
|
||||
// delivery path. May run concurrently with WriteReject on the same
|
||||
// Queue, but not with itself.
|
||||
Write(p []byte) (int, error)
|
||||
|
||||
// WriteReject writes a single packet that originated from the inside
|
||||
// path (reject replies or self-forward) using scratch state distinct
|
||||
// from Write, so it can run concurrently with Write on the same Queue
|
||||
// without a data race. On backends without a shared-scratch Write, a
|
||||
// trivial delegation to Write is acceptable.
|
||||
WriteReject(p []byte) (int, error)
|
||||
}
|
||||
|
||||
type Device interface {
|
||||
io.ReadWriteCloser
|
||||
Queue
|
||||
Activate() error
|
||||
Networks() []netip.Prefix
|
||||
Name() string
|
||||
RoutesFor(netip.Addr) routing.Gateways
|
||||
SupportsMultiqueue() bool
|
||||
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
||||
NewMultiQueueReader() (Queue, error)
|
||||
}
|
||||
|
||||
// GSOWriter is implemented by Queues that can emit a TCP TSO superpacket
|
||||
// assembled from a header prefix plus one or more borrowed payload
|
||||
// fragments, in a single vectored write (writev with a leading
|
||||
// virtio_net_hdr). This lets the coalescer avoid copying payload bytes
|
||||
// between the caller's decrypt buffer and the TUN. Backends without GSO
|
||||
// support return false from GSOSupported and coalescing is skipped.
|
||||
//
|
||||
// hdr contains the IPv4/IPv6 + TCP header prefix (mutable — callers will
|
||||
// have filled in total length and pseudo-header partial). pays are
|
||||
// non-overlapping payload fragments whose concatenation is the full
|
||||
// superpacket payload; they are read-only from the writer's perspective
|
||||
// and must remain valid until the call returns. gsoSize is the MSS:
|
||||
// every segment except possibly the last is exactly that many bytes.
|
||||
// csumStart is the byte offset where the TCP header begins within hdr.
|
||||
//
|
||||
// hdr's TCP checksum field must already hold the pseudo-header partial
|
||||
// sum (single-fold, not inverted), per virtio NEEDS_CSUM semantics.
|
||||
type GSOWriter interface {
|
||||
WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error
|
||||
GSOSupported() bool
|
||||
}
|
||||
|
||||
50
overlay/noop.go
Normal file
50
overlay/noop.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
type NoopTun struct{}
|
||||
|
||||
func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways {
|
||||
return routing.Gateways{}
|
||||
}
|
||||
|
||||
func (NoopTun) Activate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (NoopTun) Networks() []netip.Prefix {
|
||||
return []netip.Prefix{}
|
||||
}
|
||||
|
||||
func (NoopTun) Name() string {
|
||||
return "noop"
|
||||
}
|
||||
|
||||
func (NoopTun) Read() ([][]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (NoopTun) Write([]byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (NoopTun) WriteReject(p []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (NoopTun) SupportsMultiqueue() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (NoopTun) NewMultiQueueReader() (Queue, error) {
|
||||
return nil, errors.New("unsupported")
|
||||
}
|
||||
|
||||
func (NoopTun) Close() error {
|
||||
return nil
|
||||
}
|
||||
@@ -18,12 +18,39 @@ import (
|
||||
)
|
||||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
rwc io.ReadWriteCloser
|
||||
fd int
|
||||
vpnNetworks []netip.Prefix
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
|
||||
readBuf []byte
|
||||
batchRet [1][]byte
|
||||
}
|
||||
|
||||
func (t *tun) Read() ([][]byte, error) {
|
||||
if t.readBuf == nil {
|
||||
t.readBuf = make([]byte, defaultBatchBufSize)
|
||||
}
|
||||
n, err := t.rwc.Read(t.readBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.batchRet[0] = t.readBuf[:n]
|
||||
return t.batchRet[:], nil
|
||||
}
|
||||
|
||||
func (t *tun) Write(p []byte) (int, error) {
|
||||
return t.rwc.Write(p)
|
||||
}
|
||||
|
||||
func (t *tun) WriteReject(p []byte) (int, error) {
|
||||
return t.rwc.Write(p)
|
||||
}
|
||||
|
||||
func (t *tun) Close() error {
|
||||
return t.rwc.Close()
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
@@ -32,10 +59,10 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
fd: deviceFd,
|
||||
vpnNetworks: vpnNetworks,
|
||||
l: l,
|
||||
rwc: file,
|
||||
fd: deviceFd,
|
||||
vpnNetworks: vpnNetworks,
|
||||
l: l,
|
||||
}
|
||||
|
||||
err := t.reload(c, true)
|
||||
@@ -99,6 +126,6 @@ func (t *tun) SupportsMultiqueue() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
func (t *tun) NewMultiQueueReader() (Queue, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ import (
|
||||
)
|
||||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
rwc io.ReadWriteCloser
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
DefaultMTU int
|
||||
@@ -34,6 +34,9 @@ type tun struct {
|
||||
|
||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||
out []byte
|
||||
|
||||
readBuf []byte
|
||||
batchRet [1][]byte
|
||||
}
|
||||
|
||||
type ifReq struct {
|
||||
@@ -124,11 +127,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
}
|
||||
|
||||
t := &tun{
|
||||
ReadWriteCloser: os.NewFile(uintptr(fd), ""),
|
||||
Device: name,
|
||||
vpnNetworks: vpnNetworks,
|
||||
DefaultMTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
rwc: os.NewFile(uintptr(fd), ""),
|
||||
Device: name,
|
||||
vpnNetworks: vpnNetworks,
|
||||
DefaultMTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
}
|
||||
|
||||
err = t.reload(c, true)
|
||||
@@ -158,8 +161,8 @@ func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun,
|
||||
}
|
||||
|
||||
func (t *tun) Close() error {
|
||||
if t.ReadWriteCloser != nil {
|
||||
return t.ReadWriteCloser.Close()
|
||||
if t.rwc != nil {
|
||||
return t.rwc.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -503,15 +506,31 @@ func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
func (t *tun) readOne(to []byte) (int, error) {
|
||||
buf := make([]byte, len(to)+4)
|
||||
|
||||
n, err := t.ReadWriteCloser.Read(buf)
|
||||
n, err := t.rwc.Read(buf)
|
||||
|
||||
copy(to, buf[4:])
|
||||
return n - 4, err
|
||||
}
|
||||
|
||||
func (t *tun) Read() ([][]byte, error) {
|
||||
if t.readBuf == nil {
|
||||
t.readBuf = make([]byte, defaultBatchBufSize)
|
||||
}
|
||||
n, err := t.readOne(t.readBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.batchRet[0] = t.readBuf[:n]
|
||||
return t.batchRet[:], nil
|
||||
}
|
||||
|
||||
func (t *tun) WriteReject(p []byte) (int, error) {
|
||||
return t.Write(p)
|
||||
}
|
||||
|
||||
// Write is only valid for single threaded use
|
||||
func (t *tun) Write(from []byte) (int, error) {
|
||||
buf := t.out
|
||||
@@ -537,7 +556,7 @@ func (t *tun) Write(from []byte) (int, error) {
|
||||
|
||||
copy(buf[4:], from)
|
||||
|
||||
n, err := t.ReadWriteCloser.Write(buf)
|
||||
n, err := t.rwc.Write(buf)
|
||||
return n - 4, err
|
||||
}
|
||||
|
||||
@@ -553,6 +572,6 @@ func (t *tun) SupportsMultiqueue() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
func (t *tun) NewMultiQueueReader() (Queue, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
||||
}
|
||||
|
||||
@@ -20,6 +20,23 @@ type disabledTun struct {
|
||||
tx metrics.Counter
|
||||
rx metrics.Counter
|
||||
l *logrus.Logger
|
||||
|
||||
batchRet [1][]byte
|
||||
}
|
||||
|
||||
func (t *disabledTun) Read() ([][]byte, error) {
|
||||
r, ok := <-t.read
|
||||
if !ok {
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
t.tx.Inc(1)
|
||||
if t.l.Level >= logrus.DebugLevel {
|
||||
t.l.WithField("raw", prettyPacket(r)).Debugf("Write payload")
|
||||
}
|
||||
|
||||
t.batchRet[0] = r
|
||||
return t.batchRet[:], nil
|
||||
}
|
||||
|
||||
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
|
||||
@@ -56,24 +73,6 @@ func (*disabledTun) Name() string {
|
||||
return "disabled"
|
||||
}
|
||||
|
||||
func (t *disabledTun) Read(b []byte) (int, error) {
|
||||
r, ok := <-t.read
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if len(r) > len(b) {
|
||||
return 0, fmt.Errorf("packet larger than mtu: %d > %d bytes", len(r), len(b))
|
||||
}
|
||||
|
||||
t.tx.Inc(1)
|
||||
if t.l.Level >= logrus.DebugLevel {
|
||||
t.l.WithField("raw", prettyPacket(r)).Debugf("Write payload")
|
||||
}
|
||||
|
||||
return copy(b, r), nil
|
||||
}
|
||||
|
||||
func (t *disabledTun) handleICMPEchoRequest(b []byte) bool {
|
||||
out := make([]byte, len(b))
|
||||
out = iputil.CreateICMPEchoResponse(b, out)
|
||||
@@ -105,11 +104,15 @@ func (t *disabledTun) Write(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (t *disabledTun) WriteReject(b []byte) (int, error) {
|
||||
return t.Write(b)
|
||||
}
|
||||
|
||||
func (t *disabledTun) SupportsMultiqueue() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
func (t *disabledTun) NewMultiQueueReader() (Queue, error) {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/netip"
|
||||
"os"
|
||||
@@ -101,6 +100,9 @@ type tun struct {
|
||||
readPoll [2]unix.PollFd
|
||||
writePoll [2]unix.PollFd
|
||||
closed atomic.Bool
|
||||
|
||||
readBuf []byte
|
||||
batchRet [1][]byte
|
||||
}
|
||||
|
||||
// blockOnRead waits until the tun fd is readable or shutdown has been signaled.
|
||||
@@ -155,7 +157,23 @@ func (t *tun) blockOnWrite() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
func (t *tun) Read() ([][]byte, error) {
|
||||
if t.readBuf == nil {
|
||||
t.readBuf = make([]byte, defaultBatchBufSize)
|
||||
}
|
||||
n, err := t.readOne(t.readBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.batchRet[0] = t.readBuf[:n]
|
||||
return t.batchRet[:], nil
|
||||
}
|
||||
|
||||
func (t *tun) WriteReject(p []byte) (int, error) {
|
||||
return t.Write(p)
|
||||
}
|
||||
|
||||
func (t *tun) readOne(to []byte) (int, error) {
|
||||
// first 4 bytes is protocol family, in network byte order
|
||||
var head [4]byte
|
||||
iovecs := [2]syscall.Iovec{
|
||||
@@ -563,7 +581,7 @@ func (t *tun) SupportsMultiqueue() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
func (t *tun) NewMultiQueueReader() (Queue, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
||||
}
|
||||
|
||||
|
||||
@@ -21,11 +21,38 @@ import (
|
||||
)
|
||||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
rwc io.ReadWriteCloser
|
||||
vpnNetworks []netip.Prefix
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
|
||||
readBuf []byte
|
||||
batchRet [1][]byte
|
||||
}
|
||||
|
||||
func (t *tun) Read() ([][]byte, error) {
|
||||
if t.readBuf == nil {
|
||||
t.readBuf = make([]byte, defaultBatchBufSize)
|
||||
}
|
||||
n, err := t.rwc.Read(t.readBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.batchRet[0] = t.readBuf[:n]
|
||||
return t.batchRet[:], nil
|
||||
}
|
||||
|
||||
func (t *tun) Write(p []byte) (int, error) {
|
||||
return t.rwc.Write(p)
|
||||
}
|
||||
|
||||
func (t *tun) WriteReject(p []byte) (int, error) {
|
||||
return t.rwc.Write(p)
|
||||
}
|
||||
|
||||
func (t *tun) Close() error {
|
||||
return t.rwc.Close()
|
||||
}
|
||||
|
||||
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||
@@ -35,9 +62,9 @@ func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, erro
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
|
||||
t := &tun{
|
||||
vpnNetworks: vpnNetworks,
|
||||
ReadWriteCloser: &tunReadCloser{f: file},
|
||||
l: l,
|
||||
vpnNetworks: vpnNetworks,
|
||||
rwc: &tunReadCloser{f: file},
|
||||
l: l,
|
||||
}
|
||||
|
||||
err := t.reload(c, true)
|
||||
@@ -155,6 +182,6 @@ func (t *tun) SupportsMultiqueue() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
func (t *tun) NewMultiQueueReader() (Queue, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
|
||||
}
|
||||
|
||||
@@ -10,9 +10,11 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
@@ -34,16 +36,58 @@ type tunFile struct {
|
||||
readPoll [2]unix.PollFd
|
||||
writePoll [2]unix.PollFd
|
||||
closed bool
|
||||
|
||||
// vnetHdr is true when this fd was opened with IFF_VNET_HDR and the
|
||||
// kernel successfully accepted TUNSETOFFLOAD. Reads include a leading
|
||||
// virtio_net_hdr and may carry a TSO superpacket we must segment;
|
||||
// writes must prepend a zeroed virtio_net_hdr.
|
||||
vnetHdr bool
|
||||
readBuf []byte // scratch for a single raw read (virtio hdr + superpacket)
|
||||
segBuf []byte // backing store for segmented output
|
||||
segOff int // cursor into segBuf for the current Read drain
|
||||
pending [][]byte // segments returned from the most recent Read
|
||||
writeIovs [2]unix.Iovec // preallocated iovecs for Write (coalescer passthrough); iovs[0] is fixed to validVnetHdr
|
||||
// rejectIovs is a second preallocated iovec scratch used exclusively by
|
||||
// WriteReject (reject + self-forward from the inside path). It mirrors
|
||||
// writeIovs but lets listenIn goroutines emit reject packets without
|
||||
// racing with the listenOut coalescer that owns writeIovs.
|
||||
rejectIovs [2]unix.Iovec
|
||||
|
||||
// gsoHdrBuf is a per-queue 10-byte scratch for the virtio_net_hdr emitted
|
||||
// by WriteGSO. Separate from validVnetHdr so a concurrent non-GSO Write on
|
||||
// another queue never observes a half-written header.
|
||||
gsoHdrBuf [virtioNetHdrLen]byte
|
||||
// gsoIovs is the writev iovec scratch for WriteGSO. Sized to hold the
|
||||
// virtio header + IP/TCP header + up to gsoInitialPayIovs payload
|
||||
// fragments; grown on demand if a coalescer pushes more.
|
||||
gsoIovs []unix.Iovec
|
||||
}
|
||||
|
||||
// gsoInitialPayIovs is the starting capacity (in payload fragments) of
|
||||
// tunFile.gsoIovs. Sized to cover the default coalesce segment cap without
|
||||
// any reallocations.
|
||||
const gsoInitialPayIovs = 66
|
||||
|
||||
// validVnetHdr is the 10-byte virtio_net_hdr we prepend to every non-GSO TUN
|
||||
// write. Only flag set is VIRTIO_NET_HDR_F_DATA_VALID, which marks the skb
|
||||
// CHECKSUM_UNNECESSARY so the receiving network stack skips L4 checksum
|
||||
// verification. All packets that reach the plain Write / WriteReject paths
|
||||
// already carry a valid L4 checksum (either supplied by a remote peer whose
|
||||
// ciphertext we AEAD-authenticated, or produced by finishChecksum during TSO
|
||||
// segmentation, or built locally by CreateRejectPacket), so trusting them is
|
||||
// safe.
|
||||
var validVnetHdr = [virtioNetHdrLen]byte{unix.VIRTIO_NET_HDR_F_DATA_VALID}
|
||||
|
||||
// newFriend makes a tunFile for a MultiQueueReader that copies the shutdown eventfd from the parent tun
|
||||
func (r *tunFile) newFriend(fd int) (*tunFile, error) {
|
||||
if err := unix.SetNonblock(fd, true); err != nil {
|
||||
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
|
||||
}
|
||||
return &tunFile{
|
||||
out := &tunFile{
|
||||
fd: fd,
|
||||
shutdownFd: r.shutdownFd,
|
||||
vnetHdr: r.vnetHdr,
|
||||
readBuf: make([]byte, tunReadBufSize),
|
||||
readPoll: [2]unix.PollFd{
|
||||
{Fd: int32(fd), Events: unix.POLLIN},
|
||||
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
|
||||
@@ -52,10 +96,21 @@ func (r *tunFile) newFriend(fd int) (*tunFile, error) {
|
||||
{Fd: int32(fd), Events: unix.POLLOUT},
|
||||
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
if r.vnetHdr {
|
||||
out.segBuf = make([]byte, tunSegBufCap)
|
||||
out.writeIovs[0].Base = &validVnetHdr[0]
|
||||
out.writeIovs[0].SetLen(virtioNetHdrLen)
|
||||
out.rejectIovs[0].Base = &validVnetHdr[0]
|
||||
out.rejectIovs[0].SetLen(virtioNetHdrLen)
|
||||
out.gsoIovs = make([]unix.Iovec, 2, 2+gsoInitialPayIovs)
|
||||
out.gsoIovs[0].Base = &out.gsoHdrBuf[0]
|
||||
out.gsoIovs[0].SetLen(virtioNetHdrLen)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func newTunFd(fd int) (*tunFile, error) {
|
||||
func newTunFd(fd int, vnetHdr bool) (*tunFile, error) {
|
||||
if err := unix.SetNonblock(fd, true); err != nil {
|
||||
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
|
||||
}
|
||||
@@ -69,6 +124,8 @@ func newTunFd(fd int) (*tunFile, error) {
|
||||
fd: fd,
|
||||
shutdownFd: shutdownFd,
|
||||
lastOne: true,
|
||||
vnetHdr: vnetHdr,
|
||||
readBuf: make([]byte, tunReadBufSize),
|
||||
readPoll: [2]unix.PollFd{
|
||||
{Fd: int32(fd), Events: unix.POLLIN},
|
||||
{Fd: int32(shutdownFd), Events: unix.POLLIN},
|
||||
@@ -78,6 +135,16 @@ func newTunFd(fd int) (*tunFile, error) {
|
||||
{Fd: int32(shutdownFd), Events: unix.POLLIN},
|
||||
},
|
||||
}
|
||||
if vnetHdr {
|
||||
out.segBuf = make([]byte, tunSegBufCap)
|
||||
out.writeIovs[0].Base = &validVnetHdr[0]
|
||||
out.writeIovs[0].SetLen(virtioNetHdrLen)
|
||||
out.rejectIovs[0].Base = &validVnetHdr[0]
|
||||
out.rejectIovs[0].SetLen(virtioNetHdrLen)
|
||||
out.gsoIovs = make([]unix.Iovec, 2, 2+gsoInitialPayIovs)
|
||||
out.gsoIovs[0].Base = &out.gsoHdrBuf[0]
|
||||
out.gsoIovs[0].SetLen(virtioNetHdrLen)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
@@ -134,7 +201,7 @@ func (r *tunFile) blockOnWrite() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *tunFile) Read(buf []byte) (int, error) {
|
||||
func (r *tunFile) readRaw(buf []byte) (int, error) {
|
||||
for {
|
||||
if n, err := unix.Read(r.fd, buf); err == nil {
|
||||
return n, nil
|
||||
@@ -153,22 +220,238 @@ func (r *tunFile) Read(buf []byte) (int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *tunFile) Write(buf []byte) (int, error) {
|
||||
// Read reads one or more superpackets from the tun and returns the
|
||||
// resulting packets. The first read blocks via poll; once the fd is known
|
||||
// readable we drain additional packets non-blocking until the kernel queue
|
||||
// is empty (EAGAIN), we've collected tunDrainCap packets, or we're out of
|
||||
// segBuf headroom. This amortizes the poll wake over bursts of small
|
||||
// packets (e.g. TCP ACKs). Slices point into the tunFile's internal buffers
|
||||
// and are only valid until the next Read or Close on this Queue.
|
||||
func (r *tunFile) Read() ([][]byte, error) {
|
||||
r.pending = r.pending[:0]
|
||||
r.segOff = 0
|
||||
|
||||
// Initial (blocking) read. Retry on decode errors so a single bad
|
||||
// packet does not stall the reader.
|
||||
for {
|
||||
if n, err := unix.Write(r.fd, buf); err == nil {
|
||||
return n, nil
|
||||
} else if err == unix.EAGAIN {
|
||||
if err = r.blockOnWrite(); err != nil {
|
||||
n, err := r.readRaw(r.readBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !r.vnetHdr {
|
||||
r.pending = append(r.pending, r.readBuf[:n])
|
||||
// Non-vnetHdr mode shares one readBuf so we can't drain safely
|
||||
// without copying; return the single packet as before.
|
||||
return r.pending, nil
|
||||
}
|
||||
if err := r.decodeRead(n); err != nil {
|
||||
// Drop and read again — a bad packet should not kill the reader.
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// Drain: non-blocking reads until the kernel queue is empty, the drain
|
||||
// cap is reached, or segBuf no longer has room for another worst-case
|
||||
// superpacket.
|
||||
for len(r.pending) < tunDrainCap && tunSegBufCap-r.segOff >= tunSegBufSize {
|
||||
n, err := unix.Read(r.fd, r.readBuf)
|
||||
if err != nil {
|
||||
// EAGAIN / EINTR / anything else: stop draining. We already
|
||||
// have a valid batch from the first read.
|
||||
break
|
||||
}
|
||||
if n <= 0 {
|
||||
break
|
||||
}
|
||||
if err := r.decodeRead(n); err != nil {
|
||||
// Drop this packet and stop the drain; we'd rather hand off
|
||||
// what we have than keep spinning here.
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return r.pending, nil
|
||||
}
|
||||
|
||||
// decodeRead decodes the virtio header plus payload in r.readBuf[:n], appends
|
||||
// the segments to r.pending, and advances r.segOff by the total scratch used.
|
||||
// Caller must have already ensured r.vnetHdr is true.
|
||||
func (r *tunFile) decodeRead(n int) error {
|
||||
if n < virtioNetHdrLen {
|
||||
return fmt.Errorf("short tun read: %d < %d", n, virtioNetHdrLen)
|
||||
}
|
||||
var hdr virtioNetHdr
|
||||
hdr.decode(r.readBuf[:virtioNetHdrLen])
|
||||
before := len(r.pending)
|
||||
if err := segmentInto(r.readBuf[virtioNetHdrLen:n], hdr, &r.pending, r.segBuf[r.segOff:]); err != nil {
|
||||
return err
|
||||
}
|
||||
for k := before; k < len(r.pending); k++ {
|
||||
r.segOff += len(r.pending[k])
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *tunFile) Write(buf []byte) (int, error) {
|
||||
return r.writeWithScratch(buf, &r.writeIovs)
|
||||
}
|
||||
|
||||
// WriteReject emits a packet using a dedicated iovec scratch (rejectIovs)
|
||||
// distinct from the one used by the coalescer's Write path. This avoids a
|
||||
// data race between the inside (listenIn) goroutine emitting reject or
|
||||
// self-forward packets and the outside (listenOut) goroutine flushing TCP
|
||||
// coalescer passthroughs on the same tunFile.
|
||||
func (r *tunFile) WriteReject(buf []byte) (int, error) {
|
||||
return r.writeWithScratch(buf, &r.rejectIovs)
|
||||
}
|
||||
|
||||
func (r *tunFile) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error) {
|
||||
if !r.vnetHdr {
|
||||
for {
|
||||
if n, err := unix.Write(r.fd, buf); err == nil {
|
||||
return n, nil
|
||||
} else if err == unix.EAGAIN {
|
||||
if err = r.blockOnWrite(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
continue
|
||||
} else if err == unix.EINTR {
|
||||
continue
|
||||
} else if err == unix.EBADF {
|
||||
return 0, os.ErrClosed
|
||||
} else {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(buf) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
// Point the payload iovec at the caller's buffer. iovs[0] is pre-wired
|
||||
// to validVnetHdr during tunFile construction so we don't rebuild it here.
|
||||
iovs[1].Base = &buf[0]
|
||||
iovs[1].SetLen(len(buf))
|
||||
iovPtr := uintptr(unsafe.Pointer(&iovs[0]))
|
||||
// The TUN fd is non-blocking (set in newTunFd / newFriend), so writev
|
||||
// either completes promptly or returns EAGAIN — it cannot park the
|
||||
// goroutine inside the kernel. That lets us use syscall.RawSyscall and
|
||||
// skip the runtime.entersyscall / exitsyscall bookkeeping on every
|
||||
// packet; we only pay that cost when we fall through to blockOnWrite.
|
||||
for {
|
||||
n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, 2)
|
||||
if errno == 0 {
|
||||
runtime.KeepAlive(buf)
|
||||
if int(n) < virtioNetHdrLen {
|
||||
return 0, io.ErrShortWrite
|
||||
}
|
||||
return int(n) - virtioNetHdrLen, nil
|
||||
}
|
||||
if errno == unix.EAGAIN {
|
||||
runtime.KeepAlive(buf)
|
||||
if err := r.blockOnWrite(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
continue
|
||||
} else if err == unix.EINTR {
|
||||
continue
|
||||
} else if err == unix.EBADF {
|
||||
return 0, os.ErrClosed
|
||||
} else {
|
||||
return 0, err
|
||||
}
|
||||
if errno == unix.EINTR {
|
||||
continue
|
||||
}
|
||||
runtime.KeepAlive(buf)
|
||||
return 0, errno
|
||||
}
|
||||
}
|
||||
|
||||
// GSOSupported reports whether this queue was opened with IFF_VNET_HDR and
|
||||
// can accept WriteGSO. When false, callers should fall back to per-segment
|
||||
// Write calls.
|
||||
func (r *tunFile) GSOSupported() bool { return r.vnetHdr }
|
||||
|
||||
// WriteGSO emits a TCP TSO superpacket in a single writev. hdr is the
|
||||
// IPv4/IPv6 + TCP header prefix (already finalized — total length, IP csum,
|
||||
// and TCP pseudo-header partial set by the caller). pays are payload
|
||||
// fragments whose concatenation forms the full coalesced payload; each
|
||||
// slice is read-only and must stay valid until return. gsoSize is the MSS;
|
||||
// every segment except possibly the last is exactly gsoSize bytes.
|
||||
// csumStart is the byte offset where the TCP header begins within hdr.
|
||||
func (r *tunFile) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error {
|
||||
if !r.vnetHdr {
|
||||
return fmt.Errorf("WriteGSO called on tun without IFF_VNET_HDR")
|
||||
}
|
||||
if len(hdr) == 0 || len(pays) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build the virtio_net_hdr. When pays total to <= gsoSize the kernel
|
||||
// would produce a single segment; keep NEEDS_CSUM semantics but skip
|
||||
// the GSO type so the kernel doesn't spuriously mark this as TSO.
|
||||
vhdr := virtioNetHdr{
|
||||
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||||
HdrLen: uint16(len(hdr)),
|
||||
GSOSize: gsoSize,
|
||||
CsumStart: csumStart,
|
||||
CsumOffset: 16, // TCP checksum field lives 16 bytes into the TCP header
|
||||
}
|
||||
var totalPay int
|
||||
for _, p := range pays {
|
||||
totalPay += len(p)
|
||||
}
|
||||
if totalPay > int(gsoSize) {
|
||||
if isV6 {
|
||||
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV6
|
||||
} else {
|
||||
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV4
|
||||
}
|
||||
} else {
|
||||
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE
|
||||
vhdr.GSOSize = 0
|
||||
}
|
||||
vhdr.encode(r.gsoHdrBuf[:])
|
||||
|
||||
// Build the iovec array: [virtio_hdr, hdr, pays...]. r.gsoIovs[0] is
|
||||
// wired to gsoHdrBuf at construction and never changes.
|
||||
need := 2 + len(pays)
|
||||
if cap(r.gsoIovs) < need {
|
||||
grown := make([]unix.Iovec, need)
|
||||
grown[0] = r.gsoIovs[0]
|
||||
r.gsoIovs = grown
|
||||
} else {
|
||||
r.gsoIovs = r.gsoIovs[:need]
|
||||
}
|
||||
r.gsoIovs[1].Base = &hdr[0]
|
||||
r.gsoIovs[1].SetLen(len(hdr))
|
||||
for i, p := range pays {
|
||||
r.gsoIovs[2+i].Base = &p[0]
|
||||
r.gsoIovs[2+i].SetLen(len(p))
|
||||
}
|
||||
|
||||
iovPtr := uintptr(unsafe.Pointer(&r.gsoIovs[0]))
|
||||
iovCnt := uintptr(len(r.gsoIovs))
|
||||
for {
|
||||
n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, iovCnt)
|
||||
if errno == 0 {
|
||||
runtime.KeepAlive(hdr)
|
||||
runtime.KeepAlive(pays)
|
||||
if int(n) < virtioNetHdrLen {
|
||||
return io.ErrShortWrite
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if errno == unix.EAGAIN {
|
||||
runtime.KeepAlive(hdr)
|
||||
runtime.KeepAlive(pays)
|
||||
if err := r.blockOnWrite(); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if errno == unix.EINTR {
|
||||
continue
|
||||
}
|
||||
runtime.KeepAlive(hdr)
|
||||
runtime.KeepAlive(pays)
|
||||
return errno
|
||||
}
|
||||
}
|
||||
|
||||
@@ -239,7 +522,9 @@ type ifreqQLEN struct {
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
t, err := newTunGeneric(c, l, deviceFd, vpnNetworks)
|
||||
// We don't know what flags the caller opened this fd with and can't turn
|
||||
// on IFF_VNET_HDR after TUNSETIFF, so skip offload on inherited fds.
|
||||
t, err := newTunGeneric(c, l, deviceFd, false, vpnNetworks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -249,46 +534,83 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
||||
// openTunDev opens /dev/net/tun, creating the device node first if it's
|
||||
// missing (docker containers occasionally omit it).
|
||||
func openTunDev() (int, error) {
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
|
||||
if os.IsNotExist(err) {
|
||||
err = os.MkdirAll("/dev/net", 0755)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
|
||||
}
|
||||
err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
|
||||
}
|
||||
|
||||
fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
if err == nil {
|
||||
return fd, nil
|
||||
}
|
||||
if !os.IsNotExist(err) {
|
||||
return -1, err
|
||||
}
|
||||
if err = os.MkdirAll("/dev/net", 0755); err != nil {
|
||||
return -1, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
|
||||
}
|
||||
if err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))); err != nil {
|
||||
return -1, fmt.Errorf("failed to create /dev/net/tun: %w", err)
|
||||
}
|
||||
fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("created /dev/net/tun, but still failed: %w", err)
|
||||
}
|
||||
return fd, nil
|
||||
}
|
||||
|
||||
// tunSetIff runs TUNSETIFF with the given flags and returns the kernel-chosen
|
||||
// device name on success.
|
||||
func tunSetIff(fd int, name string, flags uint16) (string, error) {
|
||||
var req ifReq
|
||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
|
||||
req.Flags = flags
|
||||
copy(req.Name[:], name)
|
||||
if err := ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.Trim(string(req.Name[:]), "\x00"), nil
|
||||
}
|
||||
|
||||
// tsoOffloadFlags are the TUN_F_* bits we ask the kernel to enable when a
|
||||
// TSO-capable TUN is available. CSUM is required as a prerequisite for TSO.
|
||||
const tsoOffloadFlags = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
||||
baseFlags := uint16(unix.IFF_TUN | unix.IFF_NO_PI)
|
||||
if multiqueue {
|
||||
req.Flags |= unix.IFF_MULTI_QUEUE
|
||||
baseFlags |= unix.IFF_MULTI_QUEUE
|
||||
}
|
||||
nameStr := c.GetString("tun.dev", "")
|
||||
copy(req.Name[:], nameStr)
|
||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
|
||||
// First try to open with IFF_VNET_HDR + TUNSETOFFLOAD so we can receive
|
||||
// TSO superpackets. If either step fails (older kernel, unprivileged
|
||||
// container, etc.) we close and fall back to a plain TUN.
|
||||
fd, err := openTunDev()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vnetHdr := true
|
||||
name, err := tunSetIff(fd, nameStr, baseFlags|unix.IFF_VNET_HDR|unix.IFF_NAPI)
|
||||
if err != nil {
|
||||
_ = unix.Close(fd)
|
||||
return nil, &NameError{
|
||||
Name: nameStr,
|
||||
Underlying: err,
|
||||
vnetHdr = false
|
||||
} else if err = ioctl(uintptr(fd), unix.TUNSETOFFLOAD, uintptr(tsoOffloadFlags)); err != nil {
|
||||
l.WithError(err).Warn("Failed to enable TUN offload (TSO); proceeding without virtio headers")
|
||||
_ = unix.Close(fd)
|
||||
vnetHdr = false
|
||||
}
|
||||
|
||||
if !vnetHdr {
|
||||
fd, err = openTunDev()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
name, err = tunSetIff(fd, nameStr, baseFlags)
|
||||
if err != nil {
|
||||
_ = unix.Close(fd)
|
||||
return nil, &NameError{Name: nameStr, Underlying: err}
|
||||
}
|
||||
}
|
||||
name := strings.Trim(string(req.Name[:]), "\x00")
|
||||
|
||||
t, err := newTunGeneric(c, l, fd, vpnNetworks)
|
||||
t, err := newTunGeneric(c, l, fd, vnetHdr, vpnNetworks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -299,8 +621,8 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
||||
}
|
||||
|
||||
// newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error.
|
||||
func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
tfd, err := newTunFd(fd)
|
||||
func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vnetHdr bool, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
tfd, err := newTunFd(fd, vnetHdr)
|
||||
if err != nil {
|
||||
_ = unix.Close(fd)
|
||||
return nil, err
|
||||
@@ -410,7 +732,7 @@ func (t *tun) SupportsMultiqueue() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
func (t *tun) NewMultiQueueReader() (Queue, error) {
|
||||
t.closeLock.Lock()
|
||||
defer t.closeLock.Unlock()
|
||||
|
||||
@@ -419,14 +741,22 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var req ifReq
|
||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
|
||||
copy(req.Name[:], t.Device)
|
||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
flags := uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
|
||||
if t.vnetHdr {
|
||||
flags |= unix.IFF_VNET_HDR | unix.IFF_NAPI
|
||||
}
|
||||
if _, err = tunSetIff(fd, t.Device, flags); err != nil {
|
||||
_ = unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if t.vnetHdr {
|
||||
if err = ioctl(uintptr(fd), unix.TUNSETOFFLOAD, uintptr(tsoOffloadFlags)); err != nil {
|
||||
_ = unix.Close(fd)
|
||||
return nil, fmt.Errorf("failed to enable offload on multiqueue tun fd: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
out, err := t.tunFile.newFriend(fd)
|
||||
if err != nil {
|
||||
_ = unix.Close(fd)
|
||||
|
||||
331
overlay/tun_linux_offload.go
Normal file
331
overlay/tun_linux_offload.go
Normal file
@@ -0,0 +1,331 @@
|
||||
//go:build linux && !android && !e2e_testing
|
||||
// +build linux,!android,!e2e_testing
|
||||
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// Size of the legacy struct virtio_net_hdr that the kernel prepends/expects on
|
||||
// a TUN opened with IFF_VNET_HDR (TUNSETVNETHDRSZ not set).
|
||||
const virtioNetHdrLen = 10
|
||||
|
||||
// Maximum size we accept for a single read from a TUN with IFF_VNET_HDR. A
|
||||
// TSO superpacket can be up to 64KiB of payload plus a single L2/L3/L4 header
|
||||
// prefix plus the virtio header.
|
||||
const tunReadBufSize = 65535
|
||||
|
||||
// Space for segmented output. Worst case is many small segments, each paying
|
||||
// an IP+TCP header. 128KiB comfortably covers the 64KiB payload ceiling.
|
||||
const tunSegBufSize = 131072
|
||||
|
||||
// tunSegBufCap is the total size we allocate for the per-reader segment
|
||||
// buffer. It is sized as one worst-case TSO superpacket (tunSegBufSize) plus
|
||||
// the same again as drain headroom so a Read wake can accumulate
|
||||
// additional packets after an initial big read without overflowing.
|
||||
const tunSegBufCap = tunSegBufSize * 2
|
||||
|
||||
// tunDrainCap caps how many packets a single Read will accumulate via
|
||||
// the post-wake drain loop. Sized to soak up a burst of small ACKs while
|
||||
// bounding how much work a single caller holds before handing off.
|
||||
const tunDrainCap = 64
|
||||
|
||||
type virtioNetHdr struct {
|
||||
Flags uint8
|
||||
GSOType uint8
|
||||
HdrLen uint16
|
||||
GSOSize uint16
|
||||
CsumStart uint16
|
||||
CsumOffset uint16
|
||||
}
|
||||
|
||||
// decode reads a virtio_net_hdr in host byte order (TUN default; we never
|
||||
// call TUNSETVNETLE so the kernel matches our endianness).
|
||||
func (h *virtioNetHdr) decode(b []byte) {
|
||||
h.Flags = b[0]
|
||||
h.GSOType = b[1]
|
||||
h.HdrLen = binary.NativeEndian.Uint16(b[2:4])
|
||||
h.GSOSize = binary.NativeEndian.Uint16(b[4:6])
|
||||
h.CsumStart = binary.NativeEndian.Uint16(b[6:8])
|
||||
h.CsumOffset = binary.NativeEndian.Uint16(b[8:10])
|
||||
}
|
||||
|
||||
// encode is the inverse of decode: writes the virtio_net_hdr fields into b
|
||||
// (must be at least virtioNetHdrLen bytes). Used to emit a TSO superpacket
|
||||
// on egress.
|
||||
func (h *virtioNetHdr) encode(b []byte) {
|
||||
b[0] = h.Flags
|
||||
b[1] = h.GSOType
|
||||
binary.NativeEndian.PutUint16(b[2:4], h.HdrLen)
|
||||
binary.NativeEndian.PutUint16(b[4:6], h.GSOSize)
|
||||
binary.NativeEndian.PutUint16(b[6:8], h.CsumStart)
|
||||
binary.NativeEndian.PutUint16(b[8:10], h.CsumOffset)
|
||||
}
|
||||
|
||||
// segmentInto splits a TUN-side packet described by hdr into one or more
|
||||
// IP packets, each appended to *out as a slice of scratch. scratch must be
|
||||
// sized to hold every segment (including replicated headers).
|
||||
func segmentInto(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) error {
|
||||
// When RSC_INFO is set the csum_start/csum_offset fields are repurposed to
|
||||
// carry coalescing info rather than checksum offsets. A TUN writing via
|
||||
// IFF_VNET_HDR should never emit this, but if it did we would silently
|
||||
// miscompute the segment checksums — refuse the packet instead.
|
||||
if hdr.Flags&unix.VIRTIO_NET_HDR_F_RSC_INFO != 0 {
|
||||
return fmt.Errorf("virtio RSC_INFO flag not supported on TUN reads")
|
||||
}
|
||||
|
||||
switch hdr.GSOType {
|
||||
case unix.VIRTIO_NET_HDR_GSO_NONE:
|
||||
if len(pkt) > len(scratch) {
|
||||
return fmt.Errorf("packet larger than segment buffer: %d > %d", len(pkt), len(scratch))
|
||||
}
|
||||
copy(scratch, pkt)
|
||||
seg := scratch[:len(pkt)]
|
||||
if hdr.Flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
|
||||
if err := finishChecksum(seg, hdr); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
*out = append(*out, seg)
|
||||
return nil
|
||||
|
||||
case unix.VIRTIO_NET_HDR_GSO_TCPV4, unix.VIRTIO_NET_HDR_GSO_TCPV6:
|
||||
return segmentTCP(pkt, hdr, out, scratch)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unsupported virtio gso type: %d", hdr.GSOType)
|
||||
}
|
||||
}
|
||||
|
||||
// finishChecksum computes the L4 checksum for a non-GSO packet that the kernel
|
||||
// handed us with NEEDS_CSUM set. csum_start / csum_offset point at the 16-bit
|
||||
// checksum field; we zero it, fold a full sum (the field was pre-loaded with
|
||||
// the pseudo-header partial sum by the kernel), and store the result.
|
||||
func finishChecksum(seg []byte, hdr virtioNetHdr) error {
|
||||
cs := int(hdr.CsumStart)
|
||||
co := int(hdr.CsumOffset)
|
||||
if cs+co+2 > len(seg) {
|
||||
return fmt.Errorf("csum offsets out of range: start=%d offset=%d len=%d", cs, co, len(seg))
|
||||
}
|
||||
// The kernel stores a partial pseudo-header sum at [cs+co:]; sum over the
|
||||
// L4 region starting at cs, folding the prior partial in as the seed.
|
||||
partial := uint32(binary.BigEndian.Uint16(seg[cs+co : cs+co+2]))
|
||||
seg[cs+co] = 0
|
||||
seg[cs+co+1] = 0
|
||||
sum := checksumBytes(seg[cs:], partial)
|
||||
binary.BigEndian.PutUint16(seg[cs+co:cs+co+2], checksumFold(sum))
|
||||
return nil
|
||||
}
|
||||
|
||||
// segmentTCP software-segments a TSO superpacket into one IP packet per MSS
|
||||
// chunk. The caller guarantees hdr.GSOType is TCPV4 or TCPV6.
|
||||
//
|
||||
// Hot-path shape: the per-segment loop only sums the payload chunk. The TCP
|
||||
// header, the IPv4 header, and the pseudo-header src/dst/proto contributions
|
||||
// are each summed once up front — every segment reuses those three pre-folded
|
||||
// uint32 values and combines them with small per-segment deltas (seq, flags,
|
||||
// tcpLen, ip_id, total_len) that are cheap to fold in.
|
||||
func segmentTCP(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) error {
|
||||
if hdr.GSOSize == 0 {
|
||||
return fmt.Errorf("gso_size is zero")
|
||||
}
|
||||
if int(hdr.HdrLen) > len(pkt) || hdr.HdrLen == 0 {
|
||||
return fmt.Errorf("hdr_len %d out of range (pkt %d)", hdr.HdrLen, len(pkt))
|
||||
}
|
||||
if hdr.CsumStart == 0 || hdr.CsumStart >= hdr.HdrLen {
|
||||
return fmt.Errorf("csum_start %d out of range (hdr_len %d)", hdr.CsumStart, hdr.HdrLen)
|
||||
}
|
||||
|
||||
isV4 := hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_TCPV4
|
||||
headerLen := int(hdr.HdrLen)
|
||||
csumStart := int(hdr.CsumStart)
|
||||
|
||||
if isV4 && csumStart < 20 {
|
||||
return fmt.Errorf("csum_start %d too small for IPv4", csumStart)
|
||||
}
|
||||
if !isV4 && csumStart < 40 {
|
||||
return fmt.Errorf("csum_start %d too small for IPv6", csumStart)
|
||||
}
|
||||
tcpHdrLen := headerLen - csumStart
|
||||
if tcpHdrLen < 20 {
|
||||
return fmt.Errorf("tcp header region too small: %d", tcpHdrLen)
|
||||
}
|
||||
|
||||
payload := pkt[headerLen:]
|
||||
payLen := len(payload)
|
||||
gso := int(hdr.GSOSize)
|
||||
numSeg := (payLen + gso - 1) / gso
|
||||
if numSeg == 0 {
|
||||
numSeg = 1
|
||||
}
|
||||
|
||||
need := numSeg*headerLen + payLen
|
||||
if need > len(scratch) {
|
||||
return fmt.Errorf("scratch too small for %d segments: need %d have %d", numSeg, need, len(scratch))
|
||||
}
|
||||
|
||||
origSeq := binary.BigEndian.Uint32(pkt[csumStart+4 : csumStart+8])
|
||||
origFlags := pkt[csumStart+13]
|
||||
const tcpFinPsh = 0x09 // FIN(0x01) | PSH(0x08)
|
||||
|
||||
// Precompute the TCP header sum with seq/flags/csum zeroed. The max TCP
|
||||
// header is 60 bytes; copy onto the stack, zero the per-segment-varying
|
||||
// fields, sum once.
|
||||
var tmp [60]byte
|
||||
copy(tmp[:tcpHdrLen], pkt[csumStart:headerLen])
|
||||
tmp[4], tmp[5], tmp[6], tmp[7] = 0, 0, 0, 0 // seq
|
||||
tmp[13] = 0 // flags
|
||||
tmp[16], tmp[17] = 0, 0 // csum
|
||||
baseTcpHdrSum := checksumBytes(tmp[:tcpHdrLen], 0)
|
||||
|
||||
// Pseudo-header src+dst+proto contribution (tcpLen varies per segment).
|
||||
var baseProtoSum uint32
|
||||
if isV4 {
|
||||
baseProtoSum = checksumBytes(pkt[12:16], 0)
|
||||
baseProtoSum = checksumBytes(pkt[16:20], baseProtoSum)
|
||||
} else {
|
||||
baseProtoSum = checksumBytes(pkt[8:24], 0)
|
||||
baseProtoSum = checksumBytes(pkt[24:40], baseProtoSum)
|
||||
}
|
||||
baseProtoSum += uint32(unix.IPPROTO_TCP)
|
||||
|
||||
// Precompute IPv4 header sum with total_len/id/csum zeroed.
|
||||
var origIPID uint16
|
||||
var ihl int
|
||||
var baseIPHdrSum uint32
|
||||
if isV4 {
|
||||
origIPID = binary.BigEndian.Uint16(pkt[4:6])
|
||||
ihl = int(pkt[0]&0x0f) * 4
|
||||
if ihl < 20 || ihl > csumStart {
|
||||
return fmt.Errorf("bad IPv4 IHL: %d", ihl)
|
||||
}
|
||||
var ipTmp [60]byte
|
||||
copy(ipTmp[:ihl], pkt[:ihl])
|
||||
ipTmp[2], ipTmp[3] = 0, 0 // total_len
|
||||
ipTmp[4], ipTmp[5] = 0, 0 // id
|
||||
ipTmp[10], ipTmp[11] = 0, 0 // checksum
|
||||
baseIPHdrSum = checksumBytes(ipTmp[:ihl], 0)
|
||||
}
|
||||
|
||||
off := 0
|
||||
for i := 0; i < numSeg; i++ {
|
||||
segStart := i * gso
|
||||
segEnd := segStart + gso
|
||||
if segEnd > payLen {
|
||||
segEnd = payLen
|
||||
}
|
||||
segPayLen := segEnd - segStart
|
||||
|
||||
copy(scratch[off:], pkt[:headerLen])
|
||||
copy(scratch[off+headerLen:], payload[segStart:segEnd])
|
||||
seg := scratch[off : off+headerLen+segPayLen]
|
||||
off += headerLen + segPayLen
|
||||
|
||||
segSeq := origSeq + uint32(segStart)
|
||||
segFlags := origFlags
|
||||
if i != numSeg-1 {
|
||||
segFlags = origFlags &^ tcpFinPsh
|
||||
}
|
||||
totalLen := headerLen + segPayLen
|
||||
|
||||
// Patch IP header and write the v4 header checksum from the precomputed base.
|
||||
if isV4 {
|
||||
segID := origIPID + uint16(i)
|
||||
binary.BigEndian.PutUint16(seg[2:4], uint16(totalLen))
|
||||
binary.BigEndian.PutUint16(seg[4:6], segID)
|
||||
ipSum := baseIPHdrSum + uint32(totalLen) + uint32(segID)
|
||||
binary.BigEndian.PutUint16(seg[10:12], checksumFold(ipSum))
|
||||
} else {
|
||||
// IPv6 payload length excludes the 40-byte fixed header but
|
||||
// includes any extension headers between [40:csumStart].
|
||||
binary.BigEndian.PutUint16(seg[4:6], uint16(headerLen-40+segPayLen))
|
||||
}
|
||||
|
||||
// Patch TCP header.
|
||||
binary.BigEndian.PutUint32(seg[csumStart+4:csumStart+8], segSeq)
|
||||
seg[csumStart+13] = segFlags
|
||||
// (csum is written below; its prior contents in `seg` don't affect the
|
||||
// computation since we never sum over the segment's own header.)
|
||||
|
||||
tcpLen := tcpHdrLen + segPayLen
|
||||
paySum := checksumBytes(payload[segStart:segEnd], 0)
|
||||
|
||||
// Combine pre-folded uint32s into a wider accumulator, then fold. Using
|
||||
// uint64 guards against overflow when segSeq's high bits set.
|
||||
wide := uint64(baseTcpHdrSum) + uint64(paySum) + uint64(baseProtoSum)
|
||||
wide += uint64(segSeq) + uint64(segFlags) + uint64(tcpLen)
|
||||
wide = (wide & 0xffffffff) + (wide >> 32)
|
||||
wide = (wide & 0xffffffff) + (wide >> 32)
|
||||
binary.BigEndian.PutUint16(seg[csumStart+16:csumStart+18], checksumFold(uint32(wide)))
|
||||
|
||||
*out = append(*out, seg)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checksumBytes returns the Internet-checksum partial sum of b, seeded with
|
||||
// initial. Result is a 32-bit accumulator; the caller folds to 16.
|
||||
//
|
||||
// Each 4-byte load is added directly into a 64-bit accumulator. Two parallel
|
||||
// accumulators break the serial dependency through `sum` and let the CPU
|
||||
// overlap independent adds. The final fold from 64 → 32 → 16 handles the
|
||||
// carries that accumulated across the 32-bit lane boundary.
|
||||
func checksumBytes(b []byte, initial uint32) uint32 {
|
||||
s0 := uint64(initial)
|
||||
var s1 uint64
|
||||
for len(b) >= 32 {
|
||||
s0 += uint64(binary.BigEndian.Uint32(b[0:4]))
|
||||
s1 += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||
s0 += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||
s1 += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||
s0 += uint64(binary.BigEndian.Uint32(b[16:20]))
|
||||
s1 += uint64(binary.BigEndian.Uint32(b[20:24]))
|
||||
s0 += uint64(binary.BigEndian.Uint32(b[24:28]))
|
||||
s1 += uint64(binary.BigEndian.Uint32(b[28:32]))
|
||||
b = b[32:]
|
||||
}
|
||||
sum := s0 + s1
|
||||
for len(b) >= 4 {
|
||||
sum += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||
b = b[4:]
|
||||
}
|
||||
if len(b) >= 2 {
|
||||
sum += uint64(binary.BigEndian.Uint16(b[:2]))
|
||||
b = b[2:]
|
||||
}
|
||||
if len(b) == 1 {
|
||||
sum += uint64(b[0]) << 8
|
||||
}
|
||||
sum = (sum & 0xffffffff) + (sum >> 32)
|
||||
sum = (sum & 0xffffffff) + (sum >> 32)
|
||||
return uint32(sum)
|
||||
}
|
||||
|
||||
func checksumFold(sum uint32) uint16 {
|
||||
for sum>>16 != 0 {
|
||||
sum = (sum & 0xffff) + (sum >> 16)
|
||||
}
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
func pseudoHeaderIPv4(src, dst []byte, proto byte, tcpLen int) uint32 {
|
||||
sum := checksumBytes(src, 0)
|
||||
sum = checksumBytes(dst, sum)
|
||||
sum += uint32(proto)
|
||||
sum += uint32(tcpLen)
|
||||
return sum
|
||||
}
|
||||
|
||||
func pseudoHeaderIPv6(src, dst []byte, proto byte, tcpLen int) uint32 {
|
||||
sum := checksumBytes(src, 0)
|
||||
sum = checksumBytes(dst, sum)
|
||||
sum += uint32(tcpLen >> 16)
|
||||
sum += uint32(tcpLen & 0xffff)
|
||||
sum += uint32(proto)
|
||||
return sum
|
||||
}
|
||||
333
overlay/tun_linux_offload_test.go
Normal file
333
overlay/tun_linux_offload_test.go
Normal file
@@ -0,0 +1,333 @@
|
||||
//go:build linux && !android && !e2e_testing
|
||||
// +build linux,!android,!e2e_testing
|
||||
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// verifyChecksum confirms that the one's-complement sum across `b`, optionally
|
||||
// seeded with a pseudo-header sum, folds to all-ones (valid).
|
||||
func verifyChecksum(b []byte, pseudo uint32) bool {
|
||||
sum := checksumBytes(b, pseudo)
|
||||
for sum>>16 != 0 {
|
||||
sum = (sum & 0xffff) + (sum >> 16)
|
||||
}
|
||||
return uint16(sum) == 0xffff
|
||||
}
|
||||
|
||||
// buildTSOv4 builds a synthetic IPv4/TCP TSO superpacket with a payload of
|
||||
// `payLen` bytes split at `mss`.
|
||||
func buildTSOv4(t *testing.T, payLen, mss int) ([]byte, virtioNetHdr) {
|
||||
t.Helper()
|
||||
const ipLen = 20
|
||||
const tcpLen = 20
|
||||
pkt := make([]byte, ipLen+tcpLen+payLen)
|
||||
|
||||
// IPv4 header
|
||||
pkt[0] = 0x45 // version 4, IHL 5
|
||||
// total length is meaningless for TSO but set it anyway
|
||||
binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+tcpLen+payLen))
|
||||
binary.BigEndian.PutUint16(pkt[4:6], 0x4242) // original ID
|
||||
pkt[8] = 64 // TTL
|
||||
pkt[9] = unix.IPPROTO_TCP
|
||||
copy(pkt[12:16], []byte{10, 0, 0, 1}) // src
|
||||
copy(pkt[16:20], []byte{10, 0, 0, 2}) // dst
|
||||
|
||||
// TCP header
|
||||
binary.BigEndian.PutUint16(pkt[20:22], 12345) // sport
|
||||
binary.BigEndian.PutUint16(pkt[22:24], 80) // dport
|
||||
binary.BigEndian.PutUint32(pkt[24:28], 10000) // seq
|
||||
binary.BigEndian.PutUint32(pkt[28:32], 20000) // ack
|
||||
pkt[32] = 0x50 // data offset 5 words
|
||||
pkt[33] = 0x18 // ACK | PSH
|
||||
binary.BigEndian.PutUint16(pkt[34:36], 65535) // window
|
||||
|
||||
// payload
|
||||
for i := 0; i < payLen; i++ {
|
||||
pkt[ipLen+tcpLen+i] = byte(i & 0xff)
|
||||
}
|
||||
|
||||
return pkt, virtioNetHdr{
|
||||
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||||
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
|
||||
HdrLen: uint16(ipLen + tcpLen),
|
||||
GSOSize: uint16(mss),
|
||||
CsumStart: uint16(ipLen),
|
||||
CsumOffset: 16,
|
||||
}
|
||||
}
|
||||
|
||||
func TestSegmentTCPv4(t *testing.T) {
|
||||
const mss = 100
|
||||
const numSeg = 3
|
||||
pkt, hdr := buildTSOv4(t, mss*numSeg, mss)
|
||||
|
||||
scratch := make([]byte, tunSegBufSize)
|
||||
var out [][]byte
|
||||
if err := segmentTCP(pkt, hdr, &out, scratch); err != nil {
|
||||
t.Fatalf("segmentTCP: %v", err)
|
||||
}
|
||||
if len(out) != numSeg {
|
||||
t.Fatalf("expected %d segments, got %d", numSeg, len(out))
|
||||
}
|
||||
|
||||
for i, seg := range out {
|
||||
if len(seg) != 40+mss {
|
||||
t.Errorf("seg %d: unexpected len %d", i, len(seg))
|
||||
}
|
||||
totalLen := binary.BigEndian.Uint16(seg[2:4])
|
||||
if totalLen != uint16(40+mss) {
|
||||
t.Errorf("seg %d: total_len=%d want %d", i, totalLen, 40+mss)
|
||||
}
|
||||
id := binary.BigEndian.Uint16(seg[4:6])
|
||||
if id != 0x4242+uint16(i) {
|
||||
t.Errorf("seg %d: ip id=%#x want %#x", i, id, 0x4242+uint16(i))
|
||||
}
|
||||
seq := binary.BigEndian.Uint32(seg[24:28])
|
||||
wantSeq := uint32(10000 + i*mss)
|
||||
if seq != wantSeq {
|
||||
t.Errorf("seg %d: seq=%d want %d", i, seq, wantSeq)
|
||||
}
|
||||
flags := seg[33]
|
||||
wantFlags := byte(0x10) // ACK only, PSH cleared
|
||||
if i == numSeg-1 {
|
||||
wantFlags = 0x18 // ACK | PSH preserved on last
|
||||
}
|
||||
if flags != wantFlags {
|
||||
t.Errorf("seg %d: flags=%#x want %#x", i, flags, wantFlags)
|
||||
}
|
||||
// IPv4 header checksum must verify against itself.
|
||||
if !verifyChecksum(seg[:20], 0) {
|
||||
t.Errorf("seg %d: bad IPv4 header checksum", i)
|
||||
}
|
||||
// TCP checksum must verify against the pseudo-header.
|
||||
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+mss)
|
||||
if !verifyChecksum(seg[20:], psum) {
|
||||
t.Errorf("seg %d: bad TCP checksum", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSegmentTCPv4OddTail(t *testing.T) {
|
||||
// Payload of 250 bytes with MSS 100 → segments of 100, 100, 50.
|
||||
pkt, hdr := buildTSOv4(t, 250, 100)
|
||||
scratch := make([]byte, tunSegBufSize)
|
||||
var out [][]byte
|
||||
if err := segmentTCP(pkt, hdr, &out, scratch); err != nil {
|
||||
t.Fatalf("segmentTCP: %v", err)
|
||||
}
|
||||
if len(out) != 3 {
|
||||
t.Fatalf("want 3 segments, got %d", len(out))
|
||||
}
|
||||
wantPayLens := []int{100, 100, 50}
|
||||
for i, seg := range out {
|
||||
if len(seg)-40 != wantPayLens[i] {
|
||||
t.Errorf("seg %d: pay len %d want %d", i, len(seg)-40, wantPayLens[i])
|
||||
}
|
||||
if !verifyChecksum(seg[:20], 0) {
|
||||
t.Errorf("seg %d: bad IPv4 header checksum", i)
|
||||
}
|
||||
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+wantPayLens[i])
|
||||
if !verifyChecksum(seg[20:], psum) {
|
||||
t.Errorf("seg %d: bad TCP checksum", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSegmentTCPv6(t *testing.T) {
|
||||
const ipLen = 40
|
||||
const tcpLen = 20
|
||||
const mss = 120
|
||||
const numSeg = 2
|
||||
payLen := mss * numSeg
|
||||
pkt := make([]byte, ipLen+tcpLen+payLen)
|
||||
|
||||
// IPv6 header
|
||||
pkt[0] = 0x60 // version 6
|
||||
binary.BigEndian.PutUint16(pkt[4:6], uint16(tcpLen+payLen))
|
||||
pkt[6] = unix.IPPROTO_TCP
|
||||
pkt[7] = 64
|
||||
// src/dst fe80::1 / fe80::2
|
||||
pkt[8] = 0xfe
|
||||
pkt[9] = 0x80
|
||||
pkt[23] = 1
|
||||
pkt[24] = 0xfe
|
||||
pkt[25] = 0x80
|
||||
pkt[39] = 2
|
||||
|
||||
// TCP header
|
||||
binary.BigEndian.PutUint16(pkt[40:42], 12345)
|
||||
binary.BigEndian.PutUint16(pkt[42:44], 80)
|
||||
binary.BigEndian.PutUint32(pkt[44:48], 7)
|
||||
binary.BigEndian.PutUint32(pkt[48:52], 99)
|
||||
pkt[52] = 0x50
|
||||
pkt[53] = 0x19 // FIN | ACK | PSH — exercise FIN clearing too
|
||||
binary.BigEndian.PutUint16(pkt[54:56], 65535)
|
||||
|
||||
for i := 0; i < payLen; i++ {
|
||||
pkt[ipLen+tcpLen+i] = byte(i)
|
||||
}
|
||||
|
||||
hdr := virtioNetHdr{
|
||||
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||||
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
|
||||
HdrLen: uint16(ipLen + tcpLen),
|
||||
GSOSize: uint16(mss),
|
||||
CsumStart: uint16(ipLen),
|
||||
CsumOffset: 16,
|
||||
}
|
||||
|
||||
scratch := make([]byte, tunSegBufSize)
|
||||
var out [][]byte
|
||||
if err := segmentTCP(pkt, hdr, &out, scratch); err != nil {
|
||||
t.Fatalf("segmentTCP: %v", err)
|
||||
}
|
||||
if len(out) != numSeg {
|
||||
t.Fatalf("want %d segments, got %d", numSeg, len(out))
|
||||
}
|
||||
|
||||
for i, seg := range out {
|
||||
if len(seg) != ipLen+tcpLen+mss {
|
||||
t.Errorf("seg %d: len %d want %d", i, len(seg), ipLen+tcpLen+mss)
|
||||
}
|
||||
pl := binary.BigEndian.Uint16(seg[4:6])
|
||||
if pl != uint16(tcpLen+mss) {
|
||||
t.Errorf("seg %d: payload_length=%d want %d", i, pl, tcpLen+mss)
|
||||
}
|
||||
seq := binary.BigEndian.Uint32(seg[44:48])
|
||||
if seq != uint32(7+i*mss) {
|
||||
t.Errorf("seg %d: seq=%d want %d", i, seq, 7+i*mss)
|
||||
}
|
||||
flags := seg[53]
|
||||
// Original flags = 0x19 (FIN|ACK|PSH). FIN(0x01)+PSH(0x08) should be
|
||||
// cleared on all but the last; ACK(0x10) always preserved.
|
||||
wantFlags := byte(0x10)
|
||||
if i == numSeg-1 {
|
||||
wantFlags = 0x19
|
||||
}
|
||||
if flags != wantFlags {
|
||||
t.Errorf("seg %d: flags=%#x want %#x", i, flags, wantFlags)
|
||||
}
|
||||
psum := pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_TCP, tcpLen+mss)
|
||||
if !verifyChecksum(seg[ipLen:], psum) {
|
||||
t.Errorf("seg %d: bad TCP checksum", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSegmentGSONonePassesThrough(t *testing.T) {
|
||||
pkt, hdr := buildTSOv4(t, 100, 100)
|
||||
hdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE
|
||||
hdr.Flags = 0 // no NEEDS_CSUM, leave packet untouched
|
||||
|
||||
scratch := make([]byte, tunSegBufSize)
|
||||
var out [][]byte
|
||||
if err := segmentInto(pkt, hdr, &out, scratch); err != nil {
|
||||
t.Fatalf("segmentInto: %v", err)
|
||||
}
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("want 1 segment, got %d", len(out))
|
||||
}
|
||||
if len(out[0]) != len(pkt) {
|
||||
t.Fatalf("unexpected length: %d vs %d", len(out[0]), len(pkt))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSegmentRejectsUDP(t *testing.T) {
|
||||
hdr := virtioNetHdr{GSOType: unix.VIRTIO_NET_HDR_GSO_UDP}
|
||||
var out [][]byte
|
||||
if err := segmentInto(nil, hdr, &out, nil); err == nil {
|
||||
t.Fatalf("expected rejection for UDP GSO")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSegmentTCPv4(b *testing.B) {
|
||||
sizes := []struct {
|
||||
name string
|
||||
payLen int
|
||||
mss int
|
||||
}{
|
||||
{"64KiB_MSS1460", 65000, 1460},
|
||||
{"16KiB_MSS1460", 16384, 1460},
|
||||
{"4KiB_MSS1460", 4096, 1460},
|
||||
}
|
||||
for _, sz := range sizes {
|
||||
b.Run(sz.name, func(b *testing.B) {
|
||||
const ipLen = 20
|
||||
const tcpLen = 20
|
||||
pkt := make([]byte, ipLen+tcpLen+sz.payLen)
|
||||
pkt[0] = 0x45
|
||||
binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+tcpLen+sz.payLen))
|
||||
binary.BigEndian.PutUint16(pkt[4:6], 0x4242)
|
||||
pkt[8] = 64
|
||||
pkt[9] = unix.IPPROTO_TCP
|
||||
copy(pkt[12:16], []byte{10, 0, 0, 1})
|
||||
copy(pkt[16:20], []byte{10, 0, 0, 2})
|
||||
binary.BigEndian.PutUint16(pkt[20:22], 12345)
|
||||
binary.BigEndian.PutUint16(pkt[22:24], 80)
|
||||
binary.BigEndian.PutUint32(pkt[24:28], 10000)
|
||||
binary.BigEndian.PutUint32(pkt[28:32], 20000)
|
||||
pkt[32] = 0x50
|
||||
pkt[33] = 0x18
|
||||
binary.BigEndian.PutUint16(pkt[34:36], 65535)
|
||||
for i := 0; i < sz.payLen; i++ {
|
||||
pkt[ipLen+tcpLen+i] = byte(i)
|
||||
}
|
||||
hdr := virtioNetHdr{
|
||||
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||||
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
|
||||
HdrLen: uint16(ipLen + tcpLen),
|
||||
GSOSize: uint16(sz.mss),
|
||||
CsumStart: uint16(ipLen),
|
||||
CsumOffset: 16,
|
||||
}
|
||||
|
||||
scratch := make([]byte, tunSegBufSize)
|
||||
out := make([][]byte, 0, 64)
|
||||
|
||||
b.SetBytes(int64(len(pkt)))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
out = out[:0]
|
||||
if err := segmentTCP(pkt, hdr, &out, scratch); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTunFileWriteVnetHdrNoAlloc verifies the IFF_VNET_HDR fast-path write is
|
||||
// allocation-free. We write to /dev/null so every call succeeds synchronously.
|
||||
func TestTunFileWriteVnetHdrNoAlloc(t *testing.T) {
|
||||
fd, err := unix.Open("/dev/null", os.O_WRONLY, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("open /dev/null: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = unix.Close(fd) })
|
||||
|
||||
tf := &tunFile{fd: fd, vnetHdr: true}
|
||||
tf.writeIovs[0].Base = &validVnetHdr[0]
|
||||
tf.writeIovs[0].SetLen(virtioNetHdrLen)
|
||||
|
||||
payload := make([]byte, 1400)
|
||||
// Warm up (first call may trigger one-time internal allocations elsewhere).
|
||||
if _, err := tf.Write(payload); err != nil {
|
||||
t.Fatalf("Write: %v", err)
|
||||
}
|
||||
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
if _, err := tf.Write(payload); err != nil {
|
||||
t.Fatalf("Write: %v", err)
|
||||
}
|
||||
})
|
||||
if allocs != 0 {
|
||||
t.Fatalf("Write allocated %.1f times per call, want 0", allocs)
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,6 @@ package overlay
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"regexp"
|
||||
@@ -66,6 +65,25 @@ type tun struct {
|
||||
l *logrus.Logger
|
||||
f *os.File
|
||||
fd int
|
||||
|
||||
readBuf []byte
|
||||
batchRet [1][]byte
|
||||
}
|
||||
|
||||
func (t *tun) Read() ([][]byte, error) {
|
||||
if t.readBuf == nil {
|
||||
t.readBuf = make([]byte, defaultBatchBufSize)
|
||||
}
|
||||
n, err := t.readOne(t.readBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.batchRet[0] = t.readBuf[:n]
|
||||
return t.batchRet[:], nil
|
||||
}
|
||||
|
||||
func (t *tun) WriteReject(p []byte) (int, error) {
|
||||
return t.Write(p)
|
||||
}
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
@@ -141,7 +159,7 @@ func (t *tun) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
func (t *tun) readOne(to []byte) (int, error) {
|
||||
rc, err := t.f.SyscallConn()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err)
|
||||
@@ -394,7 +412,7 @@ func (t *tun) SupportsMultiqueue() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
func (t *tun) NewMultiQueueReader() (Queue, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ package overlay
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"regexp"
|
||||
@@ -59,6 +58,25 @@ type tun struct {
|
||||
fd int
|
||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||
out []byte
|
||||
|
||||
readBuf []byte
|
||||
batchRet [1][]byte
|
||||
}
|
||||
|
||||
func (t *tun) Read() ([][]byte, error) {
|
||||
if t.readBuf == nil {
|
||||
t.readBuf = make([]byte, defaultBatchBufSize)
|
||||
}
|
||||
n, err := t.readOne(t.readBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.batchRet[0] = t.readBuf[:n]
|
||||
return t.batchRet[:], nil
|
||||
}
|
||||
|
||||
func (t *tun) WriteReject(p []byte) (int, error) {
|
||||
return t.Write(p)
|
||||
}
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
@@ -124,7 +142,7 @@ func (t *tun) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
func (t *tun) readOne(to []byte) (int, error) {
|
||||
buf := make([]byte, len(to)+4)
|
||||
|
||||
n, err := t.f.Read(buf)
|
||||
@@ -314,7 +332,7 @@ func (t *tun) SupportsMultiqueue() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
func (t *tun) NewMultiQueueReader() (Queue, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,17 @@ type TestTun struct {
|
||||
closed atomic.Bool
|
||||
rxPackets chan []byte // Packets to receive into nebula
|
||||
TxPackets chan []byte // Packets transmitted outside by nebula
|
||||
|
||||
batchRet [1][]byte
|
||||
}
|
||||
|
||||
func (t *TestTun) Read() ([][]byte, error) {
|
||||
p, ok := <-t.rxPackets
|
||||
if !ok {
|
||||
return nil, os.ErrClosed
|
||||
}
|
||||
t.batchRet[0] = p
|
||||
return t.batchRet[:], nil
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
|
||||
@@ -115,6 +126,10 @@ func (t *TestTun) Write(b []byte) (n int, err error) {
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (t *TestTun) WriteReject(b []byte) (int, error) {
|
||||
return t.Write(b)
|
||||
}
|
||||
|
||||
func (t *TestTun) Close() error {
|
||||
if t.closed.CompareAndSwap(false, true) {
|
||||
close(t.rxPackets)
|
||||
@@ -123,19 +138,10 @@ func (t *TestTun) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TestTun) Read(b []byte) (int, error) {
|
||||
p, ok := <-t.rxPackets
|
||||
if !ok {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
copy(b, p)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (t *TestTun) SupportsMultiqueue() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
func (t *TestTun) NewMultiQueueReader() (Queue, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented")
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ package overlay
|
||||
import (
|
||||
"crypto"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -36,6 +35,25 @@ type winTun struct {
|
||||
l *logrus.Logger
|
||||
|
||||
tun *wintun.NativeTun
|
||||
|
||||
readBuf []byte
|
||||
batchRet [1][]byte
|
||||
}
|
||||
|
||||
func (t *winTun) Read() ([][]byte, error) {
|
||||
if t.readBuf == nil {
|
||||
t.readBuf = make([]byte, defaultBatchBufSize)
|
||||
}
|
||||
n, err := t.tun.Read(t.readBuf, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.batchRet[0] = t.readBuf[:n]
|
||||
return t.batchRet[:], nil
|
||||
}
|
||||
|
||||
func (t *winTun) WriteReject(p []byte) (int, error) {
|
||||
return t.Write(p)
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) {
|
||||
@@ -229,10 +247,6 @@ func (t *winTun) Name() string {
|
||||
return t.Device
|
||||
}
|
||||
|
||||
func (t *winTun) Read(b []byte) (int, error) {
|
||||
return t.tun.Read(b, 0)
|
||||
}
|
||||
|
||||
func (t *winTun) Write(b []byte) (int, error) {
|
||||
return t.tun.Write(b, 0)
|
||||
}
|
||||
@@ -241,7 +255,7 @@ func (t *winTun) SupportsMultiqueue() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
func (t *winTun) NewMultiQueueReader() (Queue, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
|
||||
}
|
||||
|
||||
|
||||
@@ -34,6 +34,21 @@ type UserDevice struct {
|
||||
|
||||
inboundReader *io.PipeReader
|
||||
inboundWriter *io.PipeWriter
|
||||
|
||||
readBuf []byte
|
||||
batchRet [1][]byte
|
||||
}
|
||||
|
||||
func (d *UserDevice) Read() ([][]byte, error) {
|
||||
if d.readBuf == nil {
|
||||
d.readBuf = make([]byte, defaultBatchBufSize)
|
||||
}
|
||||
n, err := d.outboundReader.Read(d.readBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.batchRet[0] = d.readBuf[:n]
|
||||
return d.batchRet[:], nil
|
||||
}
|
||||
|
||||
func (d *UserDevice) Activate() error {
|
||||
@@ -50,7 +65,7 @@ func (d *UserDevice) SupportsMultiqueue() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
func (d *UserDevice) NewMultiQueueReader() (Queue, error) {
|
||||
return d, nil
|
||||
}
|
||||
|
||||
@@ -58,12 +73,12 @@ func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) {
|
||||
return d.inboundReader, d.outboundWriter
|
||||
}
|
||||
|
||||
func (d *UserDevice) Read(p []byte) (n int, err error) {
|
||||
return d.outboundReader.Read(p)
|
||||
}
|
||||
func (d *UserDevice) Write(p []byte) (n int, err error) {
|
||||
return d.inboundWriter.Write(p)
|
||||
}
|
||||
func (d *UserDevice) WriteReject(p []byte) (n int, err error) {
|
||||
return d.Write(p)
|
||||
}
|
||||
func (d *UserDevice) Close() error {
|
||||
d.inboundWriter.Close()
|
||||
d.outboundWriter.Close()
|
||||
|
||||
Reference in New Issue
Block a user