This commit is contained in:
Jay Wren
2025-11-11 13:15:30 -05:00
parent 3344a840d1
commit b68e504865
2 changed files with 206 additions and 4 deletions

View File

@@ -276,9 +276,26 @@ func (f *Interface) listenOut(i int) {
})
}
// BatchReader is an interface for devices that support reading multiple packets at once
type BatchReader interface {
BatchRead(bufs [][]byte, sizes []int) (int, error)
BatchSize() int
}
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
runtime.LockOSThread()
// Check if reader supports batching
batchReader, supportsBatching := reader.(BatchReader)
if supportsBatching {
f.listenInBatch(reader, batchReader, i)
} else {
f.listenInSingle(reader, i)
}
}
func (f *Interface) listenInSingle(reader io.ReadWriteCloser, i int) {
packet := make([]byte, mtu)
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
@@ -302,6 +319,42 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
}
}
func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchReader, i int) {
batchSize := batchReader.BatchSize()
// Allocate buffers for batch reading
bufs := make([][]byte, batchSize)
for idx := range bufs {
bufs[idx] = make([]byte, mtu)
}
sizes := make([]int, batchSize)
// Per-packet state (reused across batches)
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
for {
n, err := batchReader.BatchRead(bufs, sizes)
if err != nil {
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return
}
f.l.WithError(err).Error("Error while batch reading outbound packets")
// This only seems to happen when something fatal happens to the fd, so exit.
os.Exit(2)
}
// Process each packet in the batch
for j := 0; j < n; j++ {
f.consumeInsidePacket(bufs[j][:sizes[j]], fwPacket, nb, out, i, conntrackCache.Get(f.l))
}
}
}
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
c.RegisterReloadCallback(f.reloadFirewall)
c.RegisterReloadCallback(f.reloadSendRecvError)

View File

@@ -66,6 +66,78 @@ type ifreqQLEN struct {
pad [8]byte
}
const (
virtioNetHdrLen = 10 // Size of virtio_net_hdr structure
)
// tunVirtioReader wraps a file descriptor that has IFF_VNET_HDR enabled
// and strips the virtio header on reads, adds it on writes
type tunVirtioReader struct {
f *os.File
buf [virtioNetHdrLen + 65535]byte // Space for header + max packet
}
func (r *tunVirtioReader) Read(b []byte) (int, error) {
// Read into our buffer which has space for the virtio header
n, err := r.f.Read(r.buf[:])
if err != nil {
return 0, err
}
// Strip the virtio header (first 10 bytes)
if n < virtioNetHdrLen {
return 0, fmt.Errorf("packet too short: %d bytes", n)
}
// Copy payload (after header) to destination
copy(b, r.buf[virtioNetHdrLen:n])
return n - virtioNetHdrLen, nil
}
func (r *tunVirtioReader) Write(b []byte) (int, error) {
// Zero out the virtio header (no offload from userspace write)
for i := 0; i < virtioNetHdrLen; i++ {
r.buf[i] = 0
}
// Copy packet data after header
copy(r.buf[virtioNetHdrLen:], b)
// Write with header prepended
n, err := r.f.Write(r.buf[:virtioNetHdrLen+len(b)])
if err != nil {
return 0, err
}
// Return payload size (excluding header)
return n - virtioNetHdrLen, nil
}
func (r *tunVirtioReader) Close() error {
return r.f.Close()
}
// BatchRead reads multiple packets at once for performance
// This is not used for multiqueue readers as they use direct file I/O
// Returns number of packets read
func (r *tunVirtioReader) BatchRead(bufs [][]byte, sizes []int) (int, error) {
// For multiqueue file descriptors, we don't have the wireguard Device interface
// Fall back to single packet reads
// TODO: Could implement proper batching with unix.Recvmmsg
n, err := r.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// BatchSize returns the batch size for multiqueue readers
func (r *tunVirtioReader) BatchSize() int {
// Multiqueue readers use single packet mode for now
return 1
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
wgDev, name, err := wgtun.CreateUnmonitoredTUNFromFD(deviceFd)
if err != nil {
@@ -106,7 +178,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
}
var req ifReq
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR)
if multiqueue {
req.Flags |= unix.IFF_MULTI_QUEUE
}
@@ -122,6 +194,18 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
return nil, err
}
// Enable TCP and UDP offload (TSO/GRO) for performance
// This allows the kernel to handle segmentation/coalescing
const (
tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
)
offloads := tunTCPOffloads | tunUDPOffloads
if err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, offloads); err != nil {
// Log warning but don't fail - offload is optional
l.WithError(err).Warn("Failed to enable TUN offload (TSO/GRO), performance may be reduced")
}
file := os.NewFile(uintptr(fd), "/dev/net/tun")
// Create wireguard device from file descriptor
@@ -251,16 +335,18 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
}
var req ifReq
// MUST match the flags used in newTun
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
// MUST match the flags used in newTun - includes IFF_VNET_HDR
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR | unix.IFF_MULTI_QUEUE)
copy(req.Name[:], t.Device)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
unix.Close(fd)
return nil, err
}
file := os.NewFile(uintptr(fd), "/dev/net/tun")
return file, nil
// Wrap in virtio header handler
return &tunVirtioReader{f: file}, nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
@@ -268,7 +354,70 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
return r
}
func (t *tun) Read(b []byte) (int, error) {
if t.wgDevice != nil {
// Use wireguard device which handles virtio headers internally
bufs := [][]byte{b}
sizes := make([]int, 1)
n, err := t.wgDevice.Read(bufs, sizes, 0)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.EOF
}
return sizes[0], nil
}
// Fallback: direct read from file (shouldn't happen in normal operation)
return t.ReadWriteCloser.Read(b)
}
// BatchRead reads multiple packets at once for improved performance
// bufs: slice of buffers to read into
// sizes: slice that will be filled with packet sizes
// Returns number of packets read
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
if t.wgDevice != nil {
return t.wgDevice.Read(bufs, sizes, 0)
}
// Fallback: single packet read
n, err := t.ReadWriteCloser.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// BatchSize returns the optimal number of packets to read/write in a batch
func (t *tun) BatchSize() int {
if t.wgDevice != nil {
return t.wgDevice.BatchSize()
}
return 1
}
func (t *tun) Write(b []byte) (int, error) {
if t.wgDevice != nil {
// Use wireguard device which handles virtio headers internally
// Allocate buffer with space for virtio header
buf := make([]byte, virtioNetHdrLen+len(b))
copy(buf[virtioNetHdrLen:], b)
bufs := [][]byte{buf}
n, err := t.wgDevice.Write(bufs, virtioNetHdrLen)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.ErrShortWrite
}
return len(b), nil
}
// Fallback: direct write (shouldn't happen in normal operation)
var nn int
maximum := len(b)