ReadBatch

This commit is contained in:
JackDoan
2026-04-17 11:05:34 -05:00
parent 5241bf6d16
commit c05fa793a6
17 changed files with 235 additions and 56 deletions

View File

@@ -10,6 +10,7 @@ import (
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
@@ -52,7 +53,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &test.NoopTun{},
inside: &overlay.NoopTun{},
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
@@ -135,7 +136,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &test.NoopTun{},
inside: &overlay.NoopTun{},
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
@@ -220,7 +221,7 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &test.NoopTun{},
inside: &overlay.NoopTun{},
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
@@ -347,7 +348,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &test.NoopTun{},
inside: &overlay.NoopTun{},
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,

View File

@@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"io"
"net/netip"
"sync"
"sync/atomic"
@@ -86,7 +85,7 @@ type Interface struct {
conntrackCacheTimeout time.Duration
writers []udp.Conn
readers []io.ReadWriteCloser
readers []overlay.Queue
wg sync.WaitGroup
// fatalErr holds the first unexpected reader error that caused shutdown.
@@ -184,7 +183,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
routines: c.routines,
version: c.version,
writers: make([]udp.Conn, c.routines),
readers: make([]io.ReadWriteCloser, c.routines),
readers: make([]overlay.Queue, c.routines),
myVpnNetworks: cs.myVpnNetworks,
myVpnNetworksTable: cs.myVpnNetworksTable,
myVpnAddrs: cs.myVpnAddrs,
@@ -239,7 +238,7 @@ func (f *Interface) activate() error {
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
// Prepare n tun queues
var reader io.ReadWriteCloser = f.inside
var reader overlay.Queue = f.inside
for i := 0; i < f.routines; i++ {
if i > 0 {
reader, err = f.inside.NewMultiQueueReader()
@@ -321,8 +320,7 @@ func (f *Interface) listenOut(i int) {
f.l.Infof("underlay reader %v is done", i)
}
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
packet := make([]byte, mtu)
func (f *Interface) listenIn(reader overlay.Queue, i int) {
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
@@ -330,7 +328,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
for {
n, err := reader.Read(packet)
batch, err := reader.ReadBatch()
if err != nil {
if !f.closed.Load() {
f.l.WithError(err).WithField("reader", i).Error("Error while reading outbound packet, closing")
@@ -339,7 +337,9 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
break
}
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
for _, pkt := range batch {
f.consumeInsidePacket(pkt, fwPacket, nb, out, i, conntrackCache.Get(f.l))
}
}
f.l.Infof("overlay reader %v is done", i)

View File

@@ -7,12 +7,26 @@ import (
"github.com/slackhq/nebula/routing"
)
type Device interface {
// defaultBatchBufSize is the per-Queue scratch size for ReadBatch on backends
// that don't do TSO segmentation. 65535 covers any single IP packet.
const defaultBatchBufSize = 65535
// Queue is a readable/writable tun queue. ReadBatch returns one or more
// packets; the returned slices are borrowed from the queue's internal buffer
// and are only valid until the next ReadBatch / Read / Close on this Queue.
// Callers must encrypt or copy each slice before the next call. Not safe for
// concurrent use — one goroutine per Queue.
type Queue interface {
io.ReadWriteCloser
ReadBatch() ([][]byte, error)
}
type Device interface {
Queue
Activate() error
Networks() []netip.Prefix
Name() string
RoutesFor(netip.Addr) routing.Gateways
SupportsMultiqueue() bool
NewMultiQueueReader() (io.ReadWriteCloser, error)
NewMultiQueueReader() (Queue, error)
}

View File

@@ -1,8 +1,7 @@
package test
package overlay
import (
"errors"
"io"
"net/netip"
"github.com/slackhq/nebula/routing"
@@ -30,6 +29,10 @@ func (NoopTun) Read([]byte) (int, error) {
return 0, nil
}
func (NoopTun) ReadBatch() ([][]byte, error) {
return nil, nil
}
func (NoopTun) Write([]byte) (int, error) {
return 0, nil
}
@@ -38,7 +41,7 @@ func (NoopTun) SupportsMultiqueue() bool {
return false
}
func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (NoopTun) NewMultiQueueReader() (Queue, error) {
return nil, errors.New("unsupported")
}

View File

@@ -24,6 +24,21 @@ type tun struct {
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
readBuf []byte
batchRet [1][]byte
}
func (t *tun) ReadBatch() ([][]byte, error) {
if t.readBuf == nil {
t.readBuf = make([]byte, defaultBatchBufSize)
}
n, err := t.Read(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
return t.batchRet[:], nil
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
@@ -99,6 +114,6 @@ func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) NewMultiQueueReader() (Queue, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
}

View File

@@ -34,6 +34,9 @@ type tun struct {
// cache out buffer since we need to prepend 4 bytes for tun metadata
out []byte
readBuf []byte
batchRet [1][]byte
}
type ifReq struct {
@@ -512,6 +515,18 @@ func (t *tun) Read(to []byte) (int, error) {
return n - 4, err
}
func (t *tun) ReadBatch() ([][]byte, error) {
if t.readBuf == nil {
t.readBuf = make([]byte, defaultBatchBufSize)
}
n, err := t.Read(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
return t.batchRet[:], nil
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
buf := t.out
@@ -553,6 +568,6 @@ func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) NewMultiQueueReader() (Queue, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
}

View File

@@ -20,6 +20,21 @@ type disabledTun struct {
tx metrics.Counter
rx metrics.Counter
l *logrus.Logger
readBuf []byte
batchRet [1][]byte
}
func (t *disabledTun) ReadBatch() ([][]byte, error) {
if t.readBuf == nil {
t.readBuf = make([]byte, defaultBatchBufSize)
}
n, err := t.Read(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
return t.batchRet[:], nil
}
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
@@ -109,7 +124,7 @@ func (t *disabledTun) SupportsMultiqueue() bool {
return true
}
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *disabledTun) NewMultiQueueReader() (Queue, error) {
return t, nil
}

View File

@@ -7,7 +7,6 @@ import (
"bytes"
"errors"
"fmt"
"io"
"io/fs"
"net/netip"
"sync/atomic"
@@ -94,6 +93,21 @@ type tun struct {
linkAddr *netroute.LinkAddr
l *logrus.Logger
devFd int
readBuf []byte
batchRet [1][]byte
}
func (t *tun) ReadBatch() ([][]byte, error) {
if t.readBuf == nil {
t.readBuf = make([]byte, defaultBatchBufSize)
}
n, err := t.Read(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
return t.batchRet[:], nil
}
func (t *tun) Read(to []byte) (int, error) {
@@ -454,7 +468,7 @@ func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) NewMultiQueueReader() (Queue, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
}

View File

@@ -26,6 +26,21 @@ type tun struct {
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
readBuf []byte
batchRet [1][]byte
}
func (t *tun) ReadBatch() ([][]byte, error) {
if t.readBuf == nil {
t.readBuf = make([]byte, defaultBatchBufSize)
}
n, err := t.Read(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
return t.batchRet[:], nil
}
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
@@ -155,6 +170,6 @@ func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) NewMultiQueueReader() (Queue, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
}

View File

@@ -62,6 +62,7 @@ func (r *tunFile) newFriend(fd int) (*tunFile, error) {
fd: fd,
shutdownFd: r.shutdownFd,
vnetHdr: r.vnetHdr,
readBuf: make([]byte, tunReadBufSize),
readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
@@ -72,7 +73,6 @@ func (r *tunFile) newFriend(fd int) (*tunFile, error) {
},
}
if r.vnetHdr {
out.readBuf = make([]byte, tunReadBufSize)
out.segBuf = make([]byte, tunSegBufSize)
out.writeIovs[0].Base = &zeroVnetHdr[0]
out.writeIovs[0].SetLen(virtioNetHdrLen)
@@ -95,6 +95,7 @@ func newTunFd(fd int, vnetHdr bool) (*tunFile, error) {
shutdownFd: shutdownFd,
lastOne: true,
vnetHdr: vnetHdr,
readBuf: make([]byte, tunReadBufSize),
readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
@@ -105,7 +106,6 @@ func newTunFd(fd int, vnetHdr bool) (*tunFile, error) {
},
}
if vnetHdr {
out.readBuf = make([]byte, tunReadBufSize)
out.segBuf = make([]byte, tunSegBufSize)
out.writeIovs[0].Base = &zeroVnetHdr[0]
out.writeIovs[0].SetLen(virtioNetHdrLen)
@@ -181,11 +181,39 @@ func (r *tunFile) readRaw(buf []byte) (int, error) {
}
}
func (r *tunFile) Read(buf []byte) (int, error) {
// ReadBatch reads one superpacket from the tun and returns the resulting
// packets. Slices point into the tunFile's internal buffers and are only
// valid until the next ReadBatch / Read / Close on this Queue.
func (r *tunFile) ReadBatch() ([][]byte, error) {
r.pending = r.pending[:0]
r.pendingIdx = 0
for {
n, err := r.readRaw(r.readBuf)
if err != nil {
return nil, err
}
if !r.vnetHdr {
return r.readRaw(buf)
r.pending = append(r.pending, r.readBuf[:n])
return r.pending, nil
}
if n < virtioNetHdrLen {
return nil, fmt.Errorf("short tun read: %d < %d", n, virtioNetHdrLen)
}
var hdr virtioNetHdr
hdr.decode(r.readBuf[:virtioNetHdrLen])
if err := segmentInto(r.readBuf[virtioNetHdrLen:n], hdr, &r.pending, r.segBuf); err != nil {
// Drop and read again — a bad packet should not kill the reader.
continue
}
return r.pending, nil
}
}
// Read drains segments produced by the last ReadBatch one at a time; when the
// batch is exhausted it fetches a fresh one. Kept for io.Reader compatibility;
// batch-aware callers should use ReadBatch directly.
func (r *tunFile) Read(buf []byte) (int, error) {
for {
if r.pendingIdx < len(r.pending) {
seg := r.pending[r.pendingIdx]
@@ -195,22 +223,9 @@ func (r *tunFile) Read(buf []byte) (int, error) {
}
return copy(buf, seg), nil
}
r.pending = r.pending[:0]
r.pendingIdx = 0
n, err := r.readRaw(r.readBuf)
if err != nil {
if _, err := r.ReadBatch(); err != nil {
return 0, err
}
if n < virtioNetHdrLen {
return 0, fmt.Errorf("short tun read: %d < %d", n, virtioNetHdrLen)
}
var hdr virtioNetHdr
hdr.decode(r.readBuf[:virtioNetHdrLen])
if err := segmentInto(r.readBuf[virtioNetHdrLen:n], hdr, &r.pending, r.segBuf); err != nil {
// Drop and read again — a bad packet should not kill the reader.
continue
}
}
}
@@ -540,7 +555,7 @@ func (t *tun) SupportsMultiqueue() bool {
return true
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) NewMultiQueueReader() (Queue, error) {
t.closeLock.Lock()
defer t.closeLock.Unlock()

View File

@@ -1,5 +1,5 @@
//go:build !android && !e2e_testing
// +build !android,!e2e_testing
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
package overlay

View File

@@ -1,5 +1,5 @@
//go:build !android && !e2e_testing
// +build !android,!e2e_testing
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
package overlay

View File

@@ -6,7 +6,6 @@ package overlay
import (
"errors"
"fmt"
"io"
"net/netip"
"os"
"regexp"
@@ -66,6 +65,21 @@ type tun struct {
l *logrus.Logger
f *os.File
fd int
readBuf []byte
batchRet [1][]byte
}
func (t *tun) ReadBatch() ([][]byte, error) {
if t.readBuf == nil {
t.readBuf = make([]byte, defaultBatchBufSize)
}
n, err := t.Read(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
return t.batchRet[:], nil
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
@@ -394,7 +408,7 @@ func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) NewMultiQueueReader() (Queue, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
}

View File

@@ -6,7 +6,6 @@ package overlay
import (
"errors"
"fmt"
"io"
"net/netip"
"os"
"regexp"
@@ -59,6 +58,21 @@ type tun struct {
fd int
// cache out buffer since we need to prepend 4 bytes for tun metadata
out []byte
readBuf []byte
batchRet [1][]byte
}
func (t *tun) ReadBatch() ([][]byte, error) {
if t.readBuf == nil {
t.readBuf = make([]byte, defaultBatchBufSize)
}
n, err := t.Read(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
return t.batchRet[:], nil
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
@@ -314,7 +328,7 @@ func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) NewMultiQueueReader() (Queue, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
}

View File

@@ -26,6 +26,21 @@ type TestTun struct {
closed atomic.Bool
rxPackets chan []byte // Packets to receive into nebula
TxPackets chan []byte // Packets transmitted outside by nebula
readBuf []byte
batchRet [1][]byte
}
func (t *TestTun) ReadBatch() ([][]byte, error) {
if t.readBuf == nil {
t.readBuf = make([]byte, defaultBatchBufSize)
}
n, err := t.Read(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
return t.batchRet[:], nil
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
@@ -136,6 +151,6 @@ func (t *TestTun) SupportsMultiqueue() bool {
return false
}
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *TestTun) NewMultiQueueReader() (Queue, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented")
}

View File

@@ -6,7 +6,6 @@ package overlay
import (
"crypto"
"fmt"
"io"
"net/netip"
"os"
"path/filepath"
@@ -36,6 +35,21 @@ type winTun struct {
l *logrus.Logger
tun *wintun.NativeTun
readBuf []byte
batchRet [1][]byte
}
func (t *winTun) ReadBatch() ([][]byte, error) {
if t.readBuf == nil {
t.readBuf = make([]byte, defaultBatchBufSize)
}
n, err := t.Read(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
return t.batchRet[:], nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) {
@@ -241,7 +255,7 @@ func (t *winTun) SupportsMultiqueue() bool {
return false
}
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *winTun) NewMultiQueueReader() (Queue, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
}

View File

@@ -34,6 +34,21 @@ type UserDevice struct {
inboundReader *io.PipeReader
inboundWriter *io.PipeWriter
readBuf []byte
batchRet [1][]byte
}
func (d *UserDevice) ReadBatch() ([][]byte, error) {
if d.readBuf == nil {
d.readBuf = make([]byte, defaultBatchBufSize)
}
n, err := d.Read(d.readBuf)
if err != nil {
return nil, err
}
d.batchRet[0] = d.readBuf[:n]
return d.batchRet[:], nil
}
func (d *UserDevice) Activate() error {
@@ -50,7 +65,7 @@ func (d *UserDevice) SupportsMultiqueue() bool {
return true
}
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (d *UserDevice) NewMultiQueueReader() (Queue, error) {
return d, nil
}