Merge remote-tracking branch 'origin/master' into multiport

This commit is contained in:
Wade Simmons
2026-05-06 14:26:49 -04:00
138 changed files with 10562 additions and 4541 deletions

View File

@@ -5,15 +5,15 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net/netip"
"os"
"runtime"
"sync"
"sync/atomic"
"time"
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
@@ -30,7 +30,7 @@ type InterfaceConfig struct {
pki *PKI
Cipher string
Firewall *Firewall
ServeDns bool
DnsServer *dnsServer
HandshakeManager *HandshakeManager
lightHouse *LightHouse
connectionManager *connectionManager
@@ -47,7 +47,7 @@ type InterfaceConfig struct {
reQueryWait time.Duration
ConntrackCacheTimeout time.Duration
l *logrus.Logger
l *slog.Logger
}
type Interface struct {
@@ -58,7 +58,7 @@ type Interface struct {
firewall *Firewall
connectionManager *connectionManager
handshakeManager *HandshakeManager
serveDns bool
dnsServer *dnsServer
createTime time.Time
lightHouse *LightHouse
myBroadcastAddrsTable *bart.Lite
@@ -86,17 +86,25 @@ type Interface struct {
conntrackCacheTimeout time.Duration
ctx context.Context
writers []udp.Conn
readers []io.ReadWriteCloser
udpRaw *udp.RawConn
wg sync.WaitGroup
// fatalErr holds the first unexpected reader error that caused shutdown.
// nil means "no fatal error" (yet)
fatalErr atomic.Pointer[error]
// triggerShutdown is a function that will be run exactly once, when onFatal swaps something non-nil into fatalErr
triggerShutdown func()
udpRaw *udp.RawConn
multiPort MultiPortConfig
metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics
cachedPacketMetrics *cachedPacketMetrics
l *logrus.Logger
l *slog.Logger
}
type MultiPortConfig struct {
@@ -176,12 +184,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
cs := c.pki.getCertState()
ifce := &Interface{
ctx: ctx,
pki: c.pki,
hostMap: c.HostMap,
outside: c.Outside,
inside: c.Inside,
firewall: c.Firewall,
serveDns: c.ServeDns,
dnsServer: c.DnsServer,
handshakeManager: c.HandshakeManager,
createTime: time.Now(),
lightHouse: c.lightHouse,
@@ -222,18 +231,21 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
// activate creates the interface on the host. After the interface is created, any
// other services that want to bind listeners to its IP may do so successfully. However,
// the interface isn't going to process anything until run() is called.
func (f *Interface) activate() {
func (f *Interface) activate() error {
// actually turn on tun dev
addr, err := f.outside.LocalAddr()
if err != nil {
f.l.WithError(err).Error("Failed to get udp listen address")
f.l.Error("Failed to get udp listen address", "error", err)
}
f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks).
WithField("build", f.version).WithField("udpAddr", addr).
WithField("boringcrypto", boringEnabled()).
Info("Nebula interface is active")
f.l.Info("Nebula interface is active",
"interface", f.inside.Name(),
"networks", f.myVpnNetworks,
"build", f.version,
"udpAddr", addr,
"boringcrypto", boringEnabled(),
)
if f.routines > 1 {
if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() {
@@ -252,33 +264,58 @@ func (f *Interface) activate() {
if i > 0 {
reader, err = f.inside.NewMultiQueueReader()
if err != nil {
f.l.Fatal(err)
return err
}
}
f.readers[i] = reader
}
if err := f.inside.Activate(); err != nil {
f.wg.Add(1) // for us to wait on Close() to return
if err = f.inside.Activate(); err != nil {
f.wg.Done()
f.inside.Close()
f.l.Fatal(err)
return err
}
return nil
}
func (f *Interface) run() {
func (f *Interface) run() (func() error, error) {
// Launch n queues to read packets from udp
for i := 0; i < f.routines; i++ {
go f.listenOut(i)
f.wg.Go(func() {
f.listenOut(i)
})
}
// Launch n queues to read packets from tun dev
for i := 0; i < f.routines; i++ {
go f.listenIn(f.readers[i], i)
f.wg.Go(func() {
f.listenIn(f.readers[i], i)
})
}
return func() error {
f.wg.Wait()
if e := f.fatalErr.Load(); e != nil {
return *e
}
return nil
}, nil
}
// onFatal stores the first fatal reader error, and calls triggerShutdown if it was the first one
func (f *Interface) onFatal(err error) {
swapped := f.fatalErr.CompareAndSwap(nil, &err)
if !swapped {
return
}
if f.triggerShutdown != nil {
f.triggerShutdown()
}
}
func (f *Interface) listenOut(i int) {
runtime.LockOSThread()
var li udp.Conn
if i > 0 {
li = f.writers[i]
@@ -286,42 +323,47 @@ func (f *Interface) listenOut(i int) {
li = f.outside
}
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
lhh := f.lightHouse.NewRequestHandler()
plaintext := make([]byte, udp.MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get())
})
if err != nil && !f.closed.Load() {
f.l.Error("Error while reading inbound packet, closing", "error", err)
f.onFatal(err)
}
f.l.Debug("underlay reader is done", "reader", i)
}
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
runtime.LockOSThread()
packet := make([]byte, mtu)
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
for {
n, err := reader.Read(packet)
if err != nil {
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return
if !f.closed.Load() {
f.l.Error("Error while reading outbound packet, closing", "error", err, "reader", i)
f.onFatal(err)
}
f.l.WithError(err).Error("Error while reading outbound packet")
// This only seems to happen when something fatal happens to the fd, so exit.
os.Exit(2)
break
}
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get())
}
f.l.Debug("overlay reader is done", "reader", i)
}
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
@@ -341,7 +383,7 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) {
if initial || c.HasChanged("pki.disconnect_invalid") {
f.disconnectInvalid.Store(c.GetBool("pki.disconnect_invalid", true))
if !initial {
f.l.Infof("pki.disconnect_invalid changed to %v", f.disconnectInvalid.Load())
f.l.Info("pki.disconnect_invalid changed", "value", f.disconnectInvalid.Load())
}
}
}
@@ -355,7 +397,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
if err != nil {
f.l.WithError(err).Error("Error while creating firewall during reload")
f.l.Error("Error while creating firewall during reload", "error", err)
return
}
@@ -368,10 +410,11 @@ func (f *Interface) reloadFirewall(c *config.C) {
// If rulesVersion is back to zero, we have wrapped all the way around. Be
// safe and just reset conntrack in this case.
if fw.rulesVersion == 0 {
f.l.WithField("firewallHashes", fw.GetRuleHashes()).
WithField("oldFirewallHashes", oldFw.GetRuleHashes()).
WithField("rulesVersion", fw.rulesVersion).
Warn("firewall rulesVersion has overflowed, resetting conntrack")
f.l.Warn("firewall rulesVersion has overflowed, resetting conntrack",
"firewallHashes", fw.GetRuleHashes(),
"oldFirewallHashes", oldFw.GetRuleHashes(),
"rulesVersion", fw.rulesVersion,
)
} else {
fw.Conntrack = conntrack
}
@@ -379,10 +422,11 @@ func (f *Interface) reloadFirewall(c *config.C) {
f.firewall = fw
oldFw.Destroy()
f.l.WithField("firewallHashes", fw.GetRuleHashes()).
WithField("oldFirewallHashes", oldFw.GetRuleHashes()).
WithField("rulesVersion", fw.rulesVersion).
Info("New firewall has been installed")
f.l.Info("New firewall has been installed",
"firewallHashes", fw.GetRuleHashes(),
"oldFirewallHashes", oldFw.GetRuleHashes(),
"rulesVersion", fw.rulesVersion,
)
}
func (f *Interface) reloadSendRecvError(c *config.C) {
@@ -404,8 +448,7 @@ func (f *Interface) reloadSendRecvError(c *config.C) {
}
}
f.l.WithField("sendRecvError", f.sendRecvErrorConfig.String()).
Info("Loaded send_recv_error config")
f.l.Info("Loaded send_recv_error config", "sendRecvError", f.sendRecvErrorConfig.String())
}
}
@@ -428,8 +471,7 @@ func (f *Interface) reloadAcceptRecvError(c *config.C) {
}
}
f.l.WithField("acceptRecvError", f.acceptRecvErrorConfig.String()).
Info("Loaded accept_recv_error config")
f.l.Info("Loaded accept_recv_error config", "acceptRecvError", f.acceptRecvErrorConfig.String())
}
}
@@ -505,15 +547,23 @@ func (f *Interface) GetCertState() *CertState {
}
func (f *Interface) Close() error {
var errs []error
f.closed.Store(true)
for _, u := range f.writers {
// Release the udp readers
for i, u := range f.writers {
err := u.Close()
if err != nil {
f.l.WithError(err).Error("Error while closing udp socket")
f.l.Error("Error while closing udp socket", "error", err, "writer", i)
errs = append(errs, err)
}
}
// Release the tun device
return f.inside.Close()
// Release the tun device (closing the tun also closes all readers)
closeErr := f.inside.Close()
if closeErr != nil {
errs = append(errs, closeErr)
}
f.wg.Done()
return errors.Join(errs...)
}