use wg tun library; batching & locking improvements

This commit is contained in:
Jay Wren
2025-11-04 15:04:24 -05:00
parent 42bee7cf17
commit 5cc3ff594a
33 changed files with 1353 additions and 121 deletions

View File

@@ -9,7 +9,6 @@ import (
"net"
"net/netip"
"os"
"strings"
"sync"
"sync/atomic"
"time"
@@ -22,10 +21,12 @@ import (
"github.com/slackhq/nebula/util"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
wgtun "golang.zx2c4.com/wireguard/tun"
)
type tun struct {
io.ReadWriteCloser
wgDevice wgtun.Device
fd int
Device string
vpnNetworks []netip.Prefix
@@ -71,63 +72,154 @@ type ifreqQLEN struct {
pad [8]byte
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
// wgDeviceWrapper wraps a wireguard Device to implement io.ReadWriteCloser
// This allows multiqueue readers to use the same wireguard Device batching as the main device
type wgDeviceWrapper struct {
dev wgtun.Device
buf []byte // Reusable buffer for single packet reads
}
func (w *wgDeviceWrapper) Read(b []byte) (int, error) {
// Use wireguard Device's batch API for single packet
bufs := [][]byte{b}
sizes := make([]int, 1)
n, err := w.dev.Read(bufs, sizes, 0)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.EOF
}
return sizes[0], nil
}
func (w *wgDeviceWrapper) Write(b []byte) (int, error) {
// Buffer b should have virtio header space (10 bytes) at the beginning
// The decrypted packet data starts at offset 10
// Pass the full buffer to WireGuard with offset=virtioNetHdrLen
bufs := [][]byte{b}
n, err := w.dev.Write(bufs, VirtioNetHdrLen)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.ErrShortWrite
}
return len(b), nil
}
func (w *wgDeviceWrapper) WriteBatch(bufs [][]byte, offset int) (int, error) {
// Pass all buffers to WireGuard's batch write
return w.dev.Write(bufs, offset)
}
func (w *wgDeviceWrapper) Close() error {
return w.dev.Close()
}
// BatchRead implements batching for multiqueue readers
func (w *wgDeviceWrapper) BatchRead(bufs [][]byte, sizes []int) (int, error) {
// The zero here is offset.
return w.dev.Read(bufs, sizes, 0)
}
// BatchSize returns the optimal batch size
func (w *wgDeviceWrapper) BatchSize() int {
return w.dev.BatchSize()
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
wgDev, name, err := wgtun.CreateUnmonitoredTUNFromFD(deviceFd)
if err != nil {
return nil, fmt.Errorf("failed to create TUN from FD: %w", err)
}
file := wgDev.File()
t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil {
_ = wgDev.Close()
return nil, err
}
t.Device = "tun0"
t.wgDevice = wgDev
t.Device = name
return t, nil
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
// Check if /dev/net/tun exists, create if needed (for docker containers)
if _, err := os.Stat("/dev/net/tun"); os.IsNotExist(err) {
if err := os.MkdirAll("/dev/net", 0755); err != nil {
return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
}
if err := unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))); err != nil {
return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
}
}
devName := c.GetString("tun.dev", "")
mtu := c.GetInt("tun.mtu", DefaultMTU)
// Create TUN device manually to support multiqueue
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
if os.IsNotExist(err) {
err = os.MkdirAll("/dev/net", 0755)
if err != nil {
return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
}
err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200)))
if err != nil {
return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
}
fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err)
}
} else {
return nil, err
}
}
var req ifReq
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
if multiqueue {
req.Flags |= unix.IFF_MULTI_QUEUE
}
nameStr := c.GetString("tun.dev", "")
copy(req.Name[:], nameStr)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
return nil, &NameError{
Name: nameStr,
Underlying: err,
}
}
name := strings.Trim(string(req.Name[:]), "\x00")
file := os.NewFile(uintptr(fd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil {
return nil, err
}
var req ifReq
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR)
if multiqueue {
req.Flags |= unix.IFF_MULTI_QUEUE
}
copy(req.Name[:], devName)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
unix.Close(fd)
return nil, err
}
// Set nonblocking
if err = unix.SetNonblock(fd, true); err != nil {
unix.Close(fd)
return nil, err
}
// Enable TCP and UDP offload (TSO/GRO) for performance
// This allows the kernel to handle segmentation/coalescing
const (
tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
)
offloads := tunTCPOffloads | tunUDPOffloads
if err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, offloads); err != nil {
// Log warning but don't fail - offload is optional
l.WithError(err).Warn("Failed to enable TUN offload (TSO/GRO), performance may be reduced")
}
file := os.NewFile(uintptr(fd), "/dev/net/tun")
// Create wireguard device from file descriptor
wgDev, err := wgtun.CreateTUNFromFile(file, mtu)
if err != nil {
file.Close()
return nil, fmt.Errorf("failed to create TUN from file: %w", err)
}
name, err := wgDev.Name()
if err != nil {
_ = wgDev.Close()
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
}
// file is now owned by wgDev, get a new reference
file = wgDev.File()
t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil {
_ = wgDev.Close()
return nil, err
}
t.wgDevice = wgDev
t.Device = name
return t, nil
@@ -238,22 +330,44 @@ func (t *tun) SupportsMultiqueue() bool {
return true
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, err
}
var req ifReq
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
// MUST match the flags used in newTun - includes IFF_VNET_HDR
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR | unix.IFF_MULTI_QUEUE)
copy(req.Name[:], t.Device)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
unix.Close(fd)
return nil, err
}
// Set nonblocking mode - CRITICAL for proper netpoller integration
if err = unix.SetNonblock(fd, true); err != nil {
unix.Close(fd)
return nil, err
}
// Get MTU from main device
mtu := t.MaxMTU
if mtu == 0 {
mtu = DefaultMTU
}
file := os.NewFile(uintptr(fd), "/dev/net/tun")
return file, nil
// Create wireguard Device from the file descriptor (just like the main device)
wgDev, err := wgtun.CreateTUNFromFile(file, mtu)
if err != nil {
file.Close()
return nil, fmt.Errorf("failed to create multiqueue TUN device: %w", err)
}
// Return a wrapper that uses the wireguard Device for all I/O
return &wgDeviceWrapper{dev: wgDev}, nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
@@ -261,7 +375,68 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
return r
}
func (t *tun) Read(b []byte) (int, error) {
if t.wgDevice != nil {
// Use wireguard device which handles virtio headers internally
bufs := [][]byte{b}
sizes := make([]int, 1)
n, err := t.wgDevice.Read(bufs, sizes, 0)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.EOF
}
return sizes[0], nil
}
// Fallback: direct read from file (shouldn't happen in normal operation)
return t.ReadWriteCloser.Read(b)
}
// BatchRead reads multiple packets at once for improved performance
// bufs: slice of buffers to read into
// sizes: slice that will be filled with packet sizes
// Returns number of packets read
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
if t.wgDevice != nil {
return t.wgDevice.Read(bufs, sizes, 0)
}
// Fallback: single packet read
n, err := t.ReadWriteCloser.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// BatchSize returns the optimal number of packets to read/write in a batch
func (t *tun) BatchSize() int {
if t.wgDevice != nil {
return t.wgDevice.BatchSize()
}
return 1
}
func (t *tun) Write(b []byte) (int, error) {
if t.wgDevice != nil {
// Buffer b should have virtio header space (10 bytes) at the beginning
// The decrypted packet data starts at offset 10
// Pass the full buffer to WireGuard with offset=virtioNetHdrLen
bufs := [][]byte{b}
n, err := t.wgDevice.Write(bufs, VirtioNetHdrLen)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.ErrShortWrite
}
return len(b), nil
}
// Fallback: direct write (shouldn't happen in normal operation)
var nn int
maximum := len(b)
@@ -284,6 +459,22 @@ func (t *tun) Write(b []byte) (int, error) {
}
}
// WriteBatch writes multiple packets to the TUN device in a single syscall
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
if t.wgDevice != nil {
return t.wgDevice.Write(bufs, offset)
}
// Fallback: write individually (shouldn't happen in normal operation)
for i, buf := range bufs {
_, err := t.Write(buf)
if err != nil {
return i, err
}
}
return len(bufs), nil
}
func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device {
o[i] = byte(c)
@@ -711,6 +902,10 @@ func (t *tun) Close() error {
close(t.routeChan)
}
if t.wgDevice != nil {
_ = t.wgDevice.Close()
}
if t.ReadWriteCloser != nil {
_ = t.ReadWriteCloser.Close()
}