mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 16:34:25 +01:00
more nonblocking
This commit is contained in:
@@ -70,72 +70,55 @@ const (
|
||||
virtioNetHdrLen = 10 // Size of virtio_net_hdr structure
|
||||
)
|
||||
|
||||
// tunVirtioReader wraps a file descriptor that has IFF_VNET_HDR enabled
|
||||
// and strips the virtio header on reads, adds it on writes
|
||||
type tunVirtioReader struct {
|
||||
f *os.File
|
||||
buf [virtioNetHdrLen + 65535]byte // Space for header + max packet
|
||||
// wgDeviceWrapper wraps a wireguard Device to implement io.ReadWriteCloser
|
||||
// This allows multiqueue readers to use the same wireguard Device batching as the main device
|
||||
type wgDeviceWrapper struct {
|
||||
dev wgtun.Device
|
||||
buf []byte // Reusable buffer for single packet reads
|
||||
}
|
||||
|
||||
func (r *tunVirtioReader) Read(b []byte) (int, error) {
|
||||
// Read into our buffer which has space for the virtio header
|
||||
n, err := r.f.Read(r.buf[:])
|
||||
func (w *wgDeviceWrapper) Read(b []byte) (int, error) {
|
||||
// Use wireguard Device's batch API for single packet
|
||||
bufs := [][]byte{b}
|
||||
sizes := make([]int, 1)
|
||||
n, err := w.dev.Read(bufs, sizes, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Strip the virtio header (first 10 bytes)
|
||||
if n < virtioNetHdrLen {
|
||||
return 0, fmt.Errorf("packet too short: %d bytes", n)
|
||||
if n == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
return sizes[0], nil
|
||||
}
|
||||
|
||||
// Copy payload (after header) to destination
|
||||
copy(b, r.buf[virtioNetHdrLen:n])
|
||||
return n - virtioNetHdrLen, nil
|
||||
}
|
||||
func (w *wgDeviceWrapper) Write(b []byte) (int, error) {
|
||||
// Allocate buffer with space for virtio header
|
||||
buf := make([]byte, virtioNetHdrLen+len(b))
|
||||
copy(buf[virtioNetHdrLen:], b)
|
||||
|
||||
func (r *tunVirtioReader) Write(b []byte) (int, error) {
|
||||
// Zero out the virtio header (no offload from userspace write)
|
||||
for i := 0; i < virtioNetHdrLen; i++ {
|
||||
r.buf[i] = 0
|
||||
}
|
||||
|
||||
// Copy packet data after header
|
||||
copy(r.buf[virtioNetHdrLen:], b)
|
||||
|
||||
// Write with header prepended
|
||||
n, err := r.f.Write(r.buf[:virtioNetHdrLen+len(b)])
|
||||
bufs := [][]byte{buf}
|
||||
n, err := w.dev.Write(bufs, virtioNetHdrLen)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Return payload size (excluding header)
|
||||
return n - virtioNetHdrLen, nil
|
||||
if n == 0 {
|
||||
return 0, io.ErrShortWrite
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (r *tunVirtioReader) Close() error {
|
||||
return r.f.Close()
|
||||
func (w *wgDeviceWrapper) Close() error {
|
||||
return w.dev.Close()
|
||||
}
|
||||
|
||||
// BatchRead reads multiple packets at once for performance
|
||||
// This is not used for multiqueue readers as they use direct file I/O
|
||||
// Returns number of packets read
|
||||
func (r *tunVirtioReader) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||
// For multiqueue file descriptors, we don't have the wireguard Device interface
|
||||
// Fall back to single packet reads
|
||||
// TODO: Could implement proper batching with unix.Recvmmsg
|
||||
n, err := r.Read(bufs[0])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
// BatchRead implements batching for multiqueue readers
|
||||
func (w *wgDeviceWrapper) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||
return w.dev.Read(bufs, sizes, 0)
|
||||
}
|
||||
|
||||
// BatchSize returns the batch size for multiqueue readers
|
||||
func (r *tunVirtioReader) BatchSize() int {
|
||||
// Multiqueue readers use single packet mode for now
|
||||
return 1
|
||||
// BatchSize returns the optimal batch size
|
||||
func (w *wgDeviceWrapper) BatchSize() int {
|
||||
return w.dev.BatchSize()
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
@@ -343,10 +326,29 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set nonblocking mode - CRITICAL for proper netpoller integration
|
||||
if err = unix.SetNonblock(fd, true); err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get MTU from main device
|
||||
mtu := t.MaxMTU
|
||||
if mtu == 0 {
|
||||
mtu = DefaultMTU
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||
|
||||
// Wrap in virtio header handler
|
||||
return &tunVirtioReader{f: file}, nil
|
||||
// Create wireguard Device from the file descriptor (just like the main device)
|
||||
wgDev, err := wgtun.CreateTUNFromFile(file, mtu)
|
||||
if err != nil {
|
||||
file.Close()
|
||||
return nil, fmt.Errorf("failed to create multiqueue TUN device: %w", err)
|
||||
}
|
||||
|
||||
// Return a wrapper that uses the wireguard Device for all I/O
|
||||
return &wgDeviceWrapper{dev: wgDev}, nil
|
||||
}
|
||||
|
||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
|
||||
Reference in New Issue
Block a user