mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-16 04:47:38 +02:00
first try
This commit is contained in:
@@ -34,6 +34,16 @@ type tunFile struct {
|
|||||||
readPoll [2]unix.PollFd
|
readPoll [2]unix.PollFd
|
||||||
writePoll [2]unix.PollFd
|
writePoll [2]unix.PollFd
|
||||||
closed bool
|
closed bool
|
||||||
|
|
||||||
|
// vnetHdr is true when this fd was opened with IFF_VNET_HDR and the
|
||||||
|
// kernel successfully accepted TUNSETOFFLOAD. Reads include a leading
|
||||||
|
// virtio_net_hdr and may carry a TSO superpacket we must segment;
|
||||||
|
// writes must prepend a zeroed virtio_net_hdr.
|
||||||
|
vnetHdr bool
|
||||||
|
readBuf []byte // scratch for a single raw read (virtio hdr + superpacket)
|
||||||
|
segBuf []byte // backing store for segmented output
|
||||||
|
pending [][]byte // segments waiting to be drained by Read
|
||||||
|
pendingIdx int
|
||||||
}
|
}
|
||||||
|
|
||||||
// newFriend makes a tunFile for a MultiQueueReader that copies the shutdown eventfd from the parent tun
|
// newFriend makes a tunFile for a MultiQueueReader that copies the shutdown eventfd from the parent tun
|
||||||
@@ -41,9 +51,10 @@ func (r *tunFile) newFriend(fd int) (*tunFile, error) {
|
|||||||
if err := unix.SetNonblock(fd, true); err != nil {
|
if err := unix.SetNonblock(fd, true); err != nil {
|
||||||
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
|
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
|
||||||
}
|
}
|
||||||
return &tunFile{
|
out := &tunFile{
|
||||||
fd: fd,
|
fd: fd,
|
||||||
shutdownFd: r.shutdownFd,
|
shutdownFd: r.shutdownFd,
|
||||||
|
vnetHdr: r.vnetHdr,
|
||||||
readPoll: [2]unix.PollFd{
|
readPoll: [2]unix.PollFd{
|
||||||
{Fd: int32(fd), Events: unix.POLLIN},
|
{Fd: int32(fd), Events: unix.POLLIN},
|
||||||
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
|
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
|
||||||
@@ -52,10 +63,15 @@ func (r *tunFile) newFriend(fd int) (*tunFile, error) {
|
|||||||
{Fd: int32(fd), Events: unix.POLLOUT},
|
{Fd: int32(fd), Events: unix.POLLOUT},
|
||||||
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
|
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
|
||||||
},
|
},
|
||||||
}, nil
|
}
|
||||||
|
if r.vnetHdr {
|
||||||
|
out.readBuf = make([]byte, tunReadBufSize)
|
||||||
|
out.segBuf = make([]byte, tunSegBufSize)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFd(fd int) (*tunFile, error) {
|
func newTunFd(fd int, vnetHdr bool) (*tunFile, error) {
|
||||||
if err := unix.SetNonblock(fd, true); err != nil {
|
if err := unix.SetNonblock(fd, true); err != nil {
|
||||||
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
|
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
|
||||||
}
|
}
|
||||||
@@ -69,6 +85,7 @@ func newTunFd(fd int) (*tunFile, error) {
|
|||||||
fd: fd,
|
fd: fd,
|
||||||
shutdownFd: shutdownFd,
|
shutdownFd: shutdownFd,
|
||||||
lastOne: true,
|
lastOne: true,
|
||||||
|
vnetHdr: vnetHdr,
|
||||||
readPoll: [2]unix.PollFd{
|
readPoll: [2]unix.PollFd{
|
||||||
{Fd: int32(fd), Events: unix.POLLIN},
|
{Fd: int32(fd), Events: unix.POLLIN},
|
||||||
{Fd: int32(shutdownFd), Events: unix.POLLIN},
|
{Fd: int32(shutdownFd), Events: unix.POLLIN},
|
||||||
@@ -78,6 +95,10 @@ func newTunFd(fd int) (*tunFile, error) {
|
|||||||
{Fd: int32(shutdownFd), Events: unix.POLLIN},
|
{Fd: int32(shutdownFd), Events: unix.POLLIN},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
if vnetHdr {
|
||||||
|
out.readBuf = make([]byte, tunReadBufSize)
|
||||||
|
out.segBuf = make([]byte, tunSegBufSize)
|
||||||
|
}
|
||||||
|
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
@@ -134,7 +155,7 @@ func (r *tunFile) blockOnWrite() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *tunFile) Read(buf []byte) (int, error) {
|
func (r *tunFile) readRaw(buf []byte) (int, error) {
|
||||||
for {
|
for {
|
||||||
if n, err := unix.Read(r.fd, buf); err == nil {
|
if n, err := unix.Read(r.fd, buf); err == nil {
|
||||||
return n, nil
|
return n, nil
|
||||||
@@ -149,7 +170,45 @@ func (r *tunFile) Read(buf []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *tunFile) Read(buf []byte) (int, error) {
|
||||||
|
if !r.vnetHdr {
|
||||||
|
return r.readRaw(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
if r.pendingIdx < len(r.pending) {
|
||||||
|
seg := r.pending[r.pendingIdx]
|
||||||
|
r.pendingIdx++
|
||||||
|
if len(seg) > len(buf) {
|
||||||
|
return 0, io.ErrShortBuffer
|
||||||
|
}
|
||||||
|
return copy(buf, seg), nil
|
||||||
|
}
|
||||||
|
r.pending = r.pending[:0]
|
||||||
|
r.pendingIdx = 0
|
||||||
|
|
||||||
|
n, err := r.readRaw(r.readBuf)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if n < virtioNetHdrLen {
|
||||||
|
return 0, fmt.Errorf("short tun read: %d < %d", n, virtioNetHdrLen)
|
||||||
|
}
|
||||||
|
var hdr virtioNetHdr
|
||||||
|
hdr.decode(r.readBuf[:virtioNetHdrLen])
|
||||||
|
if err := segmentInto(r.readBuf[virtioNetHdrLen:n], hdr, &r.pending, r.segBuf); err != nil {
|
||||||
|
// Drop and read again — a bad packet should not kill the reader.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// zeroVnetHdr is the prefix we prepend to every write when IFF_VNET_HDR is
|
||||||
|
// active and we have no offload info to convey.
|
||||||
|
var zeroVnetHdr [virtioNetHdrLen]byte
|
||||||
|
|
||||||
func (r *tunFile) Write(buf []byte) (int, error) {
|
func (r *tunFile) Write(buf []byte) (int, error) {
|
||||||
|
if !r.vnetHdr {
|
||||||
for {
|
for {
|
||||||
if n, err := unix.Write(r.fd, buf); err == nil {
|
if n, err := unix.Write(r.fd, buf); err == nil {
|
||||||
return n, nil
|
return n, nil
|
||||||
@@ -164,6 +223,28 @@ func (r *tunFile) Write(buf []byte) (int, error) {
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
iovs := [][]byte{zeroVnetHdr[:], buf}
|
||||||
|
for {
|
||||||
|
n, err := unix.Writev(r.fd, iovs)
|
||||||
|
if err == nil {
|
||||||
|
if n < virtioNetHdrLen {
|
||||||
|
return 0, io.ErrShortWrite
|
||||||
|
}
|
||||||
|
return n - virtioNetHdrLen, nil
|
||||||
|
}
|
||||||
|
if err == unix.EAGAIN {
|
||||||
|
if err = r.blockOnWrite(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err == unix.EINTR {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *tunFile) wakeForShutdown() error {
|
func (r *tunFile) wakeForShutdown() error {
|
||||||
@@ -233,7 +314,9 @@ type ifreqQLEN struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
t, err := newTunGeneric(c, l, deviceFd, vpnNetworks)
|
// We don't know what flags the caller opened this fd with and can't turn
|
||||||
|
// on IFF_VNET_HDR after TUNSETIFF, so skip offload on inherited fds.
|
||||||
|
t, err := newTunGeneric(c, l, deviceFd, false, vpnNetworks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -243,46 +326,83 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
// openTunDev opens /dev/net/tun, creating the device node first if it's
|
||||||
|
// missing (docker containers occasionally omit it).
|
||||||
|
func openTunDev() (int, error) {
|
||||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err == nil {
|
||||||
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
|
return fd, nil
|
||||||
if os.IsNotExist(err) {
|
|
||||||
err = os.MkdirAll("/dev/net", 0755)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
|
|
||||||
}
|
}
|
||||||
err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200)))
|
if !os.IsNotExist(err) {
|
||||||
if err != nil {
|
return -1, err
|
||||||
return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
|
}
|
||||||
|
if err = os.MkdirAll("/dev/net", 0755); err != nil {
|
||||||
|
return -1, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
|
||||||
|
}
|
||||||
|
if err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))); err != nil {
|
||||||
|
return -1, fmt.Errorf("failed to create /dev/net/tun: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err)
|
return -1, fmt.Errorf("created /dev/net/tun, but still failed: %w", err)
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
return fd, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tunSetIff runs TUNSETIFF with the given flags and returns the kernel-chosen
|
||||||
|
// device name on success.
|
||||||
|
func tunSetIff(fd int, name string, flags uint16) (string, error) {
|
||||||
var req ifReq
|
var req ifReq
|
||||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
|
req.Flags = flags
|
||||||
|
copy(req.Name[:], name)
|
||||||
|
if err := ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return strings.Trim(string(req.Name[:]), "\x00"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tsoOffloadFlags are the TUN_F_* bits we ask the kernel to enable when a
|
||||||
|
// TSO-capable TUN is available. CSUM is required as a prerequisite for TSO.
|
||||||
|
const tsoOffloadFlags = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
|
||||||
|
|
||||||
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
||||||
|
baseFlags := uint16(unix.IFF_TUN | unix.IFF_NO_PI)
|
||||||
if multiqueue {
|
if multiqueue {
|
||||||
req.Flags |= unix.IFF_MULTI_QUEUE
|
baseFlags |= unix.IFF_MULTI_QUEUE
|
||||||
}
|
}
|
||||||
nameStr := c.GetString("tun.dev", "")
|
nameStr := c.GetString("tun.dev", "")
|
||||||
copy(req.Name[:], nameStr)
|
|
||||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
|
||||||
_ = unix.Close(fd)
|
|
||||||
return nil, &NameError{
|
|
||||||
Name: nameStr,
|
|
||||||
Underlying: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
name := strings.Trim(string(req.Name[:]), "\x00")
|
|
||||||
|
|
||||||
t, err := newTunGeneric(c, l, fd, vpnNetworks)
|
// First try to open with IFF_VNET_HDR + TUNSETOFFLOAD so we can receive
|
||||||
|
// TSO superpackets. If either step fails (older kernel, unprivileged
|
||||||
|
// container, etc.) we close and fall back to a plain TUN.
|
||||||
|
fd, err := openTunDev()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
vnetHdr := true
|
||||||
|
name, err := tunSetIff(fd, nameStr, baseFlags|unix.IFF_VNET_HDR)
|
||||||
|
if err != nil {
|
||||||
|
_ = unix.Close(fd)
|
||||||
|
vnetHdr = false
|
||||||
|
} else if err = ioctl(uintptr(fd), unix.TUNSETOFFLOAD, uintptr(tsoOffloadFlags)); err != nil {
|
||||||
|
l.WithError(err).Warn("Failed to enable TUN offload (TSO); proceeding without virtio headers")
|
||||||
|
_ = unix.Close(fd)
|
||||||
|
vnetHdr = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !vnetHdr {
|
||||||
|
fd, err = openTunDev()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
name, err = tunSetIff(fd, nameStr, baseFlags)
|
||||||
|
if err != nil {
|
||||||
|
_ = unix.Close(fd)
|
||||||
|
return nil, &NameError{Name: nameStr, Underlying: err}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t, err := newTunGeneric(c, l, fd, vnetHdr, vpnNetworks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -293,8 +413,8 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|||||||
}
|
}
|
||||||
|
|
||||||
// newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error.
|
// newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error.
|
||||||
func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vnetHdr bool, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
tfd, err := newTunFd(fd)
|
tfd, err := newTunFd(fd, vnetHdr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = unix.Close(fd)
|
_ = unix.Close(fd)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -413,14 +533,22 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var req ifReq
|
flags := uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
|
||||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
|
if t.vnetHdr {
|
||||||
copy(req.Name[:], t.Device)
|
flags |= unix.IFF_VNET_HDR
|
||||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
}
|
||||||
|
if _, err = tunSetIff(fd, t.Device, flags); err != nil {
|
||||||
_ = unix.Close(fd)
|
_ = unix.Close(fd)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if t.vnetHdr {
|
||||||
|
if err = ioctl(uintptr(fd), unix.TUNSETOFFLOAD, uintptr(tsoOffloadFlags)); err != nil {
|
||||||
|
_ = unix.Close(fd)
|
||||||
|
return nil, fmt.Errorf("failed to enable offload on multiqueue tun fd: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
out, err := t.tunFile.newFriend(fd)
|
out, err := t.tunFile.newFriend(fd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = unix.Close(fd)
|
_ = unix.Close(fd)
|
||||||
|
|||||||
234
overlay/tun_linux_offload.go
Normal file
234
overlay/tun_linux_offload.go
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
//go:build !android && !e2e_testing
|
||||||
|
// +build !android,!e2e_testing
|
||||||
|
|
||||||
|
package overlay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Size of the legacy struct virtio_net_hdr that the kernel prepends/expects on
|
||||||
|
// a TUN opened with IFF_VNET_HDR (TUNSETVNETHDRSZ not set).
|
||||||
|
const virtioNetHdrLen = 10
|
||||||
|
|
||||||
|
// Maximum size we accept for a single read from a TUN with IFF_VNET_HDR. A
|
||||||
|
// TSO superpacket can be up to 64KiB of payload plus a single L2/L3/L4 header
|
||||||
|
// prefix plus the virtio header.
|
||||||
|
const tunReadBufSize = 65535
|
||||||
|
|
||||||
|
// Space for segmented output. Worst case is many small segments, each paying
|
||||||
|
// an IP+TCP header. 128KiB comfortably covers the 64KiB payload ceiling.
|
||||||
|
const tunSegBufSize = 131072
|
||||||
|
|
||||||
|
type virtioNetHdr struct {
|
||||||
|
Flags uint8
|
||||||
|
GSOType uint8
|
||||||
|
HdrLen uint16
|
||||||
|
GSOSize uint16
|
||||||
|
CsumStart uint16
|
||||||
|
CsumOffset uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// decode reads a virtio_net_hdr in host byte order (TUN default; we never
|
||||||
|
// call TUNSETVNETLE so the kernel matches our endianness).
|
||||||
|
func (h *virtioNetHdr) decode(b []byte) {
|
||||||
|
h.Flags = b[0]
|
||||||
|
h.GSOType = b[1]
|
||||||
|
h.HdrLen = binary.NativeEndian.Uint16(b[2:4])
|
||||||
|
h.GSOSize = binary.NativeEndian.Uint16(b[4:6])
|
||||||
|
h.CsumStart = binary.NativeEndian.Uint16(b[6:8])
|
||||||
|
h.CsumOffset = binary.NativeEndian.Uint16(b[8:10])
|
||||||
|
}
|
||||||
|
|
||||||
|
// segmentInto splits a TUN-side packet described by hdr into one or more
|
||||||
|
// IP packets, each appended to *out as a slice of scratch. scratch must be
|
||||||
|
// sized to hold every segment (including replicated headers).
|
||||||
|
func segmentInto(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) error {
|
||||||
|
switch hdr.GSOType {
|
||||||
|
case unix.VIRTIO_NET_HDR_GSO_NONE:
|
||||||
|
if len(pkt) > len(scratch) {
|
||||||
|
return fmt.Errorf("packet larger than segment buffer: %d > %d", len(pkt), len(scratch))
|
||||||
|
}
|
||||||
|
copy(scratch, pkt)
|
||||||
|
seg := scratch[:len(pkt)]
|
||||||
|
if hdr.Flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
|
||||||
|
if err := finishChecksum(seg, hdr); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*out = append(*out, seg)
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case unix.VIRTIO_NET_HDR_GSO_TCPV4, unix.VIRTIO_NET_HDR_GSO_TCPV6:
|
||||||
|
return segmentTCP(pkt, hdr, out, scratch)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported virtio gso type: %d", hdr.GSOType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// finishChecksum computes the L4 checksum for a non-GSO packet that the kernel
|
||||||
|
// handed us with NEEDS_CSUM set. csum_start / csum_offset point at the 16-bit
|
||||||
|
// checksum field; we zero it, fold a full sum (the field was pre-loaded with
|
||||||
|
// the pseudo-header partial sum by the kernel), and store the result.
|
||||||
|
func finishChecksum(seg []byte, hdr virtioNetHdr) error {
|
||||||
|
cs := int(hdr.CsumStart)
|
||||||
|
co := int(hdr.CsumOffset)
|
||||||
|
if cs+co+2 > len(seg) {
|
||||||
|
return fmt.Errorf("csum offsets out of range: start=%d offset=%d len=%d", cs, co, len(seg))
|
||||||
|
}
|
||||||
|
// The kernel stores a partial pseudo-header sum at [cs+co:]; sum over the
|
||||||
|
// L4 region starting at cs, folding the prior partial in as the seed.
|
||||||
|
partial := uint32(binary.BigEndian.Uint16(seg[cs+co : cs+co+2]))
|
||||||
|
seg[cs+co] = 0
|
||||||
|
seg[cs+co+1] = 0
|
||||||
|
sum := checksumBytes(seg[cs:], partial)
|
||||||
|
binary.BigEndian.PutUint16(seg[cs+co:cs+co+2], checksumFold(sum))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// segmentTCP software-segments a TSO superpacket into one IP packet per MSS
|
||||||
|
// chunk. The caller guarantees hdr.GSOType is TCPV4 or TCPV6.
|
||||||
|
func segmentTCP(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) error {
|
||||||
|
if hdr.GSOSize == 0 {
|
||||||
|
return fmt.Errorf("gso_size is zero")
|
||||||
|
}
|
||||||
|
if int(hdr.HdrLen) > len(pkt) || hdr.HdrLen == 0 {
|
||||||
|
return fmt.Errorf("hdr_len %d out of range (pkt %d)", hdr.HdrLen, len(pkt))
|
||||||
|
}
|
||||||
|
if hdr.CsumStart == 0 || hdr.CsumStart >= hdr.HdrLen {
|
||||||
|
return fmt.Errorf("csum_start %d out of range (hdr_len %d)", hdr.CsumStart, hdr.HdrLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
isV4 := hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_TCPV4
|
||||||
|
headerLen := int(hdr.HdrLen)
|
||||||
|
csumStart := int(hdr.CsumStart)
|
||||||
|
|
||||||
|
if isV4 && csumStart < 20 {
|
||||||
|
return fmt.Errorf("csum_start %d too small for IPv4", csumStart)
|
||||||
|
}
|
||||||
|
if !isV4 && csumStart < 40 {
|
||||||
|
return fmt.Errorf("csum_start %d too small for IPv6", csumStart)
|
||||||
|
}
|
||||||
|
if headerLen-csumStart < 20 {
|
||||||
|
return fmt.Errorf("tcp header region too small: %d", headerLen-csumStart)
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := pkt[headerLen:]
|
||||||
|
payLen := len(payload)
|
||||||
|
gso := int(hdr.GSOSize)
|
||||||
|
numSeg := (payLen + gso - 1) / gso
|
||||||
|
if numSeg == 0 {
|
||||||
|
numSeg = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
need := numSeg*headerLen + payLen
|
||||||
|
if need > len(scratch) {
|
||||||
|
return fmt.Errorf("scratch too small for %d segments: need %d have %d", numSeg, need, len(scratch))
|
||||||
|
}
|
||||||
|
|
||||||
|
origSeq := binary.BigEndian.Uint32(pkt[csumStart+4 : csumStart+8])
|
||||||
|
origFlags := pkt[csumStart+13]
|
||||||
|
const tcpFinPsh = 0x09 // FIN(0x01) | PSH(0x08)
|
||||||
|
|
||||||
|
var origIPID uint16
|
||||||
|
if isV4 {
|
||||||
|
origIPID = binary.BigEndian.Uint16(pkt[4:6])
|
||||||
|
}
|
||||||
|
|
||||||
|
off := 0
|
||||||
|
for i := 0; i < numSeg; i++ {
|
||||||
|
segStart := i * gso
|
||||||
|
segEnd := segStart + gso
|
||||||
|
if segEnd > payLen {
|
||||||
|
segEnd = payLen
|
||||||
|
}
|
||||||
|
segPayLen := segEnd - segStart
|
||||||
|
|
||||||
|
// Materialise IP+TCP header and this segment's payload chunk.
|
||||||
|
copy(scratch[off:], pkt[:headerLen])
|
||||||
|
copy(scratch[off+headerLen:], payload[segStart:segEnd])
|
||||||
|
seg := scratch[off : off+headerLen+segPayLen]
|
||||||
|
off += headerLen + segPayLen
|
||||||
|
|
||||||
|
// Fix IP header: total/payload length, v4 ID, v4 header csum.
|
||||||
|
if isV4 {
|
||||||
|
ihl := int(seg[0]&0x0f) * 4
|
||||||
|
if ihl < 20 || ihl > csumStart {
|
||||||
|
return fmt.Errorf("bad IPv4 IHL: %d", ihl)
|
||||||
|
}
|
||||||
|
binary.BigEndian.PutUint16(seg[2:4], uint16(headerLen+segPayLen))
|
||||||
|
binary.BigEndian.PutUint16(seg[4:6], origIPID+uint16(i))
|
||||||
|
seg[10] = 0
|
||||||
|
seg[11] = 0
|
||||||
|
binary.BigEndian.PutUint16(seg[10:12], checksumFold(checksumBytes(seg[:ihl], 0)))
|
||||||
|
} else {
|
||||||
|
// IPv6 payload length excludes the 40-byte fixed header but
|
||||||
|
// includes any extension headers that sit between [40:csumStart].
|
||||||
|
binary.BigEndian.PutUint16(seg[4:6], uint16(headerLen-40+segPayLen))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fix TCP header: seq, flags, checksum.
|
||||||
|
segSeq := origSeq + uint32(segStart)
|
||||||
|
binary.BigEndian.PutUint32(seg[csumStart+4:csumStart+8], segSeq)
|
||||||
|
if i != numSeg-1 {
|
||||||
|
seg[csumStart+13] = origFlags &^ tcpFinPsh
|
||||||
|
} else {
|
||||||
|
seg[csumStart+13] = origFlags
|
||||||
|
}
|
||||||
|
seg[csumStart+16] = 0
|
||||||
|
seg[csumStart+17] = 0
|
||||||
|
|
||||||
|
tcpLen := headerLen - csumStart + segPayLen
|
||||||
|
var psum uint32
|
||||||
|
if isV4 {
|
||||||
|
psum = pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, tcpLen)
|
||||||
|
} else {
|
||||||
|
psum = pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_TCP, tcpLen)
|
||||||
|
}
|
||||||
|
binary.BigEndian.PutUint16(seg[csumStart+16:csumStart+18], checksumFold(checksumBytes(seg[csumStart:csumStart+tcpLen], psum)))
|
||||||
|
|
||||||
|
*out = append(*out, seg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func checksumBytes(b []byte, initial uint32) uint32 {
|
||||||
|
sum := initial
|
||||||
|
for len(b) >= 2 {
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(b[:2]))
|
||||||
|
b = b[2:]
|
||||||
|
}
|
||||||
|
if len(b) == 1 {
|
||||||
|
sum += uint32(b[0]) << 8
|
||||||
|
}
|
||||||
|
return sum
|
||||||
|
}
|
||||||
|
|
||||||
|
func checksumFold(sum uint32) uint16 {
|
||||||
|
for sum>>16 != 0 {
|
||||||
|
sum = (sum & 0xffff) + (sum >> 16)
|
||||||
|
}
|
||||||
|
return ^uint16(sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
func pseudoHeaderIPv4(src, dst []byte, proto byte, tcpLen int) uint32 {
|
||||||
|
sum := checksumBytes(src, 0)
|
||||||
|
sum = checksumBytes(dst, sum)
|
||||||
|
sum += uint32(proto)
|
||||||
|
sum += uint32(tcpLen)
|
||||||
|
return sum
|
||||||
|
}
|
||||||
|
|
||||||
|
func pseudoHeaderIPv6(src, dst []byte, proto byte, tcpLen int) uint32 {
|
||||||
|
sum := checksumBytes(src, 0)
|
||||||
|
sum = checksumBytes(dst, sum)
|
||||||
|
sum += uint32(tcpLen >> 16)
|
||||||
|
sum += uint32(tcpLen & 0xffff)
|
||||||
|
sum += uint32(proto)
|
||||||
|
return sum
|
||||||
|
}
|
||||||
247
overlay/tun_linux_offload_test.go
Normal file
247
overlay/tun_linux_offload_test.go
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
//go:build !android && !e2e_testing
|
||||||
|
// +build !android,!e2e_testing
|
||||||
|
|
||||||
|
package overlay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
// verifyChecksum confirms that the one's-complement sum across `b`, optionally
|
||||||
|
// seeded with a pseudo-header sum, folds to all-ones (valid).
|
||||||
|
func verifyChecksum(b []byte, pseudo uint32) bool {
|
||||||
|
sum := checksumBytes(b, pseudo)
|
||||||
|
for sum>>16 != 0 {
|
||||||
|
sum = (sum & 0xffff) + (sum >> 16)
|
||||||
|
}
|
||||||
|
return uint16(sum) == 0xffff
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildTSOv4 builds a synthetic IPv4/TCP TSO superpacket with a payload of
|
||||||
|
// `payLen` bytes split at `mss`.
|
||||||
|
func buildTSOv4(t *testing.T, payLen, mss int) ([]byte, virtioNetHdr) {
|
||||||
|
t.Helper()
|
||||||
|
const ipLen = 20
|
||||||
|
const tcpLen = 20
|
||||||
|
pkt := make([]byte, ipLen+tcpLen+payLen)
|
||||||
|
|
||||||
|
// IPv4 header
|
||||||
|
pkt[0] = 0x45 // version 4, IHL 5
|
||||||
|
// total length is meaningless for TSO but set it anyway
|
||||||
|
binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+tcpLen+payLen))
|
||||||
|
binary.BigEndian.PutUint16(pkt[4:6], 0x4242) // original ID
|
||||||
|
pkt[8] = 64 // TTL
|
||||||
|
pkt[9] = unix.IPPROTO_TCP
|
||||||
|
copy(pkt[12:16], []byte{10, 0, 0, 1}) // src
|
||||||
|
copy(pkt[16:20], []byte{10, 0, 0, 2}) // dst
|
||||||
|
|
||||||
|
// TCP header
|
||||||
|
binary.BigEndian.PutUint16(pkt[20:22], 12345) // sport
|
||||||
|
binary.BigEndian.PutUint16(pkt[22:24], 80) // dport
|
||||||
|
binary.BigEndian.PutUint32(pkt[24:28], 10000) // seq
|
||||||
|
binary.BigEndian.PutUint32(pkt[28:32], 20000) // ack
|
||||||
|
pkt[32] = 0x50 // data offset 5 words
|
||||||
|
pkt[33] = 0x18 // ACK | PSH
|
||||||
|
binary.BigEndian.PutUint16(pkt[34:36], 65535) // window
|
||||||
|
|
||||||
|
// payload
|
||||||
|
for i := 0; i < payLen; i++ {
|
||||||
|
pkt[ipLen+tcpLen+i] = byte(i & 0xff)
|
||||||
|
}
|
||||||
|
|
||||||
|
return pkt, virtioNetHdr{
|
||||||
|
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||||||
|
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
|
||||||
|
HdrLen: uint16(ipLen + tcpLen),
|
||||||
|
GSOSize: uint16(mss),
|
||||||
|
CsumStart: uint16(ipLen),
|
||||||
|
CsumOffset: 16,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSegmentTCPv4(t *testing.T) {
|
||||||
|
const mss = 100
|
||||||
|
const numSeg = 3
|
||||||
|
pkt, hdr := buildTSOv4(t, mss*numSeg, mss)
|
||||||
|
|
||||||
|
scratch := make([]byte, tunSegBufSize)
|
||||||
|
var out [][]byte
|
||||||
|
if err := segmentTCP(pkt, hdr, &out, scratch); err != nil {
|
||||||
|
t.Fatalf("segmentTCP: %v", err)
|
||||||
|
}
|
||||||
|
if len(out) != numSeg {
|
||||||
|
t.Fatalf("expected %d segments, got %d", numSeg, len(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, seg := range out {
|
||||||
|
if len(seg) != 40+mss {
|
||||||
|
t.Errorf("seg %d: unexpected len %d", i, len(seg))
|
||||||
|
}
|
||||||
|
totalLen := binary.BigEndian.Uint16(seg[2:4])
|
||||||
|
if totalLen != uint16(40+mss) {
|
||||||
|
t.Errorf("seg %d: total_len=%d want %d", i, totalLen, 40+mss)
|
||||||
|
}
|
||||||
|
id := binary.BigEndian.Uint16(seg[4:6])
|
||||||
|
if id != 0x4242+uint16(i) {
|
||||||
|
t.Errorf("seg %d: ip id=%#x want %#x", i, id, 0x4242+uint16(i))
|
||||||
|
}
|
||||||
|
seq := binary.BigEndian.Uint32(seg[24:28])
|
||||||
|
wantSeq := uint32(10000 + i*mss)
|
||||||
|
if seq != wantSeq {
|
||||||
|
t.Errorf("seg %d: seq=%d want %d", i, seq, wantSeq)
|
||||||
|
}
|
||||||
|
flags := seg[33]
|
||||||
|
wantFlags := byte(0x10) // ACK only, PSH cleared
|
||||||
|
if i == numSeg-1 {
|
||||||
|
wantFlags = 0x18 // ACK | PSH preserved on last
|
||||||
|
}
|
||||||
|
if flags != wantFlags {
|
||||||
|
t.Errorf("seg %d: flags=%#x want %#x", i, flags, wantFlags)
|
||||||
|
}
|
||||||
|
// IPv4 header checksum must verify against itself.
|
||||||
|
if !verifyChecksum(seg[:20], 0) {
|
||||||
|
t.Errorf("seg %d: bad IPv4 header checksum", i)
|
||||||
|
}
|
||||||
|
// TCP checksum must verify against the pseudo-header.
|
||||||
|
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+mss)
|
||||||
|
if !verifyChecksum(seg[20:], psum) {
|
||||||
|
t.Errorf("seg %d: bad TCP checksum", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSegmentTCPv4OddTail(t *testing.T) {
|
||||||
|
// Payload of 250 bytes with MSS 100 → segments of 100, 100, 50.
|
||||||
|
pkt, hdr := buildTSOv4(t, 250, 100)
|
||||||
|
scratch := make([]byte, tunSegBufSize)
|
||||||
|
var out [][]byte
|
||||||
|
if err := segmentTCP(pkt, hdr, &out, scratch); err != nil {
|
||||||
|
t.Fatalf("segmentTCP: %v", err)
|
||||||
|
}
|
||||||
|
if len(out) != 3 {
|
||||||
|
t.Fatalf("want 3 segments, got %d", len(out))
|
||||||
|
}
|
||||||
|
wantPayLens := []int{100, 100, 50}
|
||||||
|
for i, seg := range out {
|
||||||
|
if len(seg)-40 != wantPayLens[i] {
|
||||||
|
t.Errorf("seg %d: pay len %d want %d", i, len(seg)-40, wantPayLens[i])
|
||||||
|
}
|
||||||
|
if !verifyChecksum(seg[:20], 0) {
|
||||||
|
t.Errorf("seg %d: bad IPv4 header checksum", i)
|
||||||
|
}
|
||||||
|
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+wantPayLens[i])
|
||||||
|
if !verifyChecksum(seg[20:], psum) {
|
||||||
|
t.Errorf("seg %d: bad TCP checksum", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSegmentTCPv6(t *testing.T) {
|
||||||
|
const ipLen = 40
|
||||||
|
const tcpLen = 20
|
||||||
|
const mss = 120
|
||||||
|
const numSeg = 2
|
||||||
|
payLen := mss * numSeg
|
||||||
|
pkt := make([]byte, ipLen+tcpLen+payLen)
|
||||||
|
|
||||||
|
// IPv6 header
|
||||||
|
pkt[0] = 0x60 // version 6
|
||||||
|
binary.BigEndian.PutUint16(pkt[4:6], uint16(tcpLen+payLen))
|
||||||
|
pkt[6] = unix.IPPROTO_TCP
|
||||||
|
pkt[7] = 64
|
||||||
|
// src/dst fe80::1 / fe80::2
|
||||||
|
pkt[8] = 0xfe
|
||||||
|
pkt[9] = 0x80
|
||||||
|
pkt[23] = 1
|
||||||
|
pkt[24] = 0xfe
|
||||||
|
pkt[25] = 0x80
|
||||||
|
pkt[39] = 2
|
||||||
|
|
||||||
|
// TCP header
|
||||||
|
binary.BigEndian.PutUint16(pkt[40:42], 12345)
|
||||||
|
binary.BigEndian.PutUint16(pkt[42:44], 80)
|
||||||
|
binary.BigEndian.PutUint32(pkt[44:48], 7)
|
||||||
|
binary.BigEndian.PutUint32(pkt[48:52], 99)
|
||||||
|
pkt[52] = 0x50
|
||||||
|
pkt[53] = 0x19 // FIN | ACK | PSH — exercise FIN clearing too
|
||||||
|
binary.BigEndian.PutUint16(pkt[54:56], 65535)
|
||||||
|
|
||||||
|
for i := 0; i < payLen; i++ {
|
||||||
|
pkt[ipLen+tcpLen+i] = byte(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
hdr := virtioNetHdr{
|
||||||
|
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||||||
|
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
|
||||||
|
HdrLen: uint16(ipLen + tcpLen),
|
||||||
|
GSOSize: uint16(mss),
|
||||||
|
CsumStart: uint16(ipLen),
|
||||||
|
CsumOffset: 16,
|
||||||
|
}
|
||||||
|
|
||||||
|
scratch := make([]byte, tunSegBufSize)
|
||||||
|
var out [][]byte
|
||||||
|
if err := segmentTCP(pkt, hdr, &out, scratch); err != nil {
|
||||||
|
t.Fatalf("segmentTCP: %v", err)
|
||||||
|
}
|
||||||
|
if len(out) != numSeg {
|
||||||
|
t.Fatalf("want %d segments, got %d", numSeg, len(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, seg := range out {
|
||||||
|
if len(seg) != ipLen+tcpLen+mss {
|
||||||
|
t.Errorf("seg %d: len %d want %d", i, len(seg), ipLen+tcpLen+mss)
|
||||||
|
}
|
||||||
|
pl := binary.BigEndian.Uint16(seg[4:6])
|
||||||
|
if pl != uint16(tcpLen+mss) {
|
||||||
|
t.Errorf("seg %d: payload_length=%d want %d", i, pl, tcpLen+mss)
|
||||||
|
}
|
||||||
|
seq := binary.BigEndian.Uint32(seg[44:48])
|
||||||
|
if seq != uint32(7+i*mss) {
|
||||||
|
t.Errorf("seg %d: seq=%d want %d", i, seq, 7+i*mss)
|
||||||
|
}
|
||||||
|
flags := seg[53]
|
||||||
|
// Original flags = 0x19 (FIN|ACK|PSH). FIN(0x01)+PSH(0x08) should be
|
||||||
|
// cleared on all but the last; ACK(0x10) always preserved.
|
||||||
|
wantFlags := byte(0x10)
|
||||||
|
if i == numSeg-1 {
|
||||||
|
wantFlags = 0x19
|
||||||
|
}
|
||||||
|
if flags != wantFlags {
|
||||||
|
t.Errorf("seg %d: flags=%#x want %#x", i, flags, wantFlags)
|
||||||
|
}
|
||||||
|
psum := pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_TCP, tcpLen+mss)
|
||||||
|
if !verifyChecksum(seg[ipLen:], psum) {
|
||||||
|
t.Errorf("seg %d: bad TCP checksum", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSegmentGSONonePassesThrough(t *testing.T) {
|
||||||
|
pkt, hdr := buildTSOv4(t, 100, 100)
|
||||||
|
hdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE
|
||||||
|
hdr.Flags = 0 // no NEEDS_CSUM, leave packet untouched
|
||||||
|
|
||||||
|
scratch := make([]byte, tunSegBufSize)
|
||||||
|
var out [][]byte
|
||||||
|
if err := segmentInto(pkt, hdr, &out, scratch); err != nil {
|
||||||
|
t.Fatalf("segmentInto: %v", err)
|
||||||
|
}
|
||||||
|
if len(out) != 1 {
|
||||||
|
t.Fatalf("want 1 segment, got %d", len(out))
|
||||||
|
}
|
||||||
|
if len(out[0]) != len(pkt) {
|
||||||
|
t.Fatalf("unexpected length: %d vs %d", len(out[0]), len(pkt))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSegmentRejectsUDP(t *testing.T) {
|
||||||
|
hdr := virtioNetHdr{GSOType: unix.VIRTIO_NET_HDR_GSO_UDP}
|
||||||
|
var out [][]byte
|
||||||
|
if err := segmentInto(nil, hdr, &out, nil); err == nil {
|
||||||
|
t.Fatalf("expected rejection for UDP GSO")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user