Compare commits

..

17 Commits

Author SHA1 Message Date
Jay Wren
be90e4aa05 handle virtio header in ctrl messages 2025-11-19 17:09:39 -05:00
Jay Wren
bc9711df68 batch more writes 2025-11-19 16:58:59 -05:00
Jay Wren
4e333c76ba write batching 2025-11-19 14:03:36 -05:00
Jay Wren
f29e21b411 don't register metrics in loops 2025-11-19 13:25:25 -05:00
Jay Wren
8b32382cd9 zero copy even with virtioheder 2025-11-19 12:03:38 -05:00
Jay Wren
518a78c9d2 preallocate nonce buffer 2025-11-18 14:19:05 -05:00
Jay Wren
7c3708561d instruments 2025-11-14 14:43:51 -05:00
Jay Wren
a62ffca975 fix 32bit 2025-11-13 15:10:51 -05:00
Jay Wren
226787ea1f prealloc them buffers 2025-11-11 15:20:50 -05:00
Jay Wren
b2bc6a09ca write in batches 2025-11-11 15:06:45 -05:00
Jay Wren
0f9b33aa36 reduce copying 2025-11-11 14:51:53 -05:00
Jay Wren
ef0a022375 more nonblocking 2025-11-11 14:22:40 -05:00
Jay Wren
b68e504865 hrm 2025-11-11 13:15:30 -05:00
Jay Wren
3344a840d1 just using the wg library works 2025-11-11 10:55:39 -05:00
Jay Wren
2bc9863e66 only wg tun, no batching 2025-11-10 16:54:00 -05:00
Wade Simmons
97b3972c11 honor remote_allow_list in hole punch response (#1186)
* honor remote_allow_ilst in hole punch response

When we receive a "hole punch notification" from a Lighthouse, we send
a hole punch packet to every remote of that host, even if we don't
include those remotes in our "remote_allow_list". Change the logic here
to check if the remote IP is in our allow list before sending the hole
punch packet.

* fix for netip

* cleanup
2025-11-10 13:52:40 -05:00
Jack Doan
0f305d5397 don't block startup on failure to configure SSH (#1520) 2025-11-05 10:41:56 -06:00
31 changed files with 1233 additions and 247 deletions

View File

@@ -65,16 +65,8 @@ func main() {
}
if !*configTest {
wait, err := ctrl.Start()
if err != nil {
util.LogWithContextIfNeeded("Error while running", err, l)
os.Exit(1)
}
go ctrl.ShutdownBlock()
wait()
l.Info("Goodbye")
ctrl.Start()
ctrl.ShutdownBlock()
}
os.Exit(0)

View File

@@ -3,9 +3,6 @@ package main
import (
"flag"
"fmt"
"log"
"net/http"
_ "net/http/pprof"
"os"
"github.com/sirupsen/logrus"
@@ -61,22 +58,10 @@ func main() {
os.Exit(1)
}
go func() {
log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
}()
if !*configTest {
wait, err := ctrl.Start()
if err != nil {
util.LogWithContextIfNeeded("Error while running", err, l)
os.Exit(1)
}
go ctrl.ShutdownBlock()
ctrl.Start()
notifyReady(l)
wait()
l.Info("Goodbye")
ctrl.ShutdownBlock()
}
os.Exit(0)

View File

@@ -2,11 +2,9 @@ package nebula
import (
"context"
"errors"
"net/netip"
"os"
"os/signal"
"sync"
"syscall"
"github.com/sirupsen/logrus"
@@ -15,16 +13,6 @@ import (
"github.com/slackhq/nebula/overlay"
)
type RunState int
const (
Stopped RunState = 0 // The control has yet to be started
Started RunState = 1 // The control has been started
Stopping RunState = 2 // The control is stopping
)
var ErrAlreadyStarted = errors.New("nebula is already started")
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
@@ -38,9 +26,6 @@ type controlHostLister interface {
}
type Control struct {
stateLock sync.Mutex
state RunState
f *Interface
l *logrus.Logger
ctx context.Context
@@ -64,21 +49,10 @@ type ControlHostInfo struct {
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
}
// Start actually runs nebula, this is a nonblocking call.
// The returned function can be used to wait for nebula to fully stop.
func (c *Control) Start() (func(), error) {
c.stateLock.Lock()
if c.state != Stopped {
c.stateLock.Unlock()
return nil, ErrAlreadyStarted
}
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
func (c *Control) Start() {
// Activate the interface
err := c.f.activate()
if err != nil {
c.stateLock.Unlock()
return nil, err
}
c.f.activate()
// Call all the delayed funcs that waited patiently for the interface to be created.
if c.sshStart != nil {
@@ -98,33 +72,15 @@ func (c *Control) Start() (func(), error) {
}
// Start reading packets.
c.state = Started
c.stateLock.Unlock()
return c.f.run()
}
func (c *Control) State() RunState {
c.stateLock.Lock()
defer c.stateLock.Unlock()
return c.state
c.f.run()
}
func (c *Control) Context() context.Context {
return c.ctx
}
// Stop is a non-blocking call that signals nebula to close all tunnels and shut down
// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete
func (c *Control) Stop() {
c.stateLock.Lock()
if c.state != Started {
c.stateLock.Unlock()
// We are stopping or stopped already
return
}
c.state = Stopping
c.stateLock.Unlock()
// Stop the handshakeManager (and other services), to prevent new tunnels from
// being created while we're shutting them all down.
c.cancel()
@@ -133,7 +89,7 @@ func (c *Control) Stop() {
if err := c.f.Close(); err != nil {
c.l.WithError(err).Error("Close interface failed")
}
c.state = Stopped
c.l.Info("Goodbye")
}
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled

143
inside.go
View File

@@ -11,6 +11,149 @@ import (
"github.com/slackhq/nebula/routing"
)
// consumeInsidePackets processes multiple packets in a batch for improved performance
// packets: slice of packet buffers to process
// sizes: slice of packet sizes
// count: number of packets to process
// outs: slice of output buffers (one per packet) with virtio headroom
// q: queue index
// localCache: firewall conntrack cache
// batchPackets: pre-allocated slice for accumulating encrypted packets
// batchAddrs: pre-allocated slice for accumulating destination addresses
func (f *Interface) consumeInsidePackets(packets [][]byte, sizes []int, count int, outs [][]byte, nb []byte, q int, localCache firewall.ConntrackCache, batchPackets *[][]byte, batchAddrs *[]netip.AddrPort) {
// Reusable per-packet state
fwPacket := &firewall.Packet{}
// Reset batch accumulation slices (reuse capacity)
*batchPackets = (*batchPackets)[:0]
*batchAddrs = (*batchAddrs)[:0]
// Process each packet in the batch
for i := 0; i < count; i++ {
packet := packets[i][:sizes[i]]
out := outs[i]
// Inline the consumeInsidePacket logic for better performance
err := newPacket(packet, false, fwPacket)
if err != nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
}
continue
}
// Ignore local broadcast packets
if f.dropLocalBroadcast {
if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
continue
}
}
if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) {
// Immediately forward packets from self to self.
if immediatelyForwardToSelf {
_, err := f.readers[q].Write(packet)
if err != nil {
f.l.WithError(err).Error("Failed to forward to tun")
}
}
continue
}
// Ignore multicast packets
if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
continue
}
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
})
if hostinfo == nil {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
WithField("fwPacket", fwPacket).
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
}
continue
}
if !ready {
continue
}
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason != nil {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).
WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping outbound packet")
}
continue
}
// Encrypt and prepare packet for batch sending
ci := hostinfo.ConnectionState
if ci.eKey == nil {
continue
}
// Check if this needs relay - if so, send immediately and skip batching
useRelay := !hostinfo.remote.IsValid()
if useRelay {
// Handle relay sends individually (less common path)
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, packet, nb, out, q)
continue
}
// Encrypt the packet for batch sending
if noiseutil.EncryptLockNeeded {
ci.writeLock.Lock()
}
c := ci.messageCounter.Add(1)
out = header.Encode(out, header.Version, header.Message, 0, hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo)
// Query lighthouse if needed
if hostinfo.lastRebindCount != f.rebindCount {
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
hostinfo.lastRebindCount = f.rebindCount
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
}
}
out, err = ci.eKey.EncryptDanger(out, out, packet, c, nb)
if noiseutil.EncryptLockNeeded {
ci.writeLock.Unlock()
}
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("counter", c).
Error("Failed to encrypt outgoing packet")
continue
}
// Add to batch
*batchPackets = append(*batchPackets, out)
*batchAddrs = append(*batchAddrs, hostinfo.remote)
}
// Send all accumulated packets in one batch
if len(*batchPackets) > 0 {
batchSize := len(*batchPackets)
f.batchMetrics.udpWriteSize.Update(int64(batchSize))
n, err := f.writers[q].WriteMulti(*batchPackets, *batchAddrs)
if err != nil {
f.l.WithError(err).WithField("sent", n).WithField("total", batchSize).Error("Failed to send batch")
}
}
}
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
err := newPacket(packet, false, fwPacket)
if err != nil {

View File

@@ -4,10 +4,9 @@ import (
"context"
"errors"
"fmt"
"io"
"net/netip"
"os"
"runtime"
"sync"
"sync/atomic"
"time"
@@ -22,6 +21,7 @@ import (
)
const mtu = 9001
const virtioNetHdrLen = overlay.VirtioNetHdrLen
type InterfaceConfig struct {
HostMap *HostMap
@@ -50,6 +50,13 @@ type InterfaceConfig struct {
l *logrus.Logger
}
type batchMetrics struct {
udpReadSize metrics.Histogram
tunReadSize metrics.Histogram
udpWriteSize metrics.Histogram
tunWriteSize metrics.Histogram
}
type Interface struct {
hostMap *HostMap
outside udp.Conn
@@ -86,12 +93,12 @@ type Interface struct {
conntrackCacheTimeout time.Duration
writers []udp.Conn
readers []io.ReadWriteCloser
wg sync.WaitGroup
readers []overlay.BatchReadWriter
metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics
cachedPacketMetrics *cachedPacketMetrics
batchMetrics *batchMetrics
l *logrus.Logger
}
@@ -178,7 +185,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
routines: c.routines,
version: c.version,
writers: make([]udp.Conn, c.routines),
readers: make([]io.ReadWriteCloser, c.routines),
readers: make([]overlay.BatchReadWriter, c.routines),
myVpnNetworks: cs.myVpnNetworks,
myVpnNetworksTable: cs.myVpnNetworksTable,
myVpnAddrs: cs.myVpnAddrs,
@@ -194,6 +201,12 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
sent: metrics.GetOrRegisterCounter("hostinfo.cached_packets.sent", nil),
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
},
batchMetrics: &batchMetrics{
udpReadSize: metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024)),
tunReadSize: metrics.GetOrRegisterHistogram("batch.tun_read_size", nil, metrics.NewUniformSample(1024)),
udpWriteSize: metrics.GetOrRegisterHistogram("batch.udp_write_size", nil, metrics.NewUniformSample(1024)),
tunWriteSize: metrics.GetOrRegisterHistogram("batch.tun_write_size", nil, metrics.NewUniformSample(1024)),
},
l: c.l,
}
@@ -210,7 +223,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
// activate creates the interface on the host. After the interface is created, any
// other services that want to bind listeners to its IP may do so successfully. However,
// the interface isn't going to process anything until run() is called.
func (f *Interface) activate() error {
func (f *Interface) activate() {
// actually turn on tun dev
addr, err := f.outside.LocalAddr()
@@ -226,43 +239,38 @@ func (f *Interface) activate() error {
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
// Prepare n tun queues
var reader io.ReadWriteCloser = f.inside
var reader overlay.BatchReadWriter = f.inside
for i := 0; i < f.routines; i++ {
if i > 0 {
reader, err = f.inside.NewMultiQueueReader()
if err != nil {
return err
f.l.Fatal(err)
}
}
f.readers[i] = reader
}
if err = f.inside.Activate(); err != nil {
if err := f.inside.Activate(); err != nil {
f.inside.Close()
return err
f.l.Fatal(err)
}
return nil
}
func (f *Interface) run() (func(), error) {
func (f *Interface) run() {
// Launch n queues to read packets from udp
for i := 0; i < f.routines; i++ {
f.wg.Add(1)
go f.listenOut(i)
}
// Launch n queues to read packets from tun dev
for i := 0; i < f.routines; i++ {
f.wg.Add(1)
go f.listenIn(f.readers[i], i)
}
return f.wg.Wait, nil
}
func (f *Interface) listenOut(i int) {
runtime.LockOSThread()
var li udp.Conn
if i > 0 {
li = f.writers[i]
@@ -272,48 +280,70 @@ func (f *Interface) listenOut(i int) {
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
lhh := f.lightHouse.NewRequestHandler()
plaintext := make([]byte, udp.MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
})
if err != nil && !f.closed.Load() {
f.l.WithError(err).Error("Error while reading packet inbound packet, closing")
//TODO: Trigger Control to close
// Pre-allocate output buffers for batch processing
batchSize := li.BatchSize()
outs := make([][]byte, batchSize)
for idx := range outs {
// Allocate full buffer with virtio header space
outs[idx] = make([]byte, virtioNetHdrLen, virtioNetHdrLen+udp.MTU)
}
f.l.Debugf("underlay reader %v is done", i)
f.wg.Done()
h := &header.H{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12)
li.ListenOutBatch(func(addrs []netip.AddrPort, payloads [][]byte, count int) {
f.readOutsidePacketsBatch(addrs, payloads, count, outs[:count], nb, i, h, fwPacket, lhh, ctCache.Get(f.l))
})
}
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
func (f *Interface) listenIn(reader overlay.BatchReadWriter, i int) {
runtime.LockOSThread()
packet := make([]byte, mtu)
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
batchSize := reader.BatchSize()
// Allocate buffers for batch reading
bufs := make([][]byte, batchSize)
for idx := range bufs {
bufs[idx] = make([]byte, mtu)
}
sizes := make([]int, batchSize)
// Allocate output buffers for batch processing (one per packet)
// Each has virtio header headroom to avoid copies on write
outs := make([][]byte, batchSize)
for idx := range outs {
outBuf := make([]byte, virtioNetHdrLen+mtu)
outs[idx] = outBuf[virtioNetHdrLen:] // Slice starting after headroom
}
// Pre-allocate batch accumulation buffers for sending
batchPackets := make([][]byte, 0, batchSize)
batchAddrs := make([]netip.AddrPort, 0, batchSize)
// Pre-allocate nonce buffer (reused for all encryptions)
nb := make([]byte, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
for {
n, err := reader.Read(packet)
n, err := reader.BatchRead(bufs, sizes)
if err != nil {
if !f.closed.Load() {
f.l.WithError(err).Error("Error while reading outbound packet, closing")
//TODO: Trigger Control to close
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return
}
break
f.l.WithError(err).Error("Error while batch reading outbound packets")
// This only seems to happen when something fatal happens to the fd, so exit.
os.Exit(2)
}
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
}
f.batchMetrics.tunReadSize.Update(int64(n))
f.l.Debugf("overlay reader %v is done", i)
f.wg.Done()
// Process all packets in the batch at once
f.consumeInsidePackets(bufs, sizes, n, outs, nb, i, conntrackCache.Get(f.l), &batchPackets, &batchAddrs)
}
}
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
@@ -465,7 +495,6 @@ func (f *Interface) GetCertState() *CertState {
func (f *Interface) Close() error {
f.closed.Store(true)
// Release the udp readers
for _, u := range f.writers {
err := u.Close()
if err != nil {
@@ -473,13 +502,6 @@ func (f *Interface) Close() error {
}
}
// Release the tun readers
for _, u := range f.readers {
err := u.Close()
if err != nil {
f.l.WithError(err).Error("Error while closing tun device")
}
}
return nil
// Release the tun device
return f.inside.Close()
}

View File

@@ -1337,12 +1337,19 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
}
}
remoteAllowList := lhh.lh.GetRemoteAllowList()
for _, a := range n.Details.V4AddrPorts {
punch(protoV4AddrPortToNetAddrPort(a), detailsVpnAddr)
b := protoV4AddrPortToNetAddrPort(a)
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
punch(b, detailsVpnAddr)
}
}
for _, a := range n.Details.V6AddrPorts {
punch(protoV6AddrPortToNetAddrPort(a), detailsVpnAddr)
b := protoV6AddrPortToNetAddrPort(a)
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
punch(b, detailsVpnAddr)
}
}
// This sends a nebula test packet to the host trying to contact us. In the case

23
main.go
View File

@@ -75,7 +75,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if c.GetBool("sshd.enabled", false) {
sshStart, err = configSSH(l, ssh, c)
if err != nil {
return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err)
l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available")
sshStart = nil
}
}
@@ -164,7 +165,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
for i := 0; i < routines; i++ {
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 128))
if err != nil {
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
}
@@ -284,14 +285,14 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
return &Control{
f: ifce,
l: l,
ctx: ctx,
cancel: cancel,
sshStart: sshStart,
statsStart: statsStart,
dnsStart: dnsStart,
lighthouseStart: lightHouse.StartUpdateWorker,
connectionManagerStart: connManager.Start,
ifce,
l,
ctx,
cancel,
sshStart,
statsStart,
dnsStart,
lightHouse.StartUpdateWorker,
connManager.Start,
}, nil
}

View File

@@ -29,7 +29,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
return
}
//f.l.Error("in packet ", h)
//l.Error("in packet ", header, packet[HeaderLen:])
if ip.IsValid() {
if f.myVpnNetworksTable.Contains(ip.Addr()) {
if f.l.Level >= logrus.DebugLevel {
@@ -95,8 +95,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
switch relay.Type {
case TerminalType:
// If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:virtioNetHdrLen], signedPayload, h, fwPacket, lhf, nb, q, localCache)
return
case ForwardingType:
// Find the target HostInfo relay object
@@ -138,7 +137,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
return
}
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d[virtioNetHdrLen:], f)
// Fallthrough to the bottom to record incoming traffic
@@ -160,7 +159,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
// This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding
f.handleHostRoaming(hostinfo, ip)
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
f.send(header.Test, header.TestReply, ci, hostinfo, d[virtioNetHdrLen:], nb, out)
}
// Fallthrough to the bottom to record incoming traffic
@@ -203,7 +202,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
return
}
f.relayManager.HandleControlMsg(hostinfo, d, f)
f.relayManager.HandleControlMsg(hostinfo, d[virtioNetHdrLen:], f)
default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
@@ -474,9 +473,11 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return false
}
err = newPacket(out, true, fwPacket)
packetData := out[virtioNetHdrLen:]
err = newPacket(packetData, true, fwPacket)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
hostinfo.logger(f.l).WithError(err).WithField("packet", packetData).
Warnf("Error while validating inbound packet")
return false
}
@@ -491,7 +492,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
if dropReason != nil {
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
f.rejectOutside(packetData, hostinfo.ConnectionState, hostinfo, nb, packet, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
@@ -548,3 +549,108 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
// We also delete it from pending hostmap to allow for fast reconnect.
f.handshakeManager.DeleteHostInfo(hostinfo)
}
// readOutsidePacketsBatch processes multiple packets received from UDP in a batch
// and writes all successfully decrypted packets to TUN in a single operation
func (f *Interface) readOutsidePacketsBatch(addrs []netip.AddrPort, payloads [][]byte, count int, outs [][]byte, nb []byte, q int, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, localCache firewall.ConntrackCache) {
// Pre-allocate slice for accumulating successful decryptions
tunPackets := make([][]byte, 0, count)
for i := 0; i < count; i++ {
payload := payloads[i]
addr := addrs[i]
out := outs[i]
// Parse header
err := h.Parse(payload)
if err != nil {
if len(payload) > 1 {
f.l.WithField("packet", payload).Infof("Error while parsing inbound packet from %s: %s", addr, err)
}
continue
}
if addr.IsValid() {
if f.myVpnNetworksTable.Contains(addr.Addr()) {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("udpAddr", addr).Debug("Refusing to process double encrypted packet")
}
continue
}
}
var hostinfo *HostInfo
if h.Type == header.Message && h.Subtype == header.MessageRelay {
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
} else {
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
}
var ci *ConnectionState
if hostinfo != nil {
ci = hostinfo.ConnectionState
}
switch h.Type {
case header.Message:
if !f.handleEncrypted(ci, addr, h) {
continue
}
switch h.Subtype {
case header.MessageNone:
// Decrypt packet
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, payload[:header.Len], payload[header.Len:], h.MessageCounter, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
continue
}
packetData := out[virtioNetHdrLen:]
err = newPacket(packetData, true, fwPacket)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("packet", packetData).Warnf("Error while validating inbound packet")
continue
}
if !hostinfo.ConnectionState.window.Update(f.l, h.MessageCounter) {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).Debugln("dropping out of window packet")
continue
}
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason != nil {
f.rejectOutside(packetData, hostinfo.ConnectionState, hostinfo, nb, payload, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).WithField("reason", dropReason).Debugln("dropping inbound packet")
}
continue
}
f.connectionManager.In(hostinfo)
// Add to batch for TUN write
tunPackets = append(tunPackets, out)
case header.MessageRelay:
// Skip relay packets in batch mode for now (less common path)
f.readOutsidePackets(addr, nil, out, payload, h, fwPacket, lhf, nb, q, localCache)
default:
hostinfo.logger(f.l).Debugf("unexpected message subtype %d", h.Subtype)
}
default:
// Handle non-Message types using single-packet path
f.readOutsidePackets(addr, nil, out, payload, h, fwPacket, lhf, nb, q, localCache)
}
}
if len(tunPackets) > 0 {
n, err := f.readers[q].WriteBatch(tunPackets, virtioNetHdrLen)
if err != nil {
f.l.WithError(err).WithField("sent", n).WithField("total", len(tunPackets)).Error("Failed to batch write to tun")
}
f.batchMetrics.tunWriteSize.Update(int64(len(tunPackets)))
}
}

