ReadBatch

This commit is contained in:
JackDoan
2026-04-17 11:05:34 -05:00
parent 5241bf6d16
commit c05fa793a6
17 changed files with 235 additions and 56 deletions

View File

@@ -10,6 +10,7 @@ import (
"github.com/flynn/noise" "github.com/flynn/noise"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -52,7 +53,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
lh := newTestLighthouse() lh := newTestLighthouse()
ifce := &Interface{ ifce := &Interface{
hostMap: hostMap, hostMap: hostMap,
inside: &test.NoopTun{}, inside: &overlay.NoopTun{},
outside: &udp.NoopConn{}, outside: &udp.NoopConn{},
firewall: &Firewall{}, firewall: &Firewall{},
lightHouse: lh, lightHouse: lh,
@@ -135,7 +136,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
lh := newTestLighthouse() lh := newTestLighthouse()
ifce := &Interface{ ifce := &Interface{
hostMap: hostMap, hostMap: hostMap,
inside: &test.NoopTun{}, inside: &overlay.NoopTun{},
outside: &udp.NoopConn{}, outside: &udp.NoopConn{},
firewall: &Firewall{}, firewall: &Firewall{},
lightHouse: lh, lightHouse: lh,
@@ -220,7 +221,7 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
lh := newTestLighthouse() lh := newTestLighthouse()
ifce := &Interface{ ifce := &Interface{
hostMap: hostMap, hostMap: hostMap,
inside: &test.NoopTun{}, inside: &overlay.NoopTun{},
outside: &udp.NoopConn{}, outside: &udp.NoopConn{},
firewall: &Firewall{}, firewall: &Firewall{},
lightHouse: lh, lightHouse: lh,
@@ -347,7 +348,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
lh := newTestLighthouse() lh := newTestLighthouse()
ifce := &Interface{ ifce := &Interface{
hostMap: hostMap, hostMap: hostMap,
inside: &test.NoopTun{}, inside: &overlay.NoopTun{},
outside: &udp.NoopConn{}, outside: &udp.NoopConn{},
firewall: &Firewall{}, firewall: &Firewall{},
lightHouse: lh, lightHouse: lh,

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"net/netip" "net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -86,7 +85,7 @@ type Interface struct {
conntrackCacheTimeout time.Duration conntrackCacheTimeout time.Duration
writers []udp.Conn writers []udp.Conn
readers []io.ReadWriteCloser readers []overlay.Queue
wg sync.WaitGroup wg sync.WaitGroup
// fatalErr holds the first unexpected reader error that caused shutdown. // fatalErr holds the first unexpected reader error that caused shutdown.
@@ -184,7 +183,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
routines: c.routines, routines: c.routines,
version: c.version, version: c.version,
writers: make([]udp.Conn, c.routines), writers: make([]udp.Conn, c.routines),
readers: make([]io.ReadWriteCloser, c.routines), readers: make([]overlay.Queue, c.routines),
myVpnNetworks: cs.myVpnNetworks, myVpnNetworks: cs.myVpnNetworks,
myVpnNetworksTable: cs.myVpnNetworksTable, myVpnNetworksTable: cs.myVpnNetworksTable,
myVpnAddrs: cs.myVpnAddrs, myVpnAddrs: cs.myVpnAddrs,
@@ -239,7 +238,7 @@ func (f *Interface) activate() error {
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines)) metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
// Prepare n tun queues // Prepare n tun queues
var reader io.ReadWriteCloser = f.inside var reader overlay.Queue = f.inside
for i := 0; i < f.routines; i++ { for i := 0; i < f.routines; i++ {
if i > 0 { if i > 0 {
reader, err = f.inside.NewMultiQueueReader() reader, err = f.inside.NewMultiQueueReader()
@@ -321,8 +320,7 @@ func (f *Interface) listenOut(i int) {
f.l.Infof("underlay reader %v is done", i) f.l.Infof("underlay reader %v is done", i)
} }
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { func (f *Interface) listenIn(reader overlay.Queue, i int) {
packet := make([]byte, mtu)
out := make([]byte, mtu) out := make([]byte, mtu)
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
@@ -330,7 +328,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
for { for {
n, err := reader.Read(packet) batch, err := reader.ReadBatch()
if err != nil { if err != nil {
if !f.closed.Load() { if !f.closed.Load() {
f.l.WithError(err).WithField("reader", i).Error("Error while reading outbound packet, closing") f.l.WithError(err).WithField("reader", i).Error("Error while reading outbound packet, closing")
@@ -339,7 +337,9 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
break break
} }
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) for _, pkt := range batch {
f.consumeInsidePacket(pkt, fwPacket, nb, out, i, conntrackCache.Get(f.l))
}
} }
f.l.Infof("overlay reader %v is done", i) f.l.Infof("overlay reader %v is done", i)

View File

@@ -7,12 +7,26 @@ import (
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
type Device interface { // defaultBatchBufSize is the per-Queue scratch size for ReadBatch on backends
// that don't do TSO segmentation. 65535 covers any single IP packet.
const defaultBatchBufSize = 65535
// Queue is a readable/writable tun queue. ReadBatch returns one or more
// packets; the returned slices are borrowed from the queue's internal buffer
// and are only valid until the next ReadBatch / Read / Close on this Queue.
// Callers must encrypt or copy each slice before the next call. Not safe for
// concurrent use — one goroutine per Queue.
type Queue interface {
io.ReadWriteCloser io.ReadWriteCloser
ReadBatch() ([][]byte, error)
}
type Device interface {
Queue
Activate() error Activate() error
Networks() []netip.Prefix Networks() []netip.Prefix
Name() string Name() string
RoutesFor(netip.Addr) routing.Gateways RoutesFor(netip.Addr) routing.Gateways
SupportsMultiqueue() bool SupportsMultiqueue() bool
NewMultiQueueReader() (io.ReadWriteCloser, error) NewMultiQueueReader() (Queue, error)
} }

View File

@@ -1,8 +1,7 @@
package test package overlay
import ( import (
"errors" "errors"
"io"
"net/netip" "net/netip"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
@@ -30,6 +29,10 @@ func (NoopTun) Read([]byte) (int, error) {
return 0, nil return 0, nil
} }
func (NoopTun) ReadBatch() ([][]byte, error) {
return nil, nil
}
func (NoopTun) Write([]byte) (int, error) { func (NoopTun) Write([]byte) (int, error) {
return 0, nil return 0, nil
} }
@@ -38,7 +41,7 @@ func (NoopTun) SupportsMultiqueue() bool {
return false return false
} }
func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (NoopTun) NewMultiQueueReader() (Queue, error) {
return nil, errors.New("unsupported") return nil, errors.New("unsupported")
} }

View File

@@ -24,6 +24,21 @@ type tun struct {
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger l *logrus.Logger
readBuf []byte
batchRet [1][]byte
}
func (t *tun) ReadBatch() ([][]byte, error) {
if t.readBuf == nil {
t.readBuf = make([]byte, defaultBatchBufSize)
}
n, err := t.Read(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
return t.batchRet[:], nil
} }
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
@@ -99,6 +114,6 @@ func (t *tun) SupportsMultiqueue() bool {
return false return false
} }
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (Queue, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for android") return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
} }

View File

@@ -34,6 +34,9 @@ type tun struct {
// cache out buffer since we need to prepend 4 bytes for tun metadata // cache out buffer since we need to prepend 4 bytes for tun metadata
out []byte out []byte
readBuf []byte
batchRet [1][]byte
} }
type ifReq struct { type ifReq struct {
@@ -512,6 +515,18 @@ func (t *tun) Read(to []byte) (int, error) {
return n - 4, err return n - 4, err
} }
func (t *tun) ReadBatch() ([][]byte, error) {
if t.readBuf == nil {
t.readBuf = make([]byte, defaultBatchBufSize)
}
n, err := t.Read(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
return t.batchRet[:], nil
}
// Write is only valid for single threaded use // Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) { func (t *tun) Write(from []byte) (int, error) {
buf := t.out buf := t.out
@@ -553,6 +568,6 @@ func (t *tun) SupportsMultiqueue() bool {
return false return false
} }
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (Queue, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
} }

View File

@@ -20,6 +20,21 @@ type disabledTun struct {
tx metrics.Counter tx metrics.Counter
rx metrics.Counter rx metrics.Counter
l *logrus.Logger l *logrus.Logger
readBuf []byte
batchRet [1][]byte
}
func (t *disabledTun) ReadBatch() ([][]byte, error) {
if t.readBuf == nil {
t.readBuf = make([]byte, defaultBatchBufSize)
}
n, err := t.Read(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
return t.batchRet[:], nil
} }
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
@@ -109,7 +124,7 @@ func (t *disabledTun) SupportsMultiqueue() bool {
return true return true
} }
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *disabledTun) NewMultiQueueReader() (Queue, error) {
return t, nil return t, nil
} }

View File

@@ -7,7 +7,6 @@ import (
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
"io"
"io/fs" "io/fs"
"net/netip" "net/netip"
"sync/atomic" "sync/atomic"
@@ -94,6 +93,21 @@ type tun struct {
linkAddr *netroute.LinkAddr linkAddr *netroute.LinkAddr
l *logrus.Logger l *logrus.Logger
devFd int devFd int
readBuf []byte
batchRet [1][]byte
}
func (t *tun) ReadBatch() ([][]byte, error) {
if t.readBuf == nil {
t.readBuf = make([]byte, defaultBatchBufSize)
}
n, err := t.Read(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
return t.batchRet[:], nil
} }
func (t *tun) Read(to []byte) (int, error) { func (t *tun) Read(to []byte) (int, error) {
@@ -454,7 +468,7 @@ func (t *tun) SupportsMultiqueue() bool {
return false return false
} }
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (Queue, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd") return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
} }

View File

@@ -26,6 +26,21 @@ type tun struct {
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger l *logrus.Logger
readBuf []byte
batchRet [1][]byte
}
func (t *tun) ReadBatch() ([][]byte, error) {
if t.readBuf == nil {
t.readBuf = make([]byte, defaultBatchBufSize)
}
n, err := t.Read(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
return t.batchRet[:], nil
} }
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
@@ -155,6 +170,6 @@ func (t *tun) SupportsMultiqueue() bool {
return false return false
} }
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (Queue, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios") return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
} }

View File

@@ -62,6 +62,7 @@ func (r *tunFile) newFriend(fd int) (*tunFile, error) {
fd: fd, fd: fd,
shutdownFd: r.shutdownFd, shutdownFd: r.shutdownFd,
vnetHdr: r.vnetHdr, vnetHdr: r.vnetHdr,
readBuf: make([]byte, tunReadBufSize),
readPoll: [2]unix.PollFd{ readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN}, {Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(r.shutdownFd), Events: unix.POLLIN}, {Fd: int32(r.shutdownFd), Events: unix.POLLIN},
@@ -72,7 +73,6 @@ func (r *tunFile) newFriend(fd int) (*tunFile, error) {
}, },
} }
if r.vnetHdr { if r.vnetHdr {
out.readBuf = make([]byte, tunReadBufSize)
out.segBuf = make([]byte, tunSegBufSize) out.segBuf = make([]byte, tunSegBufSize)
out.writeIovs[0].Base = &zeroVnetHdr[0] out.writeIovs[0].Base = &zeroVnetHdr[0]
out.writeIovs[0].SetLen(virtioNetHdrLen) out.writeIovs[0].SetLen(virtioNetHdrLen)
@@ -95,6 +95,7 @@ func newTunFd(fd int, vnetHdr bool) (*tunFile, error) {
shutdownFd: shutdownFd, shutdownFd: shutdownFd,
lastOne: true, lastOne: true,
vnetHdr: vnetHdr, vnetHdr: vnetHdr,
readBuf: make([]byte, tunReadBufSize),
readPoll: [2]unix.PollFd{ readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN}, {Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(shutdownFd), Events: unix.POLLIN}, {Fd: int32(shutdownFd), Events: unix.POLLIN},
@@ -105,7 +106,6 @@ func newTunFd(fd int, vnetHdr bool) (*tunFile, error) {
}, },
} }
if vnetHdr { if vnetHdr {
out.readBuf = make([]byte, tunReadBufSize)
out.segBuf = make([]byte, tunSegBufSize) out.segBuf = make([]byte, tunSegBufSize)
out.writeIovs[0].Base = &zeroVnetHdr[0] out.writeIovs[0].Base = &zeroVnetHdr[0]
out.writeIovs[0].SetLen(virtioNetHdrLen) out.writeIovs[0].SetLen(virtioNetHdrLen)
@@ -181,11 +181,39 @@ func (r *tunFile) readRaw(buf []byte) (int, error) {
} }
} }
func (r *tunFile) Read(buf []byte) (int, error) { // ReadBatch reads one superpacket from the tun and returns the resulting
if !r.vnetHdr { // packets. Slices point into the tunFile's internal buffers and are only
return r.readRaw(buf) // valid until the next ReadBatch / Read / Close on this Queue.
} func (r *tunFile) ReadBatch() ([][]byte, error) {
r.pending = r.pending[:0]
r.pendingIdx = 0
for {
n, err := r.readRaw(r.readBuf)
if err != nil {
return nil, err
}
if !r.vnetHdr {
r.pending = append(r.pending, r.readBuf[:n])
return r.pending, nil
}
if n < virtioNetHdrLen {
return nil, fmt.Errorf("short tun read: %d < %d", n, virtioNetHdrLen)
}
var hdr virtioNetHdr
hdr.decode(r.readBuf[:virtioNetHdrLen])
if err := segmentInto(r.readBuf[virtioNetHdrLen:n], hdr, &r.pending, r.segBuf); err != nil {
// Drop and read again — a bad packet should not kill the reader.
continue
}
return r.pending, nil
}
}
// Read drains segments produced by the last ReadBatch one at a time; when the
// batch is exhausted it fetches a fresh one. Kept for io.Reader compatibility;
// batch-aware callers should use ReadBatch directly.
func (r *tunFile) Read(buf []byte) (int, error) {
for { for {
if r.pendingIdx < len(r.pending) { if r.pendingIdx < len(r.pending) {
seg := r.pending[r.pendingIdx] seg := r.pending[r.pendingIdx]
@@ -195,22 +223,9 @@ func (r *tunFile) Read(buf []byte) (int, error) {
} }
return copy(buf, seg), nil return copy(buf, seg), nil
} }
r.pending = r.pending[:0] if _, err := r.ReadBatch(); err != nil {
r.pendingIdx = 0
n, err := r.readRaw(r.readBuf)
if err != nil {
return 0, err return 0, err
} }
if n < virtioNetHdrLen {
return 0, fmt.Errorf("short tun read: %d < %d", n, virtioNetHdrLen)
}
var hdr virtioNetHdr
hdr.decode(r.readBuf[:virtioNetHdrLen])
if err := segmentInto(r.readBuf[virtioNetHdrLen:n], hdr, &r.pending, r.segBuf); err != nil {
// Drop and read again — a bad packet should not kill the reader.
continue
}
} }
} }
@@ -540,7 +555,7 @@ func (t *tun) SupportsMultiqueue() bool {
return true return true
} }
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (Queue, error) {
t.closeLock.Lock() t.closeLock.Lock()
defer t.closeLock.Unlock() defer t.closeLock.Unlock()

View File

@@ -1,5 +1,5 @@
//go:build !android && !e2e_testing //go:build linux && !android && !e2e_testing
// +build !android,!e2e_testing // +build linux,!android,!e2e_testing
package overlay package overlay

View File

@@ -1,5 +1,5 @@
//go:build !android && !e2e_testing //go:build linux && !android && !e2e_testing
// +build !android,!e2e_testing // +build linux,!android,!e2e_testing
package overlay package overlay

View File

@@ -6,7 +6,6 @@ package overlay
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"net/netip" "net/netip"
"os" "os"
"regexp" "regexp"
@@ -66,6 +65,21 @@ type tun struct {
l *logrus.Logger l *logrus.Logger
f *os.File f *os.File
fd int fd int
readBuf []byte
batchRet [1][]byte
}
func (t *tun) ReadBatch() ([][]byte, error) {
if t.readBuf == nil {
t.readBuf = make([]byte, defaultBatchBufSize)
}
n, err := t.Read(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
return t.batchRet[:], nil
} }
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
@@ -394,7 +408,7 @@ func (t *tun) SupportsMultiqueue() bool {
return false return false
} }
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (Queue, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd") return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
} }

View File

@@ -6,7 +6,6 @@ package overlay
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"net/netip" "net/netip"
"os" "os"
"regexp" "regexp"
@@ -59,6 +58,21 @@ type tun struct {
fd int fd int
// cache out buffer since we need to prepend 4 bytes for tun metadata // cache out buffer since we need to prepend 4 bytes for tun metadata
out []byte out []byte
readBuf []byte
batchRet [1][]byte
}
func (t *tun) ReadBatch() ([][]byte, error) {
if t.readBuf == nil {
t.readBuf = make([]byte, defaultBatchBufSize)
}
n, err := t.Read(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
return t.batchRet[:], nil
} }
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
@@ -314,7 +328,7 @@ func (t *tun) SupportsMultiqueue() bool {
return false return false
} }
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (Queue, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd") return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
} }

