diff --git a/interface.go b/interface.go index 82634df..8d17573 100644 --- a/interface.go +++ b/interface.go @@ -47,6 +47,7 @@ type InterfaceConfig struct { reQueryWait time.Duration ConntrackCacheTimeout time.Duration + batchSize int l *logrus.Logger } @@ -84,6 +85,7 @@ type Interface struct { version string conntrackCacheTimeout time.Duration + batchSize int writers []udp.Conn readers []io.ReadWriteCloser @@ -112,7 +114,7 @@ type EncWriter interface { // BatchReader is an interface for readers that support vectorized packet reading type BatchReader interface { - BatchRead() ([][]byte, []int, error) + BatchRead(buffers [][]byte, sizes []int) (int, error) } // BatchWriter is an interface for writers that support vectorized packet writing @@ -196,6 +198,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { relayManager: c.relayManager, connectionManager: c.connectionManager, conntrackCacheTimeout: c.ConntrackCacheTimeout, + batchSize: c.batchSize, metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)), messageMetrics: c.MessageMetrics, @@ -323,21 +326,28 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { // listenInBatch handles vectorized packet reading for improved performance func (f *Interface) listenInBatch(reader BatchReader, i int) error { - // Allocate per-packet state - fwPackets := make([]*firewall.Packet, 64) // Match batch size - outBuffers := make([][]byte, 64) - nbBuffers := make([][]byte, 64) + // Allocate per-packet state and buffers for batch reading + batchSize := f.batchSize + if batchSize <= 0 { + batchSize = 64 // Fallback to default if not configured + } + fwPackets := make([]*firewall.Packet, batchSize) + outBuffers := make([][]byte, batchSize) + nbBuffers := make([][]byte, batchSize) + packets := make([][]byte, batchSize) + sizes := make([]int, batchSize) - for j := 0; j < 64; j++ { + for j := 0; j < batchSize; j++ { fwPackets[j] = &firewall.Packet{} outBuffers[j] = make([]byte, mtu) - nbBuffers[j] = make([]byte, 12, 12) + nbBuffers[j] = make([]byte, 12) + packets[j] = make([]byte, mtu) } conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) for { - packets, sizes, err := reader.BatchRead() + n, err := reader.BatchRead(packets, sizes) if err != nil { if errors.Is(err, os.ErrClosed) && f.closed.Load() { return nil @@ -348,8 +358,8 @@ func (f *Interface) listenInBatch(reader BatchReader, i int) error { // Process each packet in the batch cache := conntrackCache.Get(f.l) - for idx := 0; idx < len(packets); idx++ { - if idx < len(sizes) && sizes[idx] > 0 { + for idx := 0; idx < n; idx++ { + if sizes[idx] > 0 { // Use modulo to reuse fw packet state if batch is larger than our pre-allocated state stateIdx := idx % len(fwPackets) f.consumeInsidePacket(packets[idx][:sizes[idx]], fwPackets[stateIdx], nbBuffers[stateIdx], outBuffers[stateIdx], i, cache) diff --git a/main.go b/main.go index eb296fb..f5462bb 100644 --- a/main.go +++ b/main.go @@ -242,6 +242,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg relayManager: NewRelayManager(ctx, l, hostMap, c), punchy: punchy, ConntrackCacheTimeout: conntrackCacheTimeout, + batchSize: c.GetInt("tun.batch_size", 64), l: l, } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 47ad7b4..d29d0ff 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -258,7 +258,6 @@ func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) { return &wgTunReader{ parent: t, tunDevice: t.tunDevice, - batchSize: 64, offset: 0, l: t.l, }, nil diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 913c6f8..6b40af4 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -214,7 +214,6 @@ func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) { return &wgTunReader{ parent: t, tunDevice: t.tunDevice, - batchSize: 64, offset: 0, l: t.l, }, nil diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index c19a74b..f50bb9d 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -323,7 +323,6 @@ func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) { return &wgTunReader{ parent: t, tunDevice: t.tunDevice, - batchSize: 64, // Default batch size offset: 0, l: t.l, }, nil diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 9227b8a..93cda97 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -196,7 +196,6 @@ func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) { return &wgTunReader{ parent: t, tunDevice: t.tunDevice, - batchSize: 64, offset: 0, l: t.l, }, nil diff --git a/overlay/tun_wg.go b/overlay/tun_wg.go index 06017e3..459feb9 100644 --- a/overlay/tun_wg.go +++ b/overlay/tun_wg.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "net/netip" - "sync" "sync/atomic" "github.com/gaissmai/bart" @@ -37,7 +36,7 @@ type wgTun struct { // BatchReader interface for readers that support vectorized I/O type BatchReader interface { - BatchRead() ([][]byte, []int, error) + BatchRead(buffers [][]byte, sizes []int) (int, error) } // BatchWriter interface for writers that support vectorized I/O @@ -47,24 +46,12 @@ type BatchWriter interface { // wgTunReader wraps a single TUN queue for multi-queue support type wgTunReader struct { - parent *wgTun - tunDevice wgtun.Device - buffers [][]byte - sizes []int - offset int - batchSize int - l *logrus.Logger + parent *wgTun + tunDevice wgtun.Device + offset int + l *logrus.Logger } -var ( - bufferPool = sync.Pool{ - New: func() interface{} { - buf := make([]byte, 9001) // MTU size - return &buf - }, - } -) - func (t *wgTun) Networks() []netip.Prefix { return t.vpnNetworks } @@ -210,23 +197,9 @@ func (t *wgTun) reload(c *config.C, initial bool) error { } // BatchRead reads multiple packets from the TUN device using vectorized I/O -func (r *wgTunReader) BatchRead() ([][]byte, []int, error) { - // Reuse buffers from pool - if len(r.buffers) == 0 { - r.buffers = make([][]byte, r.batchSize) - r.sizes = make([]int, r.batchSize) - for i := 0; i < r.batchSize; i++ { - buf := bufferPool.Get().(*[]byte) - r.buffers[i] = (*buf)[:cap(*buf)] - } - } - - n, err := r.tunDevice.Read(r.buffers, r.sizes, r.offset) - if err != nil { - return nil, nil, err - } - - return r.buffers[:n], r.sizes[:n], nil +// The caller provides buffers and sizes slices, and this function returns the number of packets read. +func (r *wgTunReader) BatchRead(buffers [][]byte, sizes []int) (int, error) { + return r.tunDevice.Read(buffers, sizes, r.offset) } // Read implements io.Reader for wgTunReader (single packet for compatibility) @@ -262,16 +235,6 @@ func (r *wgTunReader) BatchWrite(packets [][]byte) (int, error) { } func (r *wgTunReader) Close() error { - // Return buffers to pool - for i := range r.buffers { - if r.buffers[i] != nil { - bufferPool.Put(&r.buffers[i]) - r.buffers[i] = nil - } - } - r.buffers = nil - r.sizes = nil - if r.tunDevice != nil { return r.tunDevice.Close() } diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 130c856..cc77fb5 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -175,7 +175,6 @@ func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) { return &wgTunReader{ parent: t, tunDevice: t.tunDevice, - batchSize: 64, offset: 0, l: t.l, }, nil