diff --git a/interface.go b/interface.go index 082906d..7e8dd4c 100644 --- a/interface.go +++ b/interface.go @@ -276,9 +276,26 @@ func (f *Interface) listenOut(i int) { }) } +// BatchReader is an interface for devices that support reading multiple packets at once +type BatchReader interface { + BatchRead(bufs [][]byte, sizes []int) (int, error) + BatchSize() int +} + func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { runtime.LockOSThread() + // Check if reader supports batching + batchReader, supportsBatching := reader.(BatchReader) + + if supportsBatching { + f.listenInBatch(reader, batchReader, i) + } else { + f.listenInSingle(reader, i) + } +} + +func (f *Interface) listenInSingle(reader io.ReadWriteCloser, i int) { packet := make([]byte, mtu) out := make([]byte, mtu) fwPacket := &firewall.Packet{} @@ -302,6 +319,42 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { } } +func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchReader, i int) { + batchSize := batchReader.BatchSize() + + // Allocate buffers for batch reading + bufs := make([][]byte, batchSize) + for idx := range bufs { + bufs[idx] = make([]byte, mtu) + } + sizes := make([]int, batchSize) + + // Per-packet state (reused across batches) + out := make([]byte, mtu) + fwPacket := &firewall.Packet{} + nb := make([]byte, 12, 12) + + conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) + + for { + n, err := batchReader.BatchRead(bufs, sizes) + if err != nil { + if errors.Is(err, os.ErrClosed) && f.closed.Load() { + return + } + + f.l.WithError(err).Error("Error while batch reading outbound packets") + // This only seems to happen when something fatal happens to the fd, so exit. + os.Exit(2) + } + + // Process each packet in the batch + for j := 0; j < n; j++ { + f.consumeInsidePacket(bufs[j][:sizes[j]], fwPacket, nb, out, i, conntrackCache.Get(f.l)) + } + } +} + func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { c.RegisterReloadCallback(f.reloadFirewall) c.RegisterReloadCallback(f.reloadSendRecvError) diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 9df4581..882a7e7 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -66,6 +66,78 @@ type ifreqQLEN struct { pad [8]byte } +const ( + virtioNetHdrLen = 10 // Size of virtio_net_hdr structure +) + +// tunVirtioReader wraps a file descriptor that has IFF_VNET_HDR enabled +// and strips the virtio header on reads, adds it on writes +type tunVirtioReader struct { + f *os.File + buf [virtioNetHdrLen + 65535]byte // Space for header + max packet +} + +func (r *tunVirtioReader) Read(b []byte) (int, error) { + // Read into our buffer which has space for the virtio header + n, err := r.f.Read(r.buf[:]) + if err != nil { + return 0, err + } + + // Strip the virtio header (first 10 bytes) + if n < virtioNetHdrLen { + return 0, fmt.Errorf("packet too short: %d bytes", n) + } + + // Copy payload (after header) to destination + copy(b, r.buf[virtioNetHdrLen:n]) + return n - virtioNetHdrLen, nil +} + +func (r *tunVirtioReader) Write(b []byte) (int, error) { + // Zero out the virtio header (no offload from userspace write) + for i := 0; i < virtioNetHdrLen; i++ { + r.buf[i] = 0 + } + + // Copy packet data after header + copy(r.buf[virtioNetHdrLen:], b) + + // Write with header prepended + n, err := r.f.Write(r.buf[:virtioNetHdrLen+len(b)]) + if err != nil { + return 0, err + } + + // Return payload size (excluding header) + return n - virtioNetHdrLen, nil +} + +func (r *tunVirtioReader) Close() error { + return r.f.Close() +} + +// BatchRead reads multiple packets at once for performance +// This is not used for multiqueue readers as they use direct file I/O +// Returns number of packets read +func (r *tunVirtioReader) BatchRead(bufs [][]byte, sizes []int) (int, error) { + // For multiqueue file descriptors, we don't have the wireguard Device interface + // Fall back to single packet reads + // TODO: Could implement proper batching with unix.Recvmmsg + n, err := r.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +// BatchSize returns the batch size for multiqueue readers +func (r *tunVirtioReader) BatchSize() int { + // Multiqueue readers use single packet mode for now + return 1 +} + func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { wgDev, name, err := wgtun.CreateUnmonitoredTUNFromFD(deviceFd) if err != nil { @@ -106,7 +178,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu } var req ifReq - req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI) + req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR) if multiqueue { req.Flags |= unix.IFF_MULTI_QUEUE } @@ -122,6 +194,18 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu return nil, err } + // Enable TCP and UDP offload (TSO/GRO) for performance + // This allows the kernel to handle segmentation/coalescing + const ( + tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 + tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6 + ) + offloads := tunTCPOffloads | tunUDPOffloads + if err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, offloads); err != nil { + // Log warning but don't fail - offload is optional + l.WithError(err).Warn("Failed to enable TUN offload (TSO/GRO), performance may be reduced") + } + file := os.NewFile(uintptr(fd), "/dev/net/tun") // Create wireguard device from file descriptor @@ -251,16 +335,18 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { } var req ifReq - // MUST match the flags used in newTun - req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) + // MUST match the flags used in newTun - includes IFF_VNET_HDR + req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR | unix.IFF_MULTI_QUEUE) copy(req.Name[:], t.Device) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + unix.Close(fd) return nil, err } file := os.NewFile(uintptr(fd), "/dev/net/tun") - return file, nil + // Wrap in virtio header handler + return &tunVirtioReader{f: file}, nil } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { @@ -268,7 +354,70 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { return r } +func (t *tun) Read(b []byte) (int, error) { + if t.wgDevice != nil { + // Use wireguard device which handles virtio headers internally + bufs := [][]byte{b} + sizes := make([]int, 1) + n, err := t.wgDevice.Read(bufs, sizes, 0) + if err != nil { + return 0, err + } + if n == 0 { + return 0, io.EOF + } + return sizes[0], nil + } + + // Fallback: direct read from file (shouldn't happen in normal operation) + return t.ReadWriteCloser.Read(b) +} + +// BatchRead reads multiple packets at once for improved performance +// bufs: slice of buffers to read into +// sizes: slice that will be filled with packet sizes +// Returns number of packets read +func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) { + if t.wgDevice != nil { + return t.wgDevice.Read(bufs, sizes, 0) + } + + // Fallback: single packet read + n, err := t.ReadWriteCloser.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +// BatchSize returns the optimal number of packets to read/write in a batch +func (t *tun) BatchSize() int { + if t.wgDevice != nil { + return t.wgDevice.BatchSize() + } + return 1 +} + func (t *tun) Write(b []byte) (int, error) { + if t.wgDevice != nil { + // Use wireguard device which handles virtio headers internally + // Allocate buffer with space for virtio header + buf := make([]byte, virtioNetHdrLen+len(b)) + copy(buf[virtioNetHdrLen:], b) + + bufs := [][]byte{buf} + n, err := t.wgDevice.Write(bufs, virtioNetHdrLen) + if err != nil { + return 0, err + } + if n == 0 { + return 0, io.ErrShortWrite + } + return len(b), nil + } + + // Fallback: direct write (shouldn't happen in normal operation) var nn int maximum := len(b)