fix BatchRead interface & make batch size configurable

This commit is contained in:
Jay Wren
2025-10-23 14:37:26 -04:00
parent 0827a6f1c5
commit 8281b1699f
8 changed files with 29 additions and 60 deletions

View File

@@ -47,6 +47,7 @@ type InterfaceConfig struct {
reQueryWait time.Duration reQueryWait time.Duration
ConntrackCacheTimeout time.Duration ConntrackCacheTimeout time.Duration
batchSize int
l *logrus.Logger l *logrus.Logger
} }
@@ -84,6 +85,7 @@ type Interface struct {
version string version string
conntrackCacheTimeout time.Duration conntrackCacheTimeout time.Duration
batchSize int
writers []udp.Conn writers []udp.Conn
readers []io.ReadWriteCloser readers []io.ReadWriteCloser
@@ -112,7 +114,7 @@ type EncWriter interface {
// BatchReader is an interface for readers that support vectorized packet reading // BatchReader is an interface for readers that support vectorized packet reading
type BatchReader interface { type BatchReader interface {
BatchRead() ([][]byte, []int, error) BatchRead(buffers [][]byte, sizes []int) (int, error)
} }
// BatchWriter is an interface for writers that support vectorized packet writing // BatchWriter is an interface for writers that support vectorized packet writing
@@ -196,6 +198,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
relayManager: c.relayManager, relayManager: c.relayManager,
connectionManager: c.connectionManager, connectionManager: c.connectionManager,
conntrackCacheTimeout: c.ConntrackCacheTimeout, conntrackCacheTimeout: c.ConntrackCacheTimeout,
batchSize: c.batchSize,
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)), metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
messageMetrics: c.MessageMetrics, messageMetrics: c.MessageMetrics,
@@ -323,21 +326,28 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
// listenInBatch handles vectorized packet reading for improved performance // listenInBatch handles vectorized packet reading for improved performance
func (f *Interface) listenInBatch(reader BatchReader, i int) error { func (f *Interface) listenInBatch(reader BatchReader, i int) error {
// Allocate per-packet state // Allocate per-packet state and buffers for batch reading
fwPackets := make([]*firewall.Packet, 64) // Match batch size batchSize := f.batchSize
outBuffers := make([][]byte, 64) if batchSize <= 0 {
nbBuffers := make([][]byte, 64) batchSize = 64 // Fallback to default if not configured
}
fwPackets := make([]*firewall.Packet, batchSize)
outBuffers := make([][]byte, batchSize)
nbBuffers := make([][]byte, batchSize)
packets := make([][]byte, batchSize)
sizes := make([]int, batchSize)
for j := 0; j < 64; j++ { for j := 0; j < batchSize; j++ {
fwPackets[j] = &firewall.Packet{} fwPackets[j] = &firewall.Packet{}
outBuffers[j] = make([]byte, mtu) outBuffers[j] = make([]byte, mtu)
nbBuffers[j] = make([]byte, 12, 12) nbBuffers[j] = make([]byte, 12)
packets[j] = make([]byte, mtu)
} }
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
for { for {
packets, sizes, err := reader.BatchRead() n, err := reader.BatchRead(packets, sizes)
if err != nil { if err != nil {
if errors.Is(err, os.ErrClosed) && f.closed.Load() { if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return nil return nil
@@ -348,8 +358,8 @@ func (f *Interface) listenInBatch(reader BatchReader, i int) error {
// Process each packet in the batch // Process each packet in the batch
cache := conntrackCache.Get(f.l) cache := conntrackCache.Get(f.l)
for idx := 0; idx < len(packets); idx++ { for idx := 0; idx < n; idx++ {
if idx < len(sizes) && sizes[idx] > 0 { if sizes[idx] > 0 {
// Use modulo to reuse fw packet state if batch is larger than our pre-allocated state // Use modulo to reuse fw packet state if batch is larger than our pre-allocated state
stateIdx := idx % len(fwPackets) stateIdx := idx % len(fwPackets)
f.consumeInsidePacket(packets[idx][:sizes[idx]], fwPackets[stateIdx], nbBuffers[stateIdx], outBuffers[stateIdx], i, cache) f.consumeInsidePacket(packets[idx][:sizes[idx]], fwPackets[stateIdx], nbBuffers[stateIdx], outBuffers[stateIdx], i, cache)

View File

@@ -242,6 +242,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
relayManager: NewRelayManager(ctx, l, hostMap, c), relayManager: NewRelayManager(ctx, l, hostMap, c),
punchy: punchy, punchy: punchy,
ConntrackCacheTimeout: conntrackCacheTimeout, ConntrackCacheTimeout: conntrackCacheTimeout,
batchSize: c.GetInt("tun.batch_size", 64),
l: l, l: l,
} }

View File

@@ -258,7 +258,6 @@ func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
return &wgTunReader{ return &wgTunReader{
parent: t, parent: t,
tunDevice: t.tunDevice, tunDevice: t.tunDevice,
batchSize: 64,
offset: 0, offset: 0,
l: t.l, l: t.l,
}, nil }, nil

View File

@@ -214,7 +214,6 @@ func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
return &wgTunReader{ return &wgTunReader{
parent: t, parent: t,
tunDevice: t.tunDevice, tunDevice: t.tunDevice,
batchSize: 64,
offset: 0, offset: 0,
l: t.l, l: t.l,
}, nil }, nil

View File

@@ -323,7 +323,6 @@ func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
return &wgTunReader{ return &wgTunReader{
parent: t, parent: t,
tunDevice: t.tunDevice, tunDevice: t.tunDevice,
batchSize: 64, // Default batch size
offset: 0, offset: 0,
l: t.l, l: t.l,
}, nil }, nil

View File

@@ -196,7 +196,6 @@ func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
return &wgTunReader{ return &wgTunReader{
parent: t, parent: t,
tunDevice: t.tunDevice, tunDevice: t.tunDevice,
batchSize: 64,
offset: 0, offset: 0,
l: t.l, l: t.l,
}, nil }, nil

View File

@@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"io" "io"
"net/netip" "net/netip"
"sync"
"sync/atomic" "sync/atomic"
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
@@ -37,7 +36,7 @@ type wgTun struct {
// BatchReader interface for readers that support vectorized I/O // BatchReader interface for readers that support vectorized I/O
type BatchReader interface { type BatchReader interface {
BatchRead() ([][]byte, []int, error) BatchRead(buffers [][]byte, sizes []int) (int, error)
} }
// BatchWriter interface for writers that support vectorized I/O // BatchWriter interface for writers that support vectorized I/O
@@ -47,24 +46,12 @@ type BatchWriter interface {
// wgTunReader wraps a single TUN queue for multi-queue support // wgTunReader wraps a single TUN queue for multi-queue support
type wgTunReader struct { type wgTunReader struct {
parent *wgTun parent *wgTun
tunDevice wgtun.Device tunDevice wgtun.Device
buffers [][]byte offset int
sizes []int l *logrus.Logger
offset int
batchSize int
l *logrus.Logger
} }
var (
bufferPool = sync.Pool{
New: func() interface{} {
buf := make([]byte, 9001) // MTU size
return &buf
},
}
)
func (t *wgTun) Networks() []netip.Prefix { func (t *wgTun) Networks() []netip.Prefix {
return t.vpnNetworks return t.vpnNetworks
} }
@@ -210,23 +197,9 @@ func (t *wgTun) reload(c *config.C, initial bool) error {
} }
// BatchRead reads multiple packets from the TUN device using vectorized I/O // BatchRead reads multiple packets from the TUN device using vectorized I/O
func (r *wgTunReader) BatchRead() ([][]byte, []int, error) { // The caller provides buffers and sizes slices, and this function returns the number of packets read.
// Reuse buffers from pool func (r *wgTunReader) BatchRead(buffers [][]byte, sizes []int) (int, error) {
if len(r.buffers) == 0 { return r.tunDevice.Read(buffers, sizes, r.offset)
r.buffers = make([][]byte, r.batchSize)
r.sizes = make([]int, r.batchSize)
for i := 0; i < r.batchSize; i++ {
buf := bufferPool.Get().(*[]byte)
r.buffers[i] = (*buf)[:cap(*buf)]
}
}
n, err := r.tunDevice.Read(r.buffers, r.sizes, r.offset)
if err != nil {
return nil, nil, err
}
return r.buffers[:n], r.sizes[:n], nil
} }
// Read implements io.Reader for wgTunReader (single packet for compatibility) // Read implements io.Reader for wgTunReader (single packet for compatibility)
@@ -262,16 +235,6 @@ func (r *wgTunReader) BatchWrite(packets [][]byte) (int, error) {
} }
func (r *wgTunReader) Close() error { func (r *wgTunReader) Close() error {
// Return buffers to pool
for i := range r.buffers {
if r.buffers[i] != nil {
bufferPool.Put(&r.buffers[i])
r.buffers[i] = nil
}
}
r.buffers = nil
r.sizes = nil
if r.tunDevice != nil { if r.tunDevice != nil {
return r.tunDevice.Close() return r.tunDevice.Close()
} }

View File

@@ -175,7 +175,6 @@ func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
return &wgTunReader{ return &wgTunReader{
parent: t, parent: t,
tunDevice: t.tunDevice, tunDevice: t.tunDevice,
batchSize: 64,
offset: 0, offset: 0,
l: t.l, l: t.l,
}, nil }, nil