View File

@@ -7,11 +7,25 @@ import (
"github.com/slackhq/nebula/routing"
)
type Device interface {
// BatchReadWriter extends io.ReadWriteCloser with batch I/O operations
type BatchReadWriter interface {
io.ReadWriteCloser
// BatchRead reads multiple packets at once
BatchRead(bufs [][]byte, sizes []int) (int, error)
// WriteBatch writes multiple packets at once
WriteBatch(bufs [][]byte, offset int) (int, error)
// BatchSize returns the optimal batch size for this device
BatchSize() int
}
type Device interface {
BatchReadWriter
Activate() error
Networks() []netip.Prefix
Name() string
RoutesFor(netip.Addr) routing.Gateways
NewMultiQueueReader() (io.ReadWriteCloser, error)
NewMultiQueueReader() (BatchReadWriter, error)
}

View File

@@ -11,6 +11,7 @@ import (
)
const DefaultMTU = 1300
const VirtioNetHdrLen = 10 // Size of virtio_net_hdr structure
// TODO: We may be able to remove routines
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)

View File

@@ -95,6 +95,29 @@ func (t *tun) Name() string {
return "android"
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
}
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
func (t *tun) BatchSize() int {
return 1
}

View File

@@ -549,6 +549,32 @@ func (t *tun) Name() string {
return t.Device
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
}
// BatchRead reads a single packet (batch size 1 for non-Linux platforms)
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// WriteBatch writes packets individually (no batching for non-Linux platforms)
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
// BatchSize returns 1 for non-Linux platforms (no batching)
func (t *tun) BatchSize() int {
return 1
}