View File

@@ -26,6 +26,21 @@ type TestTun struct {
closed atomic.Bool closed atomic.Bool
rxPackets chan []byte // Packets to receive into nebula rxPackets chan []byte // Packets to receive into nebula
TxPackets chan []byte // Packets transmitted outside by nebula TxPackets chan []byte // Packets transmitted outside by nebula
readBuf []byte
batchRet [1][]byte
}
func (t *TestTun) ReadBatch() ([][]byte, error) {
if t.readBuf == nil {
t.readBuf = make([]byte, defaultBatchBufSize)
}
n, err := t.Read(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
return t.batchRet[:], nil
} }
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
@@ -136,6 +151,6 @@ func (t *TestTun) SupportsMultiqueue() bool {
return false return false
} }
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *TestTun) NewMultiQueueReader() (Queue, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented") return nil, fmt.Errorf("TODO: multiqueue not implemented")
} }

View File

@@ -6,7 +6,6 @@ package overlay
import ( import (
"crypto" "crypto"
"fmt" "fmt"
"io"
"net/netip" "net/netip"
"os" "os"
"path/filepath" "path/filepath"
@@ -36,6 +35,21 @@ type winTun struct {
l *logrus.Logger l *logrus.Logger
tun *wintun.NativeTun tun *wintun.NativeTun
readBuf []byte
batchRet [1][]byte
}
func (t *winTun) ReadBatch() ([][]byte, error) {
if t.readBuf == nil {
t.readBuf = make([]byte, defaultBatchBufSize)
}
n, err := t.Read(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
return t.batchRet[:], nil
} }
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) {
@@ -241,7 +255,7 @@ func (t *winTun) SupportsMultiqueue() bool {
return false return false
} }
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *winTun) NewMultiQueueReader() (Queue, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
} }

View File

@@ -34,6 +34,21 @@ type UserDevice struct {
inboundReader *io.PipeReader inboundReader *io.PipeReader
inboundWriter *io.PipeWriter inboundWriter *io.PipeWriter
readBuf []byte
batchRet [1][]byte
}
func (d *UserDevice) ReadBatch() ([][]byte, error) {
if d.readBuf == nil {
d.readBuf = make([]byte, defaultBatchBufSize)
}
n, err := d.Read(d.readBuf)
if err != nil {
return nil, err
}
d.batchRet[0] = d.readBuf[:n]
return d.batchRet[:], nil
} }
func (d *UserDevice) Activate() error { func (d *UserDevice) Activate() error {
@@ -50,7 +65,7 @@ func (d *UserDevice) SupportsMultiqueue() bool {
return true return true
} }
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (d *UserDevice) NewMultiQueueReader() (Queue, error) {
return d, nil return d, nil
} }