View File

@@ -105,10 +105,36 @@ func (t *disabledTun) Write(b []byte) (int, error) {
return len(b), nil
}
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *disabledTun) NewMultiQueueReader() (BatchReadWriter, error) {
return t, nil
}
// BatchRead reads a single packet (batch size 1 for disabled tun)
func (t *disabledTun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// WriteBatch writes packets individually (no batching for disabled tun)
func (t *disabledTun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
// BatchSize returns 1 for disabled tun (no batching)
func (t *disabledTun) BatchSize() int {
return 1
}
func (t *disabledTun) Close() error {
if t.read != nil {
close(t.read)

View File

@@ -450,10 +450,36 @@ func (t *tun) Name() string {
return t.Device
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
}
// BatchRead reads a single packet (batch size 1 for FreeBSD)
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// WriteBatch writes packets individually (no batching for FreeBSD)
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
// BatchSize returns 1 for FreeBSD (no batching)
func (t *tun) BatchSize() int {
return 1
}
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {

View File

@@ -151,6 +151,29 @@ func (t *tun) Name() string {
return "iOS"
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
}
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
func (t *tun) BatchSize() int {
return 1
}

View File

@@ -9,7 +9,6 @@ import (
"net"
"net/netip"
"os"
"strings"
"sync/atomic"
"time"
"unsafe"
@@ -21,10 +20,12 @@ import (
"github.com/slackhq/nebula/util"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
wgtun "golang.zx2c4.com/wireguard/tun"
)
type tun struct {
io.ReadWriteCloser
wgDevice wgtun.Device
fd int
Device string
vpnNetworks []netip.Prefix
@@ -65,59 +66,154 @@ type ifreqQLEN struct {
pad [8]byte
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
// wgDeviceWrapper wraps a wireguard Device to implement io.ReadWriteCloser
// This allows multiqueue readers to use the same wireguard Device batching as the main device
type wgDeviceWrapper struct {
dev wgtun.Device
buf []byte // Reusable buffer for single packet reads
}
func (w *wgDeviceWrapper) Read(b []byte) (int, error) {
// Use wireguard Device's batch API for single packet
bufs := [][]byte{b}
sizes := make([]int, 1)
n, err := w.dev.Read(bufs, sizes, 0)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.EOF
}
return sizes[0], nil
}
func (w *wgDeviceWrapper) Write(b []byte) (int, error) {
// Buffer b should have virtio header space (10 bytes) at the beginning
// The decrypted packet data starts at offset 10
// Pass the full buffer to WireGuard with offset=virtioNetHdrLen
bufs := [][]byte{b}
n, err := w.dev.Write(bufs, VirtioNetHdrLen)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.ErrShortWrite
}
return len(b), nil
}
func (w *wgDeviceWrapper) WriteBatch(bufs [][]byte, offset int) (int, error) {
// Pass all buffers to WireGuard's batch write
return w.dev.Write(bufs, offset)
}
func (w *wgDeviceWrapper) Close() error {
return w.dev.Close()
}
// BatchRead implements batching for multiqueue readers
func (w *wgDeviceWrapper) BatchRead(bufs [][]byte, sizes []int) (int, error) {
// The zero here is offset.
return w.dev.Read(bufs, sizes, 0)
}
// BatchSize returns the optimal batch size
func (w *wgDeviceWrapper) BatchSize() int {
return w.dev.BatchSize()
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
wgDev, name, err := wgtun.CreateUnmonitoredTUNFromFD(deviceFd)
if err != nil {
return nil, fmt.Errorf("failed to create TUN from FD: %w", err)
}
file := wgDev.File()
t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil {
_ = wgDev.Close()
return nil, err
}
t.Device = "tun0"
t.wgDevice = wgDev
t.Device = name
return t, nil
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
if os.IsNotExist(err) {
err = os.MkdirAll("/dev/net", 0755)
if err != nil {
return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
}
err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200)))
if err != nil {
return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
}
fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err)
}
} else {
return nil, err
// Check if /dev/net/tun exists, create if needed (for docker containers)
if _, err := os.Stat("/dev/net/tun"); os.IsNotExist(err) {
if err := os.MkdirAll("/dev/net", 0755); err != nil {
return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
}
if err := unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))); err != nil {
return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
}
}
var req ifReq
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
if multiqueue {
req.Flags |= unix.IFF_MULTI_QUEUE
}
copy(req.Name[:], c.GetString("tun.dev", ""))
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
return nil, err
}
name := strings.Trim(string(req.Name[:]), "\x00")
devName := c.GetString("tun.dev", "")
mtu := c.GetInt("tun.mtu", DefaultMTU)
file := os.NewFile(uintptr(fd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, vpnNetworks)
// Create TUN device manually to support multiqueue
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, err
}
var req ifReq
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR)
if multiqueue {
req.Flags |= unix.IFF_MULTI_QUEUE
}
copy(req.Name[:], devName)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
unix.Close(fd)
return nil, err
}
// Set nonblocking
if err = unix.SetNonblock(fd, true); err != nil {
unix.Close(fd)
return nil, err
}
// Enable TCP and UDP offload (TSO/GRO) for performance
// This allows the kernel to handle segmentation/coalescing
const (
tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
)
offloads := tunTCPOffloads | tunUDPOffloads
if err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, offloads); err != nil {
// Log warning but don't fail - offload is optional
l.WithError(err).Warn("Failed to enable TUN offload (TSO/GRO), performance may be reduced")
}
file := os.NewFile(uintptr(fd), "/dev/net/tun")
// Create wireguard device from file descriptor
wgDev, err := wgtun.CreateTUNFromFile(file, mtu)
if err != nil {
file.Close()
return nil, fmt.Errorf("failed to create TUN from file: %w", err)
}
name, err := wgDev.Name()
if err != nil {
_ = wgDev.Close()
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
}
// file is now owned by wgDev, get a new reference
file = wgDev.File()
t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil {
_ = wgDev.Close()
return nil, err
}
t.wgDevice = wgDev
t.Device = name
return t, nil
@@ -216,22 +312,44 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, err
}
var req ifReq
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
// MUST match the flags used in newTun - includes IFF_VNET_HDR
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR | unix.IFF_MULTI_QUEUE)
copy(req.Name[:], t.Device)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
unix.Close(fd)
return nil, err
}
// Set nonblocking mode - CRITICAL for proper netpoller integration
if err = unix.SetNonblock(fd, true); err != nil {
unix.Close(fd)
return nil, err
}
// Get MTU from main device
mtu := t.MaxMTU
if mtu == 0 {
mtu = DefaultMTU
}
file := os.NewFile(uintptr(fd), "/dev/net/tun")
return file, nil
// Create wireguard Device from the file descriptor (just like the main device)
wgDev, err := wgtun.CreateTUNFromFile(file, mtu)
if err != nil {
file.Close()
return nil, fmt.Errorf("failed to create multiqueue TUN device: %w", err)
}
// Return a wrapper that uses the wireguard Device for all I/O
return &wgDeviceWrapper{dev: wgDev}, nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
@@ -239,7 +357,68 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
return r
}
func (t *tun) Read(b []byte) (int, error) {
if t.wgDevice != nil {
// Use wireguard device which handles virtio headers internally
bufs := [][]byte{b}
sizes := make([]int, 1)
n, err := t.wgDevice.Read(bufs, sizes, 0)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.EOF
}
return sizes[0], nil
}
// Fallback: direct read from file (shouldn't happen in normal operation)
return t.ReadWriteCloser.Read(b)
}
// BatchRead reads multiple packets at once for improved performance
// bufs: slice of buffers to read into
// sizes: slice that will be filled with packet sizes
// Returns number of packets read
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
if t.wgDevice != nil {
return t.wgDevice.Read(bufs, sizes, 0)
}
// Fallback: single packet read
n, err := t.ReadWriteCloser.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// BatchSize returns the optimal number of packets to read/write in a batch
func (t *tun) BatchSize() int {
if t.wgDevice != nil {
return t.wgDevice.BatchSize()
}
return 1
}
func (t *tun) Write(b []byte) (int, error) {
if t.wgDevice != nil {
// Buffer b should have virtio header space (10 bytes) at the beginning
// The decrypted packet data starts at offset 10
// Pass the full buffer to WireGuard with offset=virtioNetHdrLen
bufs := [][]byte{b}
n, err := t.wgDevice.Write(bufs, VirtioNetHdrLen)
if err != nil {
return 0, err
}
if n == 0 {
return 0, io.ErrShortWrite
}
return len(b), nil
}
// Fallback: direct write (shouldn't happen in normal operation)
var nn int
maximum := len(b)
@@ -262,6 +441,22 @@ func (t *tun) Write(b []byte) (int, error) {
}
}
// WriteBatch writes multiple packets to the TUN device in a single syscall
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
if t.wgDevice != nil {
return t.wgDevice.Write(bufs, offset)
}
// Fallback: write individually (shouldn't happen in normal operation)
for i, buf := range bufs {
_, err := t.Write(buf)
if err != nil {
return i, err
}
}
return len(bufs), nil
}
func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device {
o[i] = byte(c)
@@ -674,6 +869,10 @@ func (t *tun) Close() error {
close(t.routeChan)
}
if t.wgDevice != nil {
_ = t.wgDevice.Close()
}
if t.ReadWriteCloser != nil {
_ = t.ReadWriteCloser.Close()
}

View File

@@ -390,10 +390,33 @@ func (t *tun) Name() string {
return t.Device
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
}
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
func (t *tun) BatchSize() int {
return 1
}
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()

View File

@@ -310,10 +310,33 @@ func (t *tun) Name() string {
return t.Device
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
}
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
func (t *tun) BatchSize() int {
return 1
}
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()

View File

@@ -132,6 +132,29 @@ func (t *TestTun) Read(b []byte) (int, error) {
return len(p), nil
}
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *TestTun) NewMultiQueueReader() (BatchReadWriter, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented")
}
func (t *TestTun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
func (t *TestTun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
func (t *TestTun) BatchSize() int {
return 1
}

View File

@@ -6,7 +6,6 @@ package overlay
import (
"crypto"
"fmt"
"io"
"net/netip"
"os"
"path/filepath"
@@ -234,10 +233,36 @@ func (t *winTun) Write(b []byte) (int, error) {
return t.tun.Write(b, 0)
}
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *winTun) NewMultiQueueReader() (BatchReadWriter, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
}
// BatchRead reads a single packet (batch size 1 for Windows)
func (t *winTun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := t.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// WriteBatch writes packets individually (no batching for Windows)
func (t *winTun) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := t.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
// BatchSize returns 1 for Windows (no batching)
func (t *winTun) BatchSize() int {
return 1
}
func (t *winTun) Close() error {
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
// so to be certain, just remove everything before destroying.

View File

@@ -46,10 +46,36 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
return routing.Gateways{routing.NewGateway(ip, 1)}
}
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (d *UserDevice) NewMultiQueueReader() (BatchReadWriter, error) {
return d, nil
}
// BatchRead reads a single packet (batch size 1 for UserDevice)
func (d *UserDevice) BatchRead(bufs [][]byte, sizes []int) (int, error) {
n, err := d.Read(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
// WriteBatch writes packets individually (no batching for UserDevice)
func (d *UserDevice) WriteBatch(bufs [][]byte, offset int) (int, error) {
for i, buf := range bufs {
_, err := d.Write(buf[offset:])
if err != nil {
return i, err
}
}
return len(bufs), nil
}
// BatchSize returns 1 for UserDevice (no batching)
func (d *UserDevice) BatchSize() int {
return 1
}
func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) {
return d.inboundReader, d.outboundWriter
}

View File

@@ -44,10 +44,7 @@ type Service struct {
}
func New(control *nebula.Control) (*Service, error) {
wait, err := control.Start()
if err != nil {
return nil, err
}
control.Start()
ctx := control.Context()
eg, ctx := errgroup.WithContext(ctx)
@@ -144,12 +141,6 @@ func New(control *nebula.Control) (*Service, error) {
}
})
// Add the nebula wait function to the group
eg.Go(func() error {
wait()
return nil
})
return &s, nil
}

View File

@@ -6,6 +6,7 @@ import (
"log"
"net"
"net/http"
_ "net/http/pprof"
"runtime"
"strconv"
"time"

View File

@@ -13,12 +13,21 @@ type EncReader func(
payload []byte,
)
type EncBatchReader func(
addrs []netip.AddrPort,
payloads [][]byte,
count int,
)
type Conn interface {
Rebind() error
LocalAddr() (netip.AddrPort, error)
ListenOut(r EncReader) error
ListenOut(r EncReader)
ListenOutBatch(r EncBatchReader)
WriteTo(b []byte, addr netip.AddrPort) error
WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error)
ReloadConfig(c *config.C)
BatchSize() int
Close() error
}
@@ -30,15 +39,24 @@ func (NoopConn) Rebind() error {
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
return netip.AddrPort{}, nil
}
func (NoopConn) ListenOut(_ EncReader) error {
return nil
func (NoopConn) ListenOut(_ EncReader) {
return
}
func (NoopConn) ListenOutBatch(_ EncBatchReader) {
return
}
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
return nil
}
func (NoopConn) WriteMulti(_ [][]byte, _ []netip.AddrPort) (int, error) {
return 0, nil
}
func (NoopConn) ReloadConfig(_ *config.C) {
return
}
func (NoopConn) BatchSize() int {
return 1
}
func (NoopConn) Close() error {
return nil
}

View File

@@ -140,6 +140,17 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
}
}
// WriteMulti sends multiple packets - fallback implementation without sendmmsg
func (u *StdConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
for i := range packets {
err := u.WriteTo(packets[i], addrs[i])
if err != nil {
return i, err
}
}
return len(packets), nil
}
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
a := u.UDPConn.LocalAddr()
@@ -165,7 +176,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() {
return func() {}
}
func (u *StdConn) ListenOut(r EncReader) error {
func (u *StdConn) ListenOut(r EncReader) {
buffer := make([]byte, MTU)
for {
@@ -173,7 +184,8 @@ func (u *StdConn) ListenOut(r EncReader) error {
n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil {
if errors.Is(err, net.ErrClosed) {
return err
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
u.l.WithError(err).Error("unexpected udp socket receive error")
@@ -183,6 +195,34 @@ func (u *StdConn) ListenOut(r EncReader) error {
}
}
// ListenOutBatch - fallback to single-packet reads for Darwin
func (u *StdConn) ListenOutBatch(r EncBatchReader) {
buffer := make([]byte, MTU)
addrs := make([]netip.AddrPort, 1)
payloads := make([][]byte, 1)
for {
// Just read one packet at a time and call batch callback with count=1
n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil {
if errors.Is(err, net.ErrClosed) {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
u.l.WithError(err).Error("unexpected udp socket receive error")
}
addrs[0] = netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port())
payloads[0] = buffer[:n]
r(addrs, payloads, 1)
}
}
func (u *StdConn) BatchSize() int {
return 1
}
func (u *StdConn) Rebind() error {
var err error
if u.isV4 {

View File

@@ -71,16 +71,56 @@ type rawMessage struct {
Len uint32
}
func (u *GenericConn) ListenOut(r EncReader) error {
func (u *GenericConn) ListenOut(r EncReader) {
buffer := make([]byte, MTU)
for {
// Just read one packet at a time
n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil {
return err
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
}
}
// ListenOutBatch - fallback to single-packet reads for generic platforms
func (u *GenericConn) ListenOutBatch(r EncBatchReader) {
buffer := make([]byte, MTU)
addrs := make([]netip.AddrPort, 1)
payloads := make([][]byte, 1)
for {
// Just read one packet at a time and call batch callback with count=1
n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
addrs[0] = netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port())
payloads[0] = buffer[:n]
r(addrs, payloads, 1)
}
}
// WriteMulti sends multiple packets - fallback implementation
func (u *GenericConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
for i := range packets {
err := u.WriteTo(packets[i], addrs[i])
if err != nil {
return i, err
}
}
return len(packets), nil
}
func (u *GenericConn) BatchSize() int {
return 1
}
func (u *GenericConn) Rebind() error {
return nil
}

View File

@@ -9,7 +9,6 @@ import (
"net"
"net/netip"
"syscall"
"time"
"unsafe"
"github.com/rcrowley/go-metrics"
@@ -18,13 +17,24 @@ import (
"golang.org/x/sys/unix"
)
var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500))
type StdConn struct {
sysFd int
isV4 bool
l *logrus.Logger
batch int
// Pre-allocated buffers for batch writes (sized for IPv6, works for both)
writeMsgs []rawMessage
writeIovecs []iovec
writeNames [][]byte
}
func maybeIPV4(ip net.IP) (net.IP, bool) {
ip4 := ip.To4()
if ip4 != nil {
return ip4, true
}
return ip, false
}
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
@@ -50,11 +60,6 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
}
}
// Set a read timeout
if err = unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &readTimeout); err != nil {
return nil, fmt.Errorf("unable to set SO_RCVTIMEO: %s", err)
}
var sa unix.Sockaddr
if ip.Is4() {
sa4 := &unix.SockaddrInet4{Port: port}
@@ -69,7 +74,26 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
return nil, fmt.Errorf("unable to bind to socket: %s", err)
}
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
c := &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}
// Pre-allocate write message structures for batching (sized for IPv6, works for both)
c.writeMsgs = make([]rawMessage, batch)
c.writeIovecs = make([]iovec, batch)
c.writeNames = make([][]byte, batch)
for i := range c.writeMsgs {
// Allocate for IPv6 size (larger than IPv4, works for both)
c.writeNames[i] = make([]byte, unix.SizeofSockaddrInet6)
// Point to the iovec in the slice
c.writeMsgs[i].Hdr.Iov = &c.writeIovecs[i]
c.writeMsgs[i].Hdr.Iovlen = 1
c.writeMsgs[i].Hdr.Name = &c.writeNames[i][0]
// Namelen will be set appropriately in writeMulti4/writeMulti6
}
return c, err
}
func (u *StdConn) Rebind() error {
@@ -118,7 +142,7 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
}
}
func (u *StdConn) ListenOut(r EncReader) error {
func (u *StdConn) ListenOut(r EncReader) {
var ip netip.Addr
msgs, buffers, names := u.PrepareRawMessages(u.batch)
@@ -127,12 +151,17 @@ func (u *StdConn) ListenOut(r EncReader) error {
read = u.ReadSingle
}
udpBatchHist := metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024))
for {
n, err := read(msgs)
if err != nil {
return err
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
udpBatchHist.Update(int64(n))
for i := 0; i < n; i++ {
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
if u.isV4 {
@@ -145,6 +174,46 @@ func (u *StdConn) ListenOut(r EncReader) error {
}
}
func (u *StdConn) ListenOutBatch(r EncBatchReader) {
var ip netip.Addr
msgs, buffers, names := u.PrepareRawMessages(u.batch)
read := u.ReadMulti
if u.batch == 1 {
read = u.ReadSingle
}
udpBatchHist := metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024))
// Pre-allocate slices for batch callback
addrs := make([]netip.AddrPort, u.batch)
payloads := make([][]byte, u.batch)
for {
n, err := read(msgs)
if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
udpBatchHist.Update(int64(n))
// Prepare batch data
for i := 0; i < n; i++ {
if u.isV4 {
ip, _ = netip.AddrFromSlice(names[i][4:8])
} else {
ip, _ = netip.AddrFromSlice(names[i][8:24])
}
addrs[i] = netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
payloads[i] = buffers[i][:msgs[i].Len]
}
// Call batch callback with all packets
r(addrs, payloads, n)
}
}
func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
for {
n, _, err := unix.Syscall6(
@@ -158,9 +227,6 @@ func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
)
if err != 0 {
if err == unix.EAGAIN || err == unix.EINTR {
continue
}
return 0, &net.OpError{Op: "recvmsg", Err: err}
}
@@ -182,9 +248,6 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
)
if err != 0 {
if err == unix.EAGAIN || err == unix.EINTR {
continue
}
return 0, &net.OpError{Op: "recvmmsg", Err: err}
}
@@ -199,6 +262,19 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
return u.writeTo6(b, ip)
}
func (u *StdConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
if len(packets) != len(addrs) {
return 0, fmt.Errorf("packets and addrs length mismatch")
}
if len(packets) == 0 {
return 0, nil
}
if u.isV4 {
return u.writeMulti4(packets, addrs)
}
return u.writeMulti6(packets, addrs)
}
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6
@@ -253,6 +329,123 @@ func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
}
}
func (u *StdConn) writeMulti4(packets [][]byte, addrs []netip.AddrPort) (int, error) {
sent := 0
for sent < len(packets) {
// Determine batch size based on remaining packets and buffer capacity
batchSize := len(packets) - sent
if batchSize > len(u.writeMsgs) {
batchSize = len(u.writeMsgs)
}
// Use pre-allocated buffers
msgs := u.writeMsgs[:batchSize]
iovecs := u.writeIovecs[:batchSize]
names := u.writeNames[:batchSize]
// Setup message structures for this batch
for i := 0; i < batchSize; i++ {
pktIdx := sent + i
if !addrs[pktIdx].Addr().Is4() {
return sent + i, ErrInvalidIPv6RemoteForSocket
}
// Setup the packet buffer
iovecs[i].Base = &packets[pktIdx][0]
iovecs[i].Len = uint(len(packets[pktIdx]))
// Setup the destination address
rsa := (*unix.RawSockaddrInet4)(unsafe.Pointer(&names[i][0]))
rsa.Family = unix.AF_INET
rsa.Addr = addrs[pktIdx].Addr().As4()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[pktIdx].Port())
// Set the appropriate address length for IPv4
msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet4
}
// Send this batch
nsent, _, err := unix.Syscall6(
unix.SYS_SENDMMSG,
uintptr(u.sysFd),
uintptr(unsafe.Pointer(&msgs[0])),
uintptr(batchSize),
0,
0,
0,
)
if err != 0 {
return sent + int(nsent), &net.OpError{Op: "sendmmsg", Err: err}
}
sent += int(nsent)
if int(nsent) < batchSize {
// Couldn't send all packets in batch, return what we sent
return sent, nil
}
}
return sent, nil
}
func (u *StdConn) writeMulti6(packets [][]byte, addrs []netip.AddrPort) (int, error) {
sent := 0
for sent < len(packets) {
// Determine batch size based on remaining packets and buffer capacity
batchSize := len(packets) - sent
if batchSize > len(u.writeMsgs) {
batchSize = len(u.writeMsgs)
}
// Use pre-allocated buffers
msgs := u.writeMsgs[:batchSize]
iovecs := u.writeIovecs[:batchSize]
names := u.writeNames[:batchSize]
// Setup message structures for this batch
for i := 0; i < batchSize; i++ {
pktIdx := sent + i
// Setup the packet buffer
iovecs[i].Base = &packets[pktIdx][0]
iovecs[i].Len = uint(len(packets[pktIdx]))
// Setup the destination address
rsa := (*unix.RawSockaddrInet6)(unsafe.Pointer(&names[i][0]))
rsa.Family = unix.AF_INET6
rsa.Addr = addrs[pktIdx].Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[pktIdx].Port())
// Set the appropriate address length for IPv6
msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet6
}
// Send this batch
nsent, _, err := unix.Syscall6(
unix.SYS_SENDMMSG,
uintptr(u.sysFd),
uintptr(unsafe.Pointer(&msgs[0])),
uintptr(batchSize),
0,
0,
0,
)
if err != 0 {
return sent + int(nsent), &net.OpError{Op: "sendmmsg", Err: err}
}
sent += int(nsent)
if int(nsent) < batchSize {
// Couldn't send all packets in batch, return what we sent
return sent, nil
}
}
return sent, nil
}
func (u *StdConn) ReloadConfig(c *config.C) {
b := c.GetInt("listen.read_buffer", 0)
if b > 0 {
@@ -310,6 +503,10 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
return nil
}
func (u *StdConn) BatchSize() int {
return u.batch
}
func (u *StdConn) Close() error {
return syscall.Close(u.sysFd)
}

View File

@@ -12,7 +12,7 @@ import (
type iovec struct {
Base *byte
Len uint32
Len uint
}
type msghdr struct {
@@ -40,7 +40,7 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{
{Base: &buffers[i][0], Len: uint32(len(buffers[i]))},
{Base: &buffers[i][0], Len: uint(len(buffers[i]))},
}
msgs[i].Hdr.Iov = &vs[0]

View File

@@ -12,7 +12,7 @@ import (
type iovec struct {
Base *byte
Len uint64
Len uint
}
type msghdr struct {
@@ -43,7 +43,7 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{
{Base: &buffers[i][0], Len: uint64(len(buffers[i]))},
{Base: &buffers[i][0], Len: uint(len(buffers[i]))},
}
msgs[i].Hdr.Iov = &vs[0]

View File

@@ -134,7 +134,7 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error {
return nil
}
func (u *RIOConn) ListenOut(r EncReader) error {
func (u *RIOConn) ListenOut(r EncReader) {
buffer := make([]byte, MTU)
for {
@@ -142,7 +142,8 @@ func (u *RIOConn) ListenOut(r EncReader) error {
n, rua, err := u.receive(buffer)
if err != nil {
if errors.Is(err, net.ErrClosed) {
return err
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
u.l.WithError(err).Error("unexpected udp socket receive error")
continue

View File

@@ -6,7 +6,6 @@ package udp
import (
"io"
"net/netip"
"os"
"sync/atomic"
"github.com/sirupsen/logrus"
@@ -107,16 +106,41 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
return nil
}
func (u *TesterConn) ListenOut(r EncReader) error {
func (u *TesterConn) ListenOut(r EncReader) {
for {
p, ok := <-u.RxPackets
if !ok {
return os.ErrClosed
return
}
r(p.From, p.Data)
}
}
func (u *TesterConn) ListenOutBatch(r EncBatchReader) {
addrs := make([]netip.AddrPort, 1)
payloads := make([][]byte, 1)
for {
p, ok := <-u.RxPackets
if !ok {
return
}
addrs[0] = p.From
payloads[0] = p.Data
r(addrs, payloads, 1)
}
}
func (u *TesterConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
for i := range packets {
err := u.WriteTo(packets[i], addrs[i])
if err != nil {
return i, err
}
}
return len(packets), nil
}
func (u *TesterConn) ReloadConfig(*config.C) {}
func NewUDPStatsEmitter(_ []Conn) func() {
@@ -128,6 +152,10 @@ func (u *TesterConn) LocalAddr() (netip.AddrPort, error) {
return u.Addr, nil
}
func (u *TesterConn) BatchSize() int {
return 1
}
func (u *TesterConn) Rebind() error {
return nil
}