Switch to slog, remove logrus (#1672)

This commit is contained in:
Nate Brown
2026-04-27 09:41:47 -05:00
committed by GitHub
parent 5f890dbc34
commit d0f02ba873
77 changed files with 2299 additions and 1338 deletions

View File

@@ -2,7 +2,21 @@ version: "2"
linters:
default: none
enable:
- sloglint
- testifylint
settings:
sloglint:
# Enforce key-value pair form for Info/Debug/Warn/Error/Log/With and
# the package-level slog equivalents. Use l.Log(ctx, level, ...) for
# custom levels instead of LogAttrs when you can.
#
# LogAttrs is also flagged by this rule because it takes ...slog.Attr;
# the few legitimate sites (where attrs is built up as a []slog.Attr)
# carry a //nolint:sloglint with rationale.
kv-only: true
# no-mixed-args is on by default: forbids mixing kv and attrs in one call.
# discard-handler is on by default (since Go 1.24): suggests
# slog.DiscardHandler over slog.NewTextHandler(io.Discard, nil).
exclusions:
generated: lax
presets:

38
bits.go
View File

@@ -1,8 +1,10 @@
package nebula
import (
"context"
"log/slog"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
)
type Bits struct {
@@ -30,7 +32,7 @@ func NewBits(bits uint64) *Bits {
return b
}
func (b *Bits) Check(l *logrus.Logger, i uint64) bool {
func (b *Bits) Check(l *slog.Logger, i uint64) bool {
// If i is the next number, return true.
if i > b.current {
return true
@@ -42,13 +44,16 @@ func (b *Bits) Check(l *logrus.Logger, i uint64) bool {
}
// Not within the window
if l.Level >= logrus.DebugLevel {
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
if l.Enabled(context.Background(), slog.LevelDebug) {
l.Debug("rejected a packet (top)",
"current", b.current,
"incoming", i,
)
}
return false
}
func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
func (b *Bits) Update(l *slog.Logger, i uint64) bool {
// If i is the next number, return true and update current.
if i == b.current+1 {
// Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter
@@ -87,9 +92,13 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
// Check to see if it's a duplicate
if i > b.current-b.length || i < b.length && b.current < b.length {
if b.current == i || b.bits[i%b.length] == true {
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}).
Debug("Receive window")
if l.Enabled(context.Background(), slog.LevelDebug) {
l.Debug("Receive window",
"accepted", false,
"currentCounter", b.current,
"incomingCounter", i,
"reason", "duplicate",
)
}
b.dupeCounter.Inc(1)
return false
@@ -101,12 +110,13 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
// In all other cases, fail and don't change current.
b.outOfWindowCounter.Inc(1)
if l.Level >= logrus.DebugLevel {
l.WithField("accepted", false).
WithField("currentCounter", b.current).
WithField("incomingCounter", i).
WithField("reason", "nonsense").
Debug("Receive window")
if l.Enabled(context.Background(), slog.LevelDebug) {
l.Debug("Receive window",
"accepted", false,
"currentCounter", b.current,
"incomingCounter", i,
"reason", "nonsense",
)
}
return false
}

View File

@@ -3,8 +3,15 @@
package main
import "github.com/sirupsen/logrus"
import (
"log/slog"
"os"
func HookLogger(l *logrus.Logger) {
// Do nothing, let the logs flow to stdout/stderr
"github.com/slackhq/nebula/logging"
)
// newPlatformLogger returns a *slog.Logger that writes to stdout. Non-Windows
// platforms have no special sink to integrate with.
func newPlatformLogger() *slog.Logger {
return logging.NewLogger(os.Stdout)
}

View File

@@ -1,54 +1,86 @@
package main
import (
"fmt"
"io/ioutil"
"os"
"context"
"log/slog"
"strings"
"sync"
"github.com/kardianos/service"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/logging"
)
// HookLogger routes the logrus logs through the service logger so that they end up in the Windows Event Viewer
// logrus output will be discarded
func HookLogger(l *logrus.Logger) {
l.AddHook(newLogHook(logger))
l.SetOutput(ioutil.Discard)
// newPlatformLogger returns a *slog.Logger that routes every log record
// through the Windows service logger so records end up in the Windows
// Event Log. All the heavy lifting (level management, format swap,
// timestamp toggle, WithAttrs/WithGroup) comes from logging.NewHandler;
// this file only contributes:
//
// - an io.Writer that forwards each formatted line to the service
// logger at the current record's Event Log severity, and
// - a thin severityTag that embeds *logging.Handler and overrides
// only Handle / WithAttrs / WithGroup, so Event Viewer's severity
// column and severity-based filters keep working the way they did
// before the slog migration.
//
// Format (text vs json) is carried by the embedded *logging.Handler, so
// logging.format: json in config still produces JSON lines in Event
// Viewer, same as the pre-slog logrus setup.
func newPlatformLogger() *slog.Logger {
w := &eventLogWriter{}
return slog.New(&severityTag{Handler: logging.NewHandler(w), w: w})
}
type logHook struct {
sl service.Logger
// eventLogWriter forwards slog-formatted lines to the Windows service
// logger at the severity most recently stashed by severityTag.Handle.
// The mutex serializes the stash + inner.Handle + Write cycle per record
// across all concurrent goroutines; slog's builtin text/json handlers
// each hold their own mutex around Write, but that only protects the
// Write call itself, not our stash-then-handle sequence.
type eventLogWriter struct {
mu sync.Mutex
level slog.Level
}
func newLogHook(sl service.Logger) *logHook {
return &logHook{sl: sl}
}
func (h *logHook) Fire(entry *logrus.Entry) error {
line, err := entry.String()
if err != nil {
fmt.Fprintf(os.Stderr, "Unable to read entry, %v", err)
return err
}
switch entry.Level {
case logrus.PanicLevel:
return h.sl.Error(line)
case logrus.FatalLevel:
return h.sl.Error(line)
case logrus.ErrorLevel:
return h.sl.Error(line)
case logrus.WarnLevel:
return h.sl.Warning(line)
case logrus.InfoLevel:
return h.sl.Info(line)
case logrus.DebugLevel:
return h.sl.Info(line)
func (w *eventLogWriter) Write(p []byte) (int, error) {
line := strings.TrimRight(string(p), "\n")
switch {
case w.level >= slog.LevelError:
return len(p), logger.Error(line)
case w.level >= slog.LevelWarn:
return len(p), logger.Warning(line)
default:
return nil
return len(p), logger.Info(line)
}
}
func (h *logHook) Levels() []logrus.Level {
return logrus.AllLevels
// severityTag embeds *logging.Handler to pick up everything it does for
// free (Enabled, SetLevel, GetLevel, SetFormat, GetFormat,
// SetDisableTimestamp) and overrides only Handle / WithAttrs / WithGroup
// so each record's slog.Level is stashed on the writer before formatting
// and so derived handlers stay wrapped as severityTag rather than
// downgrading to bare *logging.Handler.
type severityTag struct {
*logging.Handler
w *eventLogWriter
}
func (s *severityTag) Handle(ctx context.Context, r slog.Record) error {
s.w.mu.Lock()
defer s.w.mu.Unlock()
s.w.level = r.Level
return s.Handler.Handle(ctx, r)
}
func (s *severityTag) WithAttrs(attrs []slog.Attr) slog.Handler {
if len(attrs) == 0 {
return s
}
return &severityTag{Handler: s.Handler.WithAttrs(attrs).(*logging.Handler), w: s.w}
}
func (s *severityTag) WithGroup(name string) slog.Handler {
if name == "" {
return s
}
return &severityTag{Handler: s.Handler.WithGroup(name).(*logging.Handler), w: s.w}
}

View File

@@ -7,9 +7,9 @@ import (
"runtime/debug"
"strings"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/logging"
"github.com/slackhq/nebula/util"
)
@@ -50,12 +50,11 @@ func main() {
os.Exit(0)
}
l := logrus.New()
l.Out = os.Stdout
l := logging.NewLogger(os.Stdout)
if *serviceFlag != "" {
if err := doService(configPath, configTest, Build, serviceFlag); err != nil {
l.WithError(err).Error("Service command failed")
l.Error("Service command failed", "error", err)
os.Exit(1)
}
return
@@ -74,6 +73,16 @@ func main() {
os.Exit(1)
}
if err := logging.ApplyConfig(l, c); err != nil {
fmt.Printf("failed to apply logging config: %s", err)
os.Exit(1)
}
c.RegisterReloadCallback(func(c *config.C) {
if err := logging.ApplyConfig(l, c); err != nil {
l.Error("Failed to reconfigure logger on reload", "error", err)
}
})
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
if err != nil {
util.LogWithContextIfNeeded("Failed to start", err, l)
@@ -90,7 +99,7 @@ func main() {
go ctrl.ShutdownBlock()
if err := wait(); err != nil {
l.WithError(err).Error("Nebula stopped due to fatal error")
l.Error("Nebula stopped due to fatal error", "error", err)
os.Exit(2)
}

View File

@@ -7,9 +7,9 @@ import (
"path/filepath"
"github.com/kardianos/service"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/logging"
)
var logger service.Logger
@@ -25,8 +25,7 @@ func (p *program) Start(s service.Service) error {
// Start should not block.
logger.Info("Nebula service starting.")
l := logrus.New()
HookLogger(l)
l := newPlatformLogger()
c := config.NewC(l)
err := c.Load(*p.configPath)
@@ -34,6 +33,15 @@ func (p *program) Start(s service.Service) error {
return fmt.Errorf("failed to load config: %s", err)
}
if err := logging.ApplyConfig(l, c); err != nil {
return fmt.Errorf("failed to apply logging config: %s", err)
}
c.RegisterReloadCallback(func(c *config.C) {
if err := logging.ApplyConfig(l, c); err != nil {
l.Error("Failed to reconfigure logger on reload", "error", err)
}
})
p.control, err = nebula.Main(c, *p.configTest, Build, l, nil)
if err != nil {
return err
@@ -85,7 +93,7 @@ func doService(configPath *string, configTest *bool, build string, serviceFlag *
// Here are what the different loggers are doing:
// - `log` is the standard go log utility, meant to be used while the process is still attached to stdout/stderr
// - `logger` is the service log utility that may be attached to a special place depending on OS (Windows will have it attached to the event log)
// - above, in `Run` we create a `logrus.Logger` which is what nebula expects to use
// - in program.Start we build a *slog.Logger via newPlatformLogger; on non-Windows that is a stdout-backed slog logger, on Windows it routes records through the service logger
s, err := service.New(prg, svcConfig)
if err != nil {
return err

View File

@@ -7,9 +7,9 @@ import (
"runtime/debug"
"strings"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/logging"
"github.com/slackhq/nebula/util"
)
@@ -55,8 +55,7 @@ func main() {
os.Exit(1)
}
l := logrus.New()
l.Out = os.Stdout
l := logging.NewLogger(os.Stdout)
c := config.NewC(l)
err := c.Load(*configPath)
@@ -65,6 +64,16 @@ func main() {
os.Exit(1)
}
if err := logging.ApplyConfig(l, c); err != nil {
fmt.Printf("failed to apply logging config: %s", err)
os.Exit(1)
}
c.RegisterReloadCallback(func(c *config.C) {
if err := logging.ApplyConfig(l, c); err != nil {
l.Error("Failed to reconfigure logger on reload", "error", err)
}
})
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
if err != nil {
util.LogWithContextIfNeeded("Failed to start", err, l)
@@ -82,7 +91,7 @@ func main() {
notifyReady(l)
if err := wait(); err != nil {
l.WithError(err).Error("Nebula stopped due to fatal error")
l.Error("Nebula stopped due to fatal error", "error", err)
os.Exit(2)
}

View File

@@ -1,11 +1,10 @@
package main
import (
"log/slog"
"net"
"os"
"time"
"github.com/sirupsen/logrus"
)
// SdNotifyReady tells systemd the service is ready and dependent services can now be started
@@ -13,30 +12,30 @@ import (
// https://www.freedesktop.org/software/systemd/man/systemd.service.html
const SdNotifyReady = "READY=1"
func notifyReady(l *logrus.Logger) {
func notifyReady(l *slog.Logger) {
sockName := os.Getenv("NOTIFY_SOCKET")
if sockName == "" {
l.Debugln("NOTIFY_SOCKET systemd env var not set, not sending ready signal")
l.Debug("NOTIFY_SOCKET systemd env var not set, not sending ready signal")
return
}
conn, err := net.DialTimeout("unixgram", sockName, time.Second)
if err != nil {
l.WithError(err).Error("failed to connect to systemd notification socket")
l.Error("failed to connect to systemd notification socket", "error", err)
return
}
defer conn.Close()
err = conn.SetWriteDeadline(time.Now().Add(time.Second))
if err != nil {
l.WithError(err).Error("failed to set the write deadline for the systemd notification socket")
l.Error("failed to set the write deadline for the systemd notification socket", "error", err)
return
}
if _, err = conn.Write([]byte(SdNotifyReady)); err != nil {
l.WithError(err).Error("failed to signal the systemd notification socket")
l.Error("failed to signal the systemd notification socket", "error", err)
return
}
l.Debugln("notified systemd the service is ready")
l.Debug("notified systemd the service is ready")
}

View File

@@ -3,8 +3,8 @@
package main
import "github.com/sirupsen/logrus"
import "log/slog"
func notifyReady(_ *logrus.Logger) {
func notifyReady(_ *slog.Logger) {
// No init service to notify
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"log/slog"
"math"
"os"
"os/signal"
@@ -16,7 +17,6 @@ import (
"time"
"dario.cat/mergo"
"github.com/sirupsen/logrus"
"go.yaml.in/yaml/v3"
)
@@ -26,11 +26,11 @@ type C struct {
Settings map[string]any
oldSettings map[string]any
callbacks []func(*C)
l *logrus.Logger
l *slog.Logger
reloadLock sync.Mutex
}
func NewC(l *logrus.Logger) *C {
func NewC(l *slog.Logger) *C {
return &C{
Settings: make(map[string]any),
l: l,
@@ -107,12 +107,18 @@ func (c *C) HasChanged(k string) bool {
newVals, err := yaml.Marshal(nv)
if err != nil {
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
c.l.Error("Error while marshaling new config",
"config_path", k,
"error", err,
)
}
oldVals, err := yaml.Marshal(ov)
if err != nil {
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
c.l.Error("Error while marshaling old config",
"config_path", k,
"error", err,
)
}
return string(newVals) != string(oldVals)
@@ -154,7 +160,10 @@ func (c *C) ReloadConfig() {
err := c.Load(c.path)
if err != nil {
c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
c.l.Error("Error occurred while reloading config",
"config_path", c.path,
"error", err,
)
return
}

View File

@@ -5,13 +5,13 @@ import (
"context"
"encoding/binary"
"fmt"
"log/slog"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
@@ -47,10 +47,10 @@ type connectionManager struct {
metricsTxPunchy metrics.Counter
l *logrus.Logger
l *slog.Logger
}
func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
func newConnectionManagerFromConfig(l *slog.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
cm := &connectionManager{
hostMap: hm,
l: l,
@@ -85,9 +85,10 @@ func (cm *connectionManager) reload(c *config.C, initial bool) {
old := cm.getInactivityTimeout()
cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute)))
if !initial {
cm.l.WithField("oldDuration", old).
WithField("newDuration", cm.getInactivityTimeout()).
Info("Inactivity timeout has changed")
cm.l.Info("Inactivity timeout has changed",
"oldDuration", old,
"newDuration", cm.getInactivityTimeout(),
)
}
}
@@ -95,9 +96,10 @@ func (cm *connectionManager) reload(c *config.C, initial bool) {
old := cm.dropInactive.Load()
cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false))
if !initial {
cm.l.WithField("oldBool", old).
WithField("newBool", cm.dropInactive.Load()).
Info("Drop inactive setting has changed")
cm.l.Info("Drop inactive setting has changed",
"oldBool", old,
"newBool", cm.dropInactive.Load(),
)
}
}
}
@@ -256,7 +258,7 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
var err error
index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested)
if err != nil {
cm.l.WithError(err).Error("failed to migrate relay to new hostinfo")
cm.l.Error("failed to migrate relay to new hostinfo", "error", err)
continue
}
switch r.Type {
@@ -304,16 +306,16 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
msg, err := req.Marshal()
if err != nil {
cm.l.WithError(err).Error("failed to marshal Control message to migrate relay")
cm.l.Error("failed to marshal Control message to migrate relay", "error", err)
} else {
cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
cm.l.WithFields(logrus.Fields{
"relayFrom": req.RelayFromAddr,
"relayTo": req.RelayToAddr,
"initiatorRelayIndex": req.InitiatorRelayIndex,
"responderRelayIndex": req.ResponderRelayIndex,
"vpnAddrs": newhostinfo.vpnAddrs}).
Info("send CreateRelayRequest")
cm.l.Info("send CreateRelayRequest",
"relayFrom", req.RelayFromAddr,
"relayTo", req.RelayToAddr,
"initiatorRelayIndex", req.InitiatorRelayIndex,
"responderRelayIndex", req.ResponderRelayIndex,
"vpnAddrs", newhostinfo.vpnAddrs,
)
}
}
}
@@ -325,7 +327,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
hostinfo := cm.hostMap.Indexes[localIndex]
if hostinfo == nil {
cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap")
cm.l.Debug("Not found in hostmap", "localIndex", localIndex)
return doNothing, nil, nil
}
@@ -345,10 +347,10 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
// A hostinfo is determined alive if there is incoming traffic
if inTraffic {
decision := doNothing
if cm.l.Level >= logrus.DebugLevel {
hostinfo.logger(cm.l).
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
Debug("Tunnel status")
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(cm.l).Debug("Tunnel status",
"tunnelCheck", m{"state": "alive", "method": "passive"},
)
}
hostinfo.pendingDeletion.Store(false)
@@ -375,9 +377,9 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
if hostinfo.pendingDeletion.Load() {
// We have already sent a test packet and nothing was returned, this hostinfo is dead
hostinfo.logger(cm.l).
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
Info("Tunnel status")
hostinfo.logger(cm.l).Info("Tunnel status",
"tunnelCheck", m{"state": "dead", "method": "active"},
)
return deleteTunnel, hostinfo, nil
}
@@ -388,10 +390,10 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
inactiveFor, isInactive := cm.isInactive(hostinfo, now)
if isInactive {
// Tunnel is inactive, tear it down
hostinfo.logger(cm.l).
WithField("inactiveDuration", inactiveFor).
WithField("primary", mainHostInfo).
Info("Dropping tunnel due to inactivity")
hostinfo.logger(cm.l).Info("Dropping tunnel due to inactivity",
"inactiveDuration", inactiveFor,
"primary", mainHostInfo,
)
return closeTunnel, hostinfo, primary
}
@@ -410,18 +412,18 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
cm.sendPunch(hostinfo)
}
if cm.l.Level >= logrus.DebugLevel {
hostinfo.logger(cm.l).
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
Debug("Tunnel status")
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(cm.l).Debug("Tunnel status",
"tunnelCheck", m{"state": "testing", "method": "active"},
)
}
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
decision = sendTestPacket
} else {
if cm.l.Level >= logrus.DebugLevel {
hostinfo.logger(cm.l).Debugf("Hostinfo sadness")
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(cm.l).Debug("Hostinfo sadness")
}
}
@@ -493,14 +495,16 @@ func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostI
return false //cert is still valid! yay!
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
// Block listed certificates should always be disconnected
hostinfo.logger(cm.l).WithError(err).
WithField("fingerprint", remoteCert.Fingerprint).
Info("Remote certificate is blocked, tearing down the tunnel")
hostinfo.logger(cm.l).Info("Remote certificate is blocked, tearing down the tunnel",
"error", err,
"fingerprint", remoteCert.Fingerprint,
)
return true
} else if cm.intf.disconnectInvalid.Load() {
hostinfo.logger(cm.l).WithError(err).
WithField("fingerprint", remoteCert.Fingerprint).
Info("Remote certificate is no longer valid, tearing down the tunnel")
hostinfo.logger(cm.l).Info("Remote certificate is no longer valid, tearing down the tunnel",
"error", err,
"fingerprint", remoteCert.Fingerprint,
)
return true
} else {
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
@@ -539,10 +543,11 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
curCrtVersion := curCrt.Version()
myCrt := cs.getCertificate(curCrtVersion)
if myCrt == nil {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("version", curCrtVersion).
WithField("reason", "local certificate removed").
Info("Re-handshaking with remote")
cm.l.Info("Re-handshaking with remote",
"vpnAddrs", hostinfo.vpnAddrs,
"version", curCrtVersion,
"reason", "local certificate removed",
)
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return
}
@@ -550,11 +555,12 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("version", curCrtVersion).
WithField("peerVersion", peerCrt.Certificate.Version()).
WithField("reason", "local certificate version lower than peer, attempting to correct").
Info("Re-handshaking with remote")
cm.l.Info("Re-handshaking with remote",
"vpnAddrs", hostinfo.vpnAddrs,
"version", curCrtVersion,
"peerVersion", peerCrt.Certificate.Version(),
"reason", "local certificate version lower than peer, attempting to correct",
)
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
})
@@ -562,17 +568,19 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
}
}
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "local certificate is not current").
Info("Re-handshaking with remote")
cm.l.Info("Re-handshaking with remote",
"vpnAddrs", hostinfo.vpnAddrs,
"reason", "local certificate is not current",
)
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return
}
if curCrtVersion < cs.initiatingVersion {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "current cert version < pki.initiatingVersion").
Info("Re-handshaking with remote")
cm.l.Info("Re-handshaking with remote",
"vpnAddrs", hostinfo.vpnAddrs,
"reason", "current cert version < pki.initiatingVersion",
)
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return

View File

@@ -10,6 +10,7 @@ import (
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/overlaytest"
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
@@ -52,7 +53,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &test.NoopTun{},
inside: &overlaytest.NoopTun{},
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
@@ -63,9 +64,9 @@ func Test_NewConnectionManagerTest(t *testing.T) {
ifce.pki.cs.Store(cs)
// Create manager
conf := config.NewC(l)
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
conf := config.NewC(test.NewLogger())
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
nc.intf = ifce
p := []byte("")
nb := make([]byte, 12, 12)
@@ -135,7 +136,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &test.NoopTun{},
inside: &overlaytest.NoopTun{},
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
@@ -146,9 +147,9 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
ifce.pki.cs.Store(cs)
// Create manager
conf := config.NewC(l)
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
conf := config.NewC(test.NewLogger())
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
nc.intf = ifce
p := []byte("")
nb := make([]byte, 12, 12)
@@ -220,7 +221,7 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &test.NoopTun{},
inside: &overlaytest.NoopTun{},
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
@@ -231,12 +232,12 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
ifce.pki.cs.Store(cs)
// Create manager
conf := config.NewC(l)
conf := config.NewC(test.NewLogger())
conf.Settings["tunnels"] = map[string]any{
"drop_inactive": true,
}
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
assert.True(t, nc.dropInactive.Load())
nc.intf = ifce
@@ -347,7 +348,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &test.NoopTun{},
inside: &overlaytest.NoopTun{},
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
@@ -360,9 +361,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
ifce.disconnectInvalid.Store(true)
// Create manager
conf := config.NewC(l)
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
conf := config.NewC(test.NewLogger())
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
nc.intf = ifce
ifce.connectionManager = nc

View File

@@ -8,7 +8,6 @@ import (
"sync/atomic"
"github.com/flynn/noise"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/noiseutil"
)
@@ -27,7 +26,7 @@ type ConnectionState struct {
writeLock sync.Mutex
}
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
func NewConnectionState(cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
var dhFunc noise.DHFunc
switch crt.Curve() {
case cert.Curve_CURVE25519:

View File

@@ -3,13 +3,13 @@ package nebula
import (
"context"
"errors"
"log/slog"
"net/netip"
"os"
"os/signal"
"sync"
"syscall"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/overlay"
@@ -46,7 +46,7 @@ type Control struct {
state RunState
f *Interface
l *logrus.Logger
l *slog.Logger
ctx context.Context
cancel context.CancelFunc
sshStart func()
@@ -151,7 +151,7 @@ func (c *Control) Stop() {
c.CloseAllTunnels(false)
if err := c.f.Close(); err != nil {
c.l.WithError(err).Error("Close interface failed")
c.l.Error("Close interface failed", "error", err)
}
c.stateLock.Lock()
c.state = StateStopped
@@ -166,7 +166,7 @@ func (c *Control) ShutdownBlock() {
rawSig := <-sigChan
sig := rawSig.String()
c.l.WithField("signal", sig).Info("Caught signal, shutting down")
c.l.Info("Caught signal, shutting down", "signal", sig)
c.Stop()
}
@@ -303,8 +303,10 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
c.f.closeTunnel(h)
c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote).
Debug("Sending close tunnel message")
c.l.Debug("Sending close tunnel message",
"vpnAddrs", h.vpnAddrs,
"udpAddr", h.remote,
)
closed++
}

View File

@@ -6,7 +6,6 @@ import (
"reflect"
"testing"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
@@ -83,7 +82,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
f: &Interface{
hostMap: hm,
},
l: logrus.New(),
l: test.NewLogger(),
}
thi := c.GetHostInfoByVpnAddr(vpnIp, false)

View File

@@ -3,6 +3,7 @@ package nebula
import (
"context"
"fmt"
"log/slog"
"net"
"net/netip"
"strconv"
@@ -12,13 +13,12 @@ import (
"github.com/gaissmai/bart"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
)
type dnsServer struct {
sync.RWMutex
l *logrus.Logger
l *slog.Logger
ctx context.Context
dnsMap4 map[string]netip.Addr
dnsMap6 map[string]netip.Addr
@@ -55,7 +55,7 @@ type dnsServer struct {
// they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel
// watcher that tears the listener down on nebula shutdown. The returned
// pointer is always non-nil, even on error.
func newDnsServerFromConfig(ctx context.Context, l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) {
func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) {
ds := &dnsServer{
l: l,
ctx: ctx,
@@ -69,7 +69,7 @@ func newDnsServerFromConfig(ctx context.Context, l *logrus.Logger, cs *CertState
c.RegisterReloadCallback(func(c *config.C) {
if err := ds.reload(c, false); err != nil {
l.WithError(err).Error("Failed to reload DNS responder from config")
ds.l.Error("Failed to reload DNS responder from config", "error", err)
}
})
@@ -145,7 +145,7 @@ func (d *dnsServer) shutdownServer(srv *dns.Server, started chan struct{}, reaso
<-started
}
if err := srv.Shutdown(); err != nil {
d.l.WithError(err).WithField("reason", reason).Warn("Failed to shut down the DNS responder")
d.l.Warn("Failed to shut down the DNS responder", "reason", reason, "error", err)
}
}
@@ -188,7 +188,7 @@ func (d *dnsServer) Start() {
}
}()
d.l.WithField("dnsListener", addr).Info("Starting DNS responder")
d.l.Info("Starting DNS responder", "dnsListener", addr)
err := server.ListenAndServe()
close(done)
@@ -201,7 +201,7 @@ func (d *dnsServer) Start() {
}
if err != nil {
d.l.WithError(err).Warn("Failed to run the DNS responder")
d.l.Warn("Failed to run the DNS responder", "error", err)
}
}
@@ -314,6 +314,7 @@ func (d *dnsServer) isSelfNebulaOrLocalhost(addr string) bool {
}
func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
debugEnabled := d.l.Enabled(context.Background(), slog.LevelDebug)
// Per RFC 2308 §2.2, a name that exists but has no record of the requested
// type must be answered with NOERROR and an empty answer section (NODATA),
// not NXDOMAIN (RFC 2308 §2.1), which is reserved for names that do not
@@ -323,7 +324,9 @@ func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
switch q.Qtype {
case dns.TypeA, dns.TypeAAAA:
qType := dns.TypeToString[q.Qtype]
d.l.Debugf("Query for %s %s", qType, q.Name)
if debugEnabled {
d.l.Debug("DNS query", "type", qType, "name", q.Name)
}
ip, nameExists := d.Query(q.Qtype, q.Name)
if nameExists {
anyNameExists = true
@@ -339,7 +342,9 @@ func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
return
}
d.l.Debugf("Query for TXT %s", q.Name)
if debugEnabled {
d.l.Debug("DNS query", "type", "TXT", "name", q.Name)
}
ip := d.QueryCert(q.Name)
if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))

View File

@@ -2,7 +2,7 @@ package nebula
import (
"context"
"io"
"log/slog"
"net"
"net/netip"
"strconv"
@@ -10,7 +10,6 @@ import (
"time"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -30,7 +29,7 @@ func (stubDNSWriter) TsigTimersOnly(bool) {}
func (stubDNSWriter) Hijack() {}
func TestParsequery(t *testing.T) {
l := logrus.New()
l := slog.New(slog.DiscardHandler)
hostMap := &HostMap{}
ds := &dnsServer{
l: l,
@@ -137,10 +136,9 @@ func Test_getDnsServerAddr(t *testing.T) {
func newTestDnsServer(t *testing.T) (*dnsServer, *config.C) {
t.Helper()
l := logrus.New()
l.Out = io.Discard
sl := slog.New(slog.DiscardHandler)
ds := &dnsServer{
l: l,
l: sl,
ctx: context.Background(),
dnsMap4: make(map[string]netip.Addr),
dnsMap6: make(map[string]netip.Addr),
@@ -148,7 +146,7 @@ func newTestDnsServer(t *testing.T) (*dnsServer, *config.C) {
}
ds.mux = dns.NewServeMux()
ds.mux.HandleFunc(".", ds.handleDnsRequest)
return ds, config.NewC(l)
return ds, config.NewC(nil)
}
func setDnsConfig(c *config.C, host string, port string, amLighthouse, serveDns bool) {

View File

@@ -11,7 +11,6 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
@@ -749,7 +748,6 @@ func TestStage1RaceRelays2(t *testing.T) {
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
l := NewTestLogger()
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
@@ -771,49 +769,41 @@ func TestStage1RaceRelays2(t *testing.T) {
theirControl.Start()
r.Log("Get a tunnel between me and relay")
l.Info("Get a tunnel between me and relay")
assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
r.Log("Get a tunnel between them and relay")
l.Info("Get a tunnel between them and relay")
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
r.Log("Trigger a handshake from both them and me via relay to them and me")
l.Info("Trigger a handshake from both them and me via relay to them and me")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
r.Log("Wait for a packet from them to me")
l.Info("Wait for a packet from them to me; myControl")
r.Log("Wait for a packet from them to me; myControl")
r.RouteForAllUntilTxTun(myControl)
l.Info("Wait for a packet from them to me; theirControl")
r.Log("Wait for a packet from them to me; theirControl")
r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
l.Info("Assert the tunnel works")
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
t.Log("Wait until we remove extra tunnels")
l.Info("Wait until we remove extra tunnels")
l.WithFields(
logrus.Fields{
"myControl": len(myControl.GetHostmap().Indexes),
"theirControl": len(theirControl.GetHostmap().Indexes),
"relayControl": len(relayControl.GetHostmap().Indexes),
}).Info("Waiting for hostinfos to be removed...")
t.Logf("Waiting for hostinfos to be removed... myControl=%d theirControl=%d relayControl=%d",
len(myControl.GetHostmap().Indexes),
len(theirControl.GetHostmap().Indexes),
len(relayControl.GetHostmap().Indexes),
)
hostInfos := len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes)
retries := 60
for hostInfos > 6 && retries > 0 {
hostInfos = len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes)
l.WithFields(
logrus.Fields{
"myControl": len(myControl.GetHostmap().Indexes),
"theirControl": len(theirControl.GetHostmap().Indexes),
"relayControl": len(relayControl.GetHostmap().Indexes),
}).Info("Waiting for hostinfos to be removed...")
t.Logf("Waiting for hostinfos to be removed... myControl=%d theirControl=%d relayControl=%d",
len(myControl.GetHostmap().Indexes),
len(theirControl.GetHostmap().Indexes),
len(relayControl.GetHostmap().Indexes),
)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second)
@@ -821,7 +811,6 @@ func TestStage1RaceRelays2(t *testing.T) {
}
r.Log("Assert the tunnel works")
l.Info("Assert the tunnel works")
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
myControl.Stop()

View File

@@ -4,7 +4,6 @@
package e2e
import (
"fmt"
"io"
"net/netip"
"os"
@@ -12,15 +11,18 @@ import (
"testing"
"time"
"log/slog"
"dario.cat/mergo"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/e2e/router"
"github.com/slackhq/nebula/logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.yaml.in/yaml/v3"
@@ -132,8 +134,7 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
"port": udpAddr.Port(),
},
"logging": m{
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name),
"level": l.Level.String(),
"level": testLogLevelName(),
},
"timers": m{
"pending_deletion_interval": 2,
@@ -234,8 +235,7 @@ func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, o
"port": udpAddr.Port(),
},
"logging": m{
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
"level": l.Level.String(),
"level": testLogLevelName(),
},
"timers": m{
"pending_deletion_interval": 2,
@@ -379,24 +379,32 @@ func getAddrs(ns []netip.Prefix) []netip.Addr {
return a
}
func NewTestLogger() *logrus.Logger {
l := logrus.New()
func NewTestLogger() *slog.Logger {
v := os.Getenv("TEST_LOGS")
if v == "" {
l.SetOutput(io.Discard)
l.SetLevel(logrus.PanicLevel)
return l
return slog.New(slog.NewTextHandler(io.Discard, nil))
}
level := slog.LevelInfo
switch v {
case "2":
l.SetLevel(logrus.DebugLevel)
level = slog.LevelDebug
case "3":
l.SetLevel(logrus.TraceLevel)
default:
l.SetLevel(logrus.InfoLevel)
level = logging.LevelTrace
}
return l
return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level}))
}
// testLogLevelName returns the level name string accepted by logging.ApplyConfig
// for the current TEST_LOGS setting. Kept in sync with NewTestLogger.
func testLogLevelName() string {
switch os.Getenv("TEST_LOGS") {
case "2":
return "debug"
case "3":
return "trace"
case "":
return "info"
}
return "info"
}

View File

@@ -292,23 +292,17 @@ tun:
# Configure logging level
logging:
# panic, fatal, error, warning, info, or debug. Default is info and is reloadable.
#NOTE: Debug mode can log remotely controlled/untrusted data which can quickly fill a disk in some
# scenarios. Debug logging is also CPU intensive and will decrease performance overall.
# Only enable debug logging while actively investigating an issue.
# trace, debug, info, warn, or error. Default is info and is reloadable.
# fatal and panic are accepted for backwards compatibility and map to error.
#NOTE: Debug and trace modes can log remotely controlled/untrusted data which can quickly fill a disk in some
# scenarios. Debug and trace logging are also CPU intensive and will decrease performance overall.
# Only enable debug or trace logging while actively investigating an issue.
level: info
# json or text formats currently available. Default is text
# json or text formats currently available. Default is text.
format: text
# Disable timestamp logging. useful when output is redirected to logging system that already adds timestamps. Default is false
# Disable timestamp logging. Useful when output is redirected to a logging system that already adds timestamps. Default is false.
#disable_timestamp: true
# timestamp format is specified in Go time format, see:
# https://golang.org/pkg/time/#pkg-constants
# default when `format: json`: "2006-01-02T15:04:05Z07:00" (RFC3339)
# default when `format: text`:
# when TTY attached: seconds since beginning of execution
# otherwise: "2006-01-02T15:04:05Z07:00" (RFC3339)
# As an example, to log as RFC3339 with millisecond precision, set to:
#timestamp_format: "2006-01-02T15:04:05.000Z07:00"
# Timestamps use RFC3339Nano ("2006-01-02T15:04:05.999999999Z07:00") and are not configurable.
#stats:
#type: graphite

View File

@@ -7,9 +7,9 @@ import (
"net"
"os"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/logging"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/service"
)
@@ -64,8 +64,7 @@ pki:
return err
}
logger := logrus.New()
logger.Out = os.Stdout
logger := logging.NewLogger(os.Stdout)
ctrl, err := nebula.Main(&cfg, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
if err != nil {

View File

@@ -1,11 +1,13 @@
package nebula
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"hash/fnv"
"log/slog"
"net/netip"
"reflect"
"slices"
@@ -16,7 +18,6 @@ import (
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
@@ -67,7 +68,7 @@ type Firewall struct {
incomingMetrics firewallMetrics
outgoingMetrics firewallMetrics
l *logrus.Logger
l *slog.Logger
}
type firewallMetrics struct {
@@ -131,7 +132,7 @@ type firewallLocalCIDR struct {
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
// The certificate provided should be the highest version loaded in memory.
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
//TODO: error on 0 duration
var tmin, tmax time.Duration
@@ -191,7 +192,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
}
}
func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
func NewFirewallFromConfig(l *slog.Logger, cs *CertState, c *config.C) (*Firewall, error) {
certificate := cs.getCertificate(cert.Version2)
if certificate == nil {
certificate = cs.getCertificate(cert.Version1)
@@ -219,7 +220,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
case "drop":
fw.InSendReject = false
default:
l.WithField("action", inboundAction).Warn("invalid firewall.inbound_action, defaulting to `drop`")
l.Warn("invalid firewall.inbound_action, defaulting to `drop`", "action", inboundAction)
fw.InSendReject = false
}
@@ -230,7 +231,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
case "drop":
fw.OutSendReject = false
default:
l.WithField("action", outboundAction).Warn("invalid firewall.outbound_action, defaulting to `drop`")
l.Warn("invalid firewall.outbound_action, defaulting to `drop`", "action", outboundAction)
fw.OutSendReject = false
}
@@ -268,7 +269,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
case firewall.ProtoICMP, firewall.ProtoICMPv6:
//ICMP traffic doesn't have ports, so we always coerce to "any", even if a value is provided
if startPort != firewall.PortAny {
f.l.WithField("startPort", startPort).Warn("ignoring port specification for ICMP firewall rule")
f.l.Warn("ignoring port specification for ICMP firewall rule", "startPort", startPort)
}
startPort = firewall.PortAny
endPort = firewall.PortAny
@@ -290,8 +291,9 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
if !incoming {
direction = "outgoing"
}
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}).
Info("Firewall rule added")
f.l.Info("Firewall rule added",
"firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha},
)
return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha)
}
@@ -314,7 +316,7 @@ func (f *Firewall) GetRuleHashes() string {
return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10)
}
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
func AddFirewallRulesFromConfig(l *slog.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
var table string
if inbound {
table = "firewall.inbound"
@@ -372,7 +374,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
startPort = firewall.PortAny
endPort = firewall.PortAny
if sPort != "" {
l.WithField("port", sPort).Warn("ignoring port specification for ICMP firewall rule")
l.Warn("ignoring port specification for ICMP firewall rule", "port", sPort)
}
default:
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
@@ -396,7 +398,11 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
}
if warning := r.sanity(); warning != nil {
l.Warnf("%s rule #%v; %s", table, i, warning)
l.Warn("firewall rule sanity check",
"table", table,
"rule", i,
"warning", warning,
)
}
err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
@@ -528,26 +534,26 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
// We now know which firewall table to check against
if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
if f.l.Level >= logrus.DebugLevel {
h.logger(f.l).
WithField("fwPacket", fp).
WithField("incoming", c.incoming).
WithField("rulesVersion", f.rulesVersion).
WithField("oldRulesVersion", c.rulesVersion).
Debugln("dropping old conntrack entry, does not match new ruleset")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
h.logger(f.l).Debug("dropping old conntrack entry, does not match new ruleset",
"fwPacket", fp,
"incoming", c.incoming,
"rulesVersion", f.rulesVersion,
"oldRulesVersion", c.rulesVersion,
)
}
delete(conntrack.Conns, fp)
conntrack.Unlock()
return false
}
if f.l.Level >= logrus.DebugLevel {
h.logger(f.l).
WithField("fwPacket", fp).
WithField("incoming", c.incoming).
WithField("rulesVersion", f.rulesVersion).
WithField("oldRulesVersion", c.rulesVersion).
Debugln("keeping old conntrack entry, does match new ruleset")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
h.logger(f.l).Debug("keeping old conntrack entry, does match new ruleset",
"fwPacket", fp,
"incoming", c.incoming,
"rulesVersion", f.rulesVersion,
"oldRulesVersion", c.rulesVersion,
)
}
c.rulesVersion = f.rulesVersion
@@ -935,7 +941,7 @@ type rule struct {
CASha string
}
func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
func convertRule(l *slog.Logger, p any, table string, i int) (rule, error) {
r := rule{}
m, ok := p.(map[string]any)
@@ -966,7 +972,10 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
return r, errors.New("group should contain a single value, an array with more than one entry was provided")
}
l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i)
l.Warn("group was an array with a single value, converting to simple value",
"table", table,
"rule", i,
)
m["group"] = v[0]
}

View File

@@ -2,10 +2,9 @@ package firewall
import (
"context"
"log/slog"
"sync/atomic"
"time"
"github.com/sirupsen/logrus"
)
// ConntrackCache is used as a local routine cache to know if a given flow
@@ -16,15 +15,17 @@ type ConntrackCacheTicker struct {
cacheV uint64
cacheTick atomic.Uint64
l *slog.Logger
cache ConntrackCache
}
func NewConntrackCacheTicker(ctx context.Context, d time.Duration) *ConntrackCacheTicker {
func NewConntrackCacheTicker(ctx context.Context, l *slog.Logger, d time.Duration) *ConntrackCacheTicker {
if d == 0 {
return nil
}
c := &ConntrackCacheTicker{
l: l,
cache: ConntrackCache{},
}
@@ -48,15 +49,15 @@ func (c *ConntrackCacheTicker) tick(ctx context.Context, d time.Duration) {
// Get checks if the cache ticker has moved to the next version before returning
// the map. If it has moved, we reset the map.
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
func (c *ConntrackCacheTicker) Get() ConntrackCache {
if c == nil {
return nil
}
if tick := c.cacheTick.Load(); tick != c.cacheV {
c.cacheV = tick
if ll := len(c.cache); ll > 0 {
if l.Level == logrus.DebugLevel {
l.WithField("len", ll).Debug("resetting conntrack cache")
if c.l.Enabled(context.Background(), slog.LevelDebug) {
c.l.Debug("resetting conntrack cache", "len", ll)
}
c.cache = make(ConntrackCache, ll)
}

69
firewall/cache_test.go Normal file
View File

@@ -0,0 +1,69 @@
package firewall
import (
"bytes"
"log/slog"
"strings"
"testing"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
)
// The tests below pin the log format produced by ConntrackCacheTicker.Get
// so changes cannot silently break what operators are grepping for. The
// ticker's internal state (cache + cacheTick) is poked directly to avoid
// racing a goroutine-driven tick in tests.
func newFixedTicker(t *testing.T, l *slog.Logger, cacheLen int) *ConntrackCacheTicker {
t.Helper()
c := &ConntrackCacheTicker{
l: l,
cache: make(ConntrackCache, cacheLen),
}
for i := 0; i < cacheLen; i++ {
c.cache[Packet{LocalPort: uint16(i) + 1}] = struct{}{}
}
c.cacheTick.Store(1) // cacheV starts at 0, so Get() takes the reset path
return c
}
func TestConntrackCacheTicker_Get_TextFormat(t *testing.T) {
buf := &bytes.Buffer{}
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug)
c := newFixedTicker(t, l, 3)
c.Get()
assert.Equal(t, "level=DEBUG msg=\"resetting conntrack cache\" len=3\n", buf.String())
}
func TestConntrackCacheTicker_Get_JSONFormat(t *testing.T) {
buf := &bytes.Buffer{}
l := test.NewJSONLoggerWithOutput(buf, slog.LevelDebug)
c := newFixedTicker(t, l, 2)
c.Get()
assert.JSONEq(t, `{"level":"DEBUG","msg":"resetting conntrack cache","len":2}`, strings.TrimSpace(buf.String()))
}
func TestConntrackCacheTicker_Get_QuietBelowDebug(t *testing.T) {
buf := &bytes.Buffer{}
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelInfo)
c := newFixedTicker(t, l, 5)
c.Get()
assert.Empty(t, buf.String())
}
func TestConntrackCacheTicker_Get_QuietWhenCacheEmpty(t *testing.T) {
buf := &bytes.Buffer{}
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug)
c := newFixedTicker(t, l, 0)
c.Get()
assert.Empty(t, buf.String())
}

View File

@@ -3,13 +3,13 @@ package nebula
import (
"bytes"
"errors"
"log/slog"
"math"
"net/netip"
"testing"
"time"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
@@ -58,9 +58,8 @@ func TestNewFirewall(t *testing.T) {
}
func TestFirewall_AddRule(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
c := &dummyCert{}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
@@ -177,9 +176,8 @@ func TestFirewall_AddRule(t *testing.T) {
}
func TestFirewall_Drop(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{
@@ -254,9 +252,8 @@ func TestFirewall_Drop(t *testing.T) {
}
func TestFirewall_DropV6(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
@@ -485,9 +482,8 @@ func BenchmarkFirewallTable_match(b *testing.B) {
}
func TestFirewall_Drop2(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
@@ -544,9 +540,8 @@ func TestFirewall_Drop2(t *testing.T) {
}
func TestFirewall_Drop3(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
@@ -633,9 +628,8 @@ func TestFirewall_Drop3(t *testing.T) {
}
func TestFirewall_Drop3V6(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
@@ -671,9 +665,8 @@ func TestFirewall_Drop3V6(t *testing.T) {
}
func TestFirewall_DropConntrackReload(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
@@ -736,9 +729,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
}
func TestFirewall_ICMPPortBehavior(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
@@ -880,9 +872,8 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
}
func TestFirewall_DropIPSpoofing(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
@@ -1045,25 +1036,25 @@ func TestNewFirewallFromConfig(t *testing.T) {
cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
require.NoError(t, err)
conf := config.NewC(l)
conf := config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
// Test both port and code
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
// Test missing host, group, cidr, ca_name and ca_sha
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
// Test code/port error
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
@@ -1073,25 +1064,25 @@ func TestNewFirewallFromConfig(t *testing.T) {
require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
// Test proto error
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
// Test cidr parse error
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test local_cidr parse error
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test both group and groups
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
@@ -1100,35 +1091,35 @@ func TestNewFirewallFromConfig(t *testing.T) {
func TestAddFirewallRulesFromConfig(t *testing.T) {
l := test.NewLogger()
// Test adding tcp rule
conf := config.NewC(l)
conf := config.NewC(test.NewLogger())
mf := &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding udp rule
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding icmp rule
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding icmp rule no port
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"proto": "icmp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding any rule
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
@@ -1136,14 +1127,14 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
// Test adding rule with cidr
cidr := netip.MustParsePrefix("10.0.0.0/8")
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall)
// Test adding rule with local_cidr
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
@@ -1151,82 +1142,82 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
// Test adding rule with cidr ipv6
cidr6 := netip.MustParsePrefix("fd00::/8")
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall)
// Test adding rule with any cidr
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall)
// Test adding rule with junk cidr
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
// Test adding rule with local_cidr ipv6
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall)
// Test adding rule with any local_cidr
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall)
// Test adding rule with junk local_cidr
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
// Test adding rule with ca_sha
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall)
// Test single group
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
// Test single groups
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
// Test multiple AND groups
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall)
// Test Add error
conf = config.NewC(l)
conf = config.NewC(test.NewLogger())
mf = &mockFirewall{}
mf.nextCallReturn = errors.New("test error")
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
@@ -1234,9 +1225,8 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
}
func TestFirewall_convertRule(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
// Ensure group array of 1 is converted and a warning is printed
c := map[string]any{
@@ -1244,7 +1234,9 @@ func TestFirewall_convertRule(t *testing.T) {
}
r, err := convertRule(l, c, "test", 1)
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
assert.Contains(t, ob.String(), "group was an array with a single value, converting to simple value")
assert.Contains(t, ob.String(), "table=test")
assert.Contains(t, ob.String(), "rule=1")
require.NoError(t, err)
assert.Equal(t, []string{"group1"}, r.Groups)
@@ -1270,9 +1262,8 @@ func TestFirewall_convertRule(t *testing.T) {
}
func TestFirewall_convertRuleSanity(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
noWarningPlease := []map[string]any{
{"group": "group1"},
@@ -1386,7 +1377,7 @@ type testsetup struct {
fw *Firewall
}
func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup {
func newSetup(t *testing.T, l *slog.Logger, myPrefixes ...netip.Prefix) testsetup {
c := dummyCert{
name: "me",
networks: myPrefixes,
@@ -1397,7 +1388,7 @@ func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testse
return newSetupFromCert(t, l, c)
}
func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
func newSetupFromCert(t *testing.T, l *slog.Logger, c dummyCert) testsetup {
myVpnNetworksTable := new(bart.Lite)
for _, prefix := range c.Networks() {
myVpnNetworksTable.Insert(prefix)
@@ -1414,9 +1405,8 @@ func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
t.Parallel()
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
l := test.NewLoggerWithOutput(ob)
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
// for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out

1
go.mod
View File

@@ -18,7 +18,6 @@ require (
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
github.com/prometheus/client_golang v1.23.2
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
github.com/sirupsen/logrus v1.9.4
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
github.com/stretchr/testify v1.11.1

2
go.sum
View File

@@ -133,8 +133,6 @@ github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncj
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w=
github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw=

View File

@@ -2,11 +2,12 @@ package nebula
import (
"bytes"
"context"
"log/slog"
"net/netip"
"time"
"github.com/flynn/noise"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
)
@@ -18,8 +19,11 @@ import (
func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
err := f.handshakeManager.allocateIndex(hh)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
f.l.Error("Failed to generate index",
"error", err,
"vpnAddrs", hh.hostinfo.vpnAddrs,
"handshake", m{"stage": 0, "style": "ix_psk0"},
)
return false
}
@@ -39,28 +43,32 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
crt := cs.getCertificate(v)
if crt == nil {
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Unable to handshake with host because no certificate is available")
f.l.Error("Unable to handshake with host because no certificate is available",
"vpnAddrs", hh.hostinfo.vpnAddrs,
"handshake", m{"stage": 0, "style": "ix_psk0"},
"certVersion", v,
)
return false
}
crtHs := cs.getHandshakeBytes(v)
if crtHs == nil {
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Unable to handshake with host because no certificate handshake bytes is available")
f.l.Error("Unable to handshake with host because no certificate handshake bytes is available",
"vpnAddrs", hh.hostinfo.vpnAddrs,
"handshake", m{"stage": 0, "style": "ix_psk0"},
"certVersion", v,
)
return false
}
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
ci, err := NewConnectionState(cs, crt, true, noise.HandshakeIX)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Failed to create connection state")
f.l.Error("Failed to create connection state",
"error", err,
"vpnAddrs", hh.hostinfo.vpnAddrs,
"handshake", m{"stage": 0, "style": "ix_psk0"},
"certVersion", v,
)
return false
}
hh.hostinfo.ConnectionState = ci
@@ -76,9 +84,12 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
hsBytes, err := hs.Marshal()
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("certVersion", v).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
f.l.Error("Failed to marshal handshake message",
"error", err,
"vpnAddrs", hh.hostinfo.vpnAddrs,
"certVersion", v,
"handshake", m{"stage": 0, "style": "ix_psk0"},
)
return false
}
@@ -86,8 +97,11 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
f.l.Error("Failed to call noise.WriteMessage",
"error", err,
"vpnAddrs", hh.hostinfo.vpnAddrs,
"handshake", m{"stage": 0, "style": "ix_psk0"},
)
return false
}
@@ -104,18 +118,21 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
cs := f.pki.getCertState()
crt := cs.GetDefaultCertificate()
if crt == nil {
f.l.WithField("from", via).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", cs.initiatingVersion).
Error("Unable to handshake with host because no certificate is available")
f.l.Error("Unable to handshake with host because no certificate is available",
"from", via,
"handshake", m{"stage": 0, "style": "ix_psk0"},
"certVersion", cs.initiatingVersion,
)
return
}
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
ci, err := NewConnectionState(cs, crt, false, noise.HandshakeIX)
if err != nil {
f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to create connection state")
f.l.Error("Failed to create connection state",
"error", err,
"from", via,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
@@ -124,26 +141,32 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to call noise.ReadMessage")
f.l.Error("Failed to call noise.ReadMessage",
"error", err,
"from", via,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
hs := &NebulaHandshake{}
err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed unmarshal handshake message")
f.l.Error("Failed unmarshal handshake message",
"error", err,
"from", via,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
if err != nil {
f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake did not contain a certificate")
f.l.Info("Handshake did not contain a certificate",
"error", err,
"from", via,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
@@ -154,23 +177,30 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
fp = "<error generating certificate fingerprint>"
}
e := f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("certVpnNetworks", rc.Networks()).
WithField("certFingerprint", fp)
if f.l.Level >= logrus.DebugLevel {
e = e.WithField("cert", rc)
attrs := []slog.Attr{
slog.Any("error", err),
slog.Any("from", via),
slog.Any("handshake", m{"stage": 1, "style": "ix_psk0"}),
slog.Any("certVpnNetworks", rc.Networks()),
slog.String("certFingerprint", fp),
}
if f.l.Enabled(context.Background(), slog.LevelDebug) {
attrs = append(attrs, slog.Any("cert", rc))
}
e.Info("Invalid certificate from host")
// LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that
// callers grow conditionally, which has no pair-form equivalent.
//nolint:sloglint
f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...)
return
}
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
f.l.WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake")
f.l.Info("public key mismatch between certificate and handshake",
"from", via,
"handshake", m{"stage": 1, "style": "ix_psk0"},
"cert", remoteCert,
)
return
}
@@ -178,12 +208,13 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
if myCertOtherVersion == nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithError(err).WithFields(m{
"from": via,
"handshake": m{"stage": 1, "style": "ix_psk0"},
"cert": remoteCert,
}).Debug("Might be unable to handshake with host due to missing certificate version")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Might be unable to handshake with host due to missing certificate version",
"error", err,
"from", via,
"handshake", m{"stage": 1, "style": "ix_psk0"},
"cert", remoteCert,
)
}
} else {
// Record the certificate we are actually using
@@ -192,10 +223,12 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
}
if len(remoteCert.Certificate.Networks()) == 0 {
f.l.WithError(err).WithField("from", via).
WithField("cert", remoteCert).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("No networks in certificate")
f.l.Info("No networks in certificate",
"error", err,
"from", via,
"cert", remoteCert,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
@@ -209,12 +242,15 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
for i, network := range vpnNetworks {
if f.myVpnAddrsTable.Contains(network.Addr()) {
f.l.WithField("vpnNetworks", vpnNetworks).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
f.l.Error("Refusing to handshake with myself",
"vpnNetworks", vpnNetworks,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
vpnAddrs[i] = network.Addr()
@@ -226,20 +262,28 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
if !via.IsRelayed {
// We only want to apply the remote allow list for direct tunnels here
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
Debug("lighthouse.remote_allow_list denied incoming handshake")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("lighthouse.remote_allow_list denied incoming handshake",
"vpnAddrs", vpnAddrs,
"from", via,
)
}
return
}
}
myIndex, err := generateIndex(f.l)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
f.l.Error("Failed to generate index",
"error", err,
"vpnAddrs", vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
@@ -257,18 +301,18 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
},
}
msgRxL := f.l.WithFields(m{
"vpnAddrs": vpnAddrs,
"from": via,
"certName": certName,
"certVersion": certVersion,
"fingerprint": fingerprint,
"issuer": issuer,
"initiatorIndex": hs.Details.InitiatorIndex,
"responderIndex": hs.Details.ResponderIndex,
"remoteIndex": h.RemoteIndex,
"handshake": m{"stage": 1, "style": "ix_psk0"},
})
msgRxL := f.l.With(
"vpnAddrs", vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
if anyVpnAddrsInCommon {
msgRxL.Info("Handshake message received")
@@ -280,8 +324,9 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
hs.Details.ResponderIndex = myIndex
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
if hs.Details.Cert == nil {
msgRxL.WithField("myCertVersion", ci.myCert.Version()).
Error("Unable to handshake with host because no certificate handshake bytes is available")
msgRxL.Error("Unable to handshake with host because no certificate handshake bytes is available",
"myCertVersion", ci.myCert.Version(),
)
return
}
@@ -291,32 +336,43 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
hsBytes, err := hs.Marshal()
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
f.l.Error("Failed to marshal handshake message",
"error", err,
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
f.l.Error("Failed to call noise.WriteMessage",
"error", err,
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
} else if dKey == nil || eKey == nil {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
f.l.Error("Noise did not arrive at a key",
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
@@ -358,13 +414,20 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
if !via.IsRelayed {
err := f.outside.WriteTo(msg, via.UdpAddr)
if err != nil {
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
WithError(err).Error("Failed to send handshake message")
f.l.Error("Failed to send handshake message",
"vpnAddrs", existing.vpnAddrs,
"from", via,
"handshake", m{"stage": 2, "style": "ix_psk0"},
"cached", true,
"error", err,
)
} else {
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent")
f.l.Info("Handshake message sent",
"vpnAddrs", existing.vpnAddrs,
"from", via,
"handshake", m{"stage": 2, "style": "ix_psk0"},
"cached", true,
)
}
return
} else {
@@ -374,50 +437,67 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
}
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent")
f.l.Info("Handshake message sent",
"vpnAddrs", existing.vpnAddrs,
"relay", via.relayHI.vpnAddrs[0],
"handshake", m{"stage": 2, "style": "ix_psk0"},
"cached", true,
)
return
}
case ErrExistingHostInfo:
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("oldHandshakeTime", existing.lastHandshakeTime).
WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake too old")
f.l.Info("Handshake too old",
"vpnAddrs", vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"oldHandshakeTime", existing.lastHandshakeTime,
"newHandshakeTime", hostinfo.lastHandshakeTime,
"fingerprint", fingerprint,
"issuer", issuer,
"initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
return
case ErrLocalIndexCollision:
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs).
Error("Failed to add HostInfo due to localIndex collision")
f.l.Error("Failed to add HostInfo due to localIndex collision",
"vpnAddrs", vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 1, "style": "ix_psk0"},
"localIndex", hostinfo.localIndexId,
"collision", existing.vpnAddrs,
)
return
default:
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
// And we forget to update it here
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to add HostInfo to HostMap")
f.l.Error("Failed to add HostInfo to HostMap",
"error", err,
"vpnAddrs", vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
}
@@ -426,15 +506,20 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
if !via.IsRelayed {
err = f.outside.WriteTo(msg, via.UdpAddr)
log := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
log := f.l.With(
"vpnAddrs", vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
if err != nil {
log.WithError(err).Error("Failed to send handshake")
log.Error("Failed to send handshake", "error", err)
} else {
log.Info("Handshake message sent")
}
@@ -448,14 +533,18 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
// it's correctly marked as working.
via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established)
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Handshake message sent")
f.l.Info("Handshake message sent",
"vpnAddrs", vpnAddrs,
"relay", via.relayHI.vpnAddrs[0],
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
}
f.connectionManager.AddTrafficWatch(hostinfo)
@@ -483,7 +572,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
if !via.IsRelayed {
// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("lighthouse.remote_allow_list denied incoming handshake",
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
)
}
return false
}
}
@@ -491,18 +585,24 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
ci := hostinfo.ConnectionState
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Error("Failed to call noise.ReadMessage")
f.l.Error("Failed to call noise.ReadMessage",
"error", err,
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
"handshake", m{"stage": 2, "style": "ix_psk0"},
"header", h,
)
// We don't want to tear down the connection on a bad ReadMessage because it could be an attacker trying
// to DOS us. Every other error condition after should to allow a possible good handshake to complete in the
// near future
return false
} else if dKey == nil || eKey == nil {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key")
f.l.Error("Noise did not arrive at a key",
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
// This should be impossible in IX but just in case, if we get here then there is no chance to recover
// the handshake state machine. Tear it down
@@ -512,8 +612,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
hs := &NebulaHandshake{}
err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
f.l.Error("Failed unmarshal handshake message",
"error", err,
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
return true
@@ -521,10 +625,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
if err != nil {
f.l.WithError(err).WithField("from", via).
WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Handshake did not contain a certificate")
f.l.Info("Handshake did not contain a certificate",
"error", err,
"from", via,
"vpnAddrs", hostinfo.vpnAddrs,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
return true
}
@@ -535,32 +641,41 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
fp = "<error generating certificate fingerprint>"
}
e := f.l.WithError(err).WithField("from", via).
WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("certFingerprint", fp).
WithField("certVpnNetworks", rc.Networks())
if f.l.Level >= logrus.DebugLevel {
e = e.WithField("cert", rc)
attrs := []slog.Attr{
slog.Any("error", err),
slog.Any("from", via),
slog.Any("vpnAddrs", hostinfo.vpnAddrs),
slog.Any("handshake", m{"stage": 2, "style": "ix_psk0"}),
slog.String("certFingerprint", fp),
slog.Any("certVpnNetworks", rc.Networks()),
}
if f.l.Enabled(context.Background(), slog.LevelDebug) {
attrs = append(attrs, slog.Any("cert", rc))
}
e.Info("Invalid certificate from host")
// LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that
// callers grow conditionally, which has no pair-form equivalent.
//nolint:sloglint
f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...)
return true
}
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
f.l.WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake")
f.l.Info("public key mismatch between certificate and handshake",
"from", via,
"handshake", m{"stage": 2, "style": "ix_psk0"},
"cert", remoteCert,
)
return true
}
if len(remoteCert.Certificate.Networks()) == 0 {
f.l.WithError(err).WithField("from", via).
WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("cert", remoteCert).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("No networks in certificate")
f.l.Info("No networks in certificate",
"error", err,
"from", via,
"vpnAddrs", hostinfo.vpnAddrs,
"cert", remoteCert,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
return true
}
@@ -601,12 +716,14 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
// Ensure the right host responded
if !correctHostResponded {
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Incorrect host responded to handshake")
f.l.Info("Incorrect host responded to handshake",
"intendedVpnAddrs", hostinfo.vpnAddrs,
"haveVpnNetworks", vpnNetworks,
"from", via,
"certName", certName,
"certVersion", certVersion,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
// Release our old handshake from pending, it should not continue
f.handshakeManager.DeleteHostInfo(hostinfo)
@@ -618,10 +735,11 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
newHH.hostinfo.remotes = hostinfo.remotes
newHH.hostinfo.remotes.BlockRemote(via)
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
WithField("vpnNetworks", vpnNetworks).
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
Info("Blocked addresses for handshakes")
f.l.Info("Blocked addresses for handshakes",
"blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes(),
"vpnNetworks", vpnNetworks,
"remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges()),
)
// Swap the packet store to benefit the original intended recipient
newHH.packetStore = hh.packetStore
@@ -639,15 +757,20 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
ci.window.Update(f.l, 2)
duration := time.Since(hh.startTime).Nanoseconds()
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("durationNs", duration).
WithField("sentCachedPackets", len(hh.packetStore))
msgRxL := f.l.With(
"vpnAddrs", vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 2, "style": "ix_psk0"},
"durationNs", duration,
"sentCachedPackets", len(hh.packetStore),
)
if anyVpnAddrsInCommon {
msgRxL.Info("Handshake message received")
} else {
@@ -663,8 +786,10 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
f.handshakeManager.Complete(hostinfo, f)
f.connectionManager.AddTrafficWatch(hostinfo)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("Sending stored packets",
"count", len(hh.packetStore),
)
}
if len(hh.packetStore) > 0 {

View File

@@ -6,13 +6,13 @@ import (
"crypto/rand"
"encoding/binary"
"errors"
"log/slog"
"net/netip"
"slices"
"sync"
"time"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp"
@@ -59,7 +59,7 @@ type HandshakeManager struct {
metricInitiated metrics.Counter
metricTimedOut metrics.Counter
f *Interface
l *logrus.Logger
l *slog.Logger
// can be used to trigger outbound handshake for the given vpnIp
trigger chan netip.Addr
@@ -78,32 +78,32 @@ type HandshakeHostInfo struct {
hostinfo *HostInfo
}
func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
func (hh *HandshakeHostInfo) cachePacket(l *slog.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
if len(hh.packetStore) < 100 {
tempPacket := make([]byte, len(packet))
copy(tempPacket, packet)
hh.packetStore = append(hh.packetStore, &cachedPacket{t, st, f, tempPacket})
if l.Level >= logrus.DebugLevel {
hh.hostinfo.logger(l).
WithField("length", len(hh.packetStore)).
WithField("stored", true).
Debugf("Packet store")
if l.Enabled(context.Background(), slog.LevelDebug) {
hh.hostinfo.logger(l).Debug("Packet store",
"length", len(hh.packetStore),
"stored", true,
)
}
} else {
m.dropped.Inc(1)
if l.Level >= logrus.DebugLevel {
hh.hostinfo.logger(l).
WithField("length", len(hh.packetStore)).
WithField("stored", false).
Debugf("Packet store")
if l.Enabled(context.Background(), slog.LevelDebug) {
hh.hostinfo.logger(l).Debug("Packet store",
"length", len(hh.packetStore),
"stored", false,
)
}
}
}
func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
func NewHandshakeManager(l *slog.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
return &HandshakeManager{
vpnIps: map[netip.Addr]*HandshakeHostInfo{},
indexes: map[uint32]*HandshakeHostInfo{},
@@ -140,7 +140,7 @@ func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *head
// First remote allow list check before we know the vpnIp
if !via.IsRelayed {
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) {
hm.l.WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake")
hm.l.Debug("lighthouse.remote_allow_list denied incoming handshake", "from", via)
return
}
}
@@ -183,12 +183,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
hostinfo := hh.hostinfo
// If we are out of time, clean up
if hh.counter >= hm.config.retries {
hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())).
WithField("initiatorIndex", hh.hostinfo.localIndexId).
WithField("remoteIndex", hh.hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("durationNs", time.Since(hh.startTime).Nanoseconds()).
Info("Handshake timed out")
hh.hostinfo.logger(hm.l).Info("Handshake timed out",
"udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()),
"initiatorIndex", hh.hostinfo.localIndexId,
"remoteIndex", hh.hostinfo.remoteIndexId,
"handshake", m{"stage": 1, "style": "ix_psk0"},
"durationNs", time.Since(hh.startTime).Nanoseconds(),
)
hm.metricTimedOut.Inc(1)
hm.DeleteHostInfo(hostinfo)
return
@@ -241,10 +242,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
if err != nil {
hostinfo.logger(hm.l).WithField("udpAddr", addr).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake message")
hostinfo.logger(hm.l).Error("Failed to send handshake message",
"udpAddr", addr,
"initiatorIndex", hostinfo.localIndexId,
"handshake", m{"stage": 1, "style": "ix_psk0"},
"error", err,
)
} else {
sentTo = append(sentTo, addr)
@@ -254,19 +257,21 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
// Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout,
// so only log when the list of remotes has changed
if remotesHaveChanged {
hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake message sent")
} else if hm.l.Level >= logrus.DebugLevel {
hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Debug("Handshake message sent")
hostinfo.logger(hm.l).Info("Handshake message sent",
"udpAddrs", sentTo,
"initiatorIndex", hostinfo.localIndexId,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
} else if hm.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(hm.l).Debug("Handshake message sent",
"udpAddrs", sentTo,
"initiatorIndex", hostinfo.localIndexId,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
}
if hm.config.useRelays && len(hostinfo.remotes.relays) > 0 {
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
hostinfo.logger(hm.l).Info("Attempt to relay through hosts", "relays", hostinfo.remotes.relays)
// Send a RelayRequest to all known Relay IP's
for _, relay := range hostinfo.remotes.relays {
// Don't relay through the host I'm trying to connect to
@@ -281,7 +286,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay)
if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
hostinfo.logger(hm.l).Info("Establish tunnel to relay target", "relay", relay.String())
hm.f.Handshake(relay)
continue
}
@@ -292,7 +297,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
if relayHostInfo.remote.IsValid() {
idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
if err != nil {
hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
hostinfo.logger(hm.l).Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err)
}
m := NebulaControl{
@@ -326,17 +331,15 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
msg, err := m.Marshal()
if err != nil {
hostinfo.logger(hm.l).
WithError(err).
Error("Failed to marshal Control message to create relay")
hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err)
} else {
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{
"relayFrom": hm.f.myVpnAddrs[0],
"relayTo": vpnIp,
"initiatorRelayIndex": idx,
"relay": relay}).
Info("send CreateRelayRequest")
hm.l.Info("send CreateRelayRequest",
"relayFrom", hm.f.myVpnAddrs[0],
"relayTo", vpnIp,
"initiatorRelayIndex", idx,
"relay", relay,
)
}
}
continue
@@ -344,14 +347,14 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
switch existingRelay.State {
case Established:
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay")
hostinfo.logger(hm.l).Info("Send handshake via relay", "relay", relay.String())
hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
case Disestablished:
// Mark this relay as 'requested'
relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested)
fallthrough
case Requested:
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
hostinfo.logger(hm.l).Info("Re-send CreateRelay request", "relay", relay.String())
// Re-send the CreateRelay request, in case the previous one was lost.
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
@@ -383,28 +386,26 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
}
msg, err := m.Marshal()
if err != nil {
hostinfo.logger(hm.l).
WithError(err).
Error("Failed to marshal Control message to create relay")
hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err)
} else {
// This must send over the hostinfo, not over hm.Hosts[ip]
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{
"relayFrom": hm.f.myVpnAddrs[0],
"relayTo": vpnIp,
"initiatorRelayIndex": existingRelay.LocalIndex,
"relay": relay}).
Info("send CreateRelayRequest")
hm.l.Info("send CreateRelayRequest",
"relayFrom", hm.f.myVpnAddrs[0],
"relayTo", vpnIp,
"initiatorRelayIndex", existingRelay.LocalIndex,
"relay", relay,
)
}
case PeerRequested:
// PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case.
fallthrough
default:
hostinfo.logger(hm.l).
WithField("vpnIp", vpnIp).
WithField("state", existingRelay.State).
WithField("relay", relay).
Errorf("Relay unexpected state")
hostinfo.logger(hm.l).Error("Relay unexpected state",
"vpnIp", vpnIp,
"state", existingRelay.State,
"relay", relay,
)
}
}
@@ -549,9 +550,10 @@ func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] {
// We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note.
hostinfo.logger(hm.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
Info("New host shadows existing host remoteIndex")
hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex",
"remoteIndex", hostinfo.remoteIndexId,
"collision", existingRemoteIndex.vpnAddrs,
)
}
hm.mainHostMap.unlockedAddHostInfo(hostinfo, f)
@@ -571,9 +573,10 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
if found && existingRemoteIndex != nil {
// We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note.
hostinfo.logger(hm.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
Info("New host shadows existing host remoteIndex")
hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex",
"remoteIndex", hostinfo.remoteIndexId,
"collision", existingRemoteIndex.vpnAddrs,
)
}
// We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap.
@@ -629,10 +632,11 @@ func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
hm.indexes = map[uint32]*HandshakeHostInfo{}
}
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps),
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
Debug("Pending hostmap hostInfo deleted")
if hm.l.Enabled(context.Background(), slog.LevelDebug) {
hm.l.Debug("Pending hostmap hostInfo deleted",
"hostMap", m{"mapTotalSize": len(hm.vpnIps),
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId},
)
}
}
@@ -700,7 +704,7 @@ func (hm *HandshakeManager) EmitStats() {
// Utility functions below
func generateIndex(l *logrus.Logger) (uint32, error) {
func generateIndex(l *slog.Logger) (uint32, error) {
b := make([]byte, 4)
// Let zero mean we don't know the ID, so don't generate zero
@@ -708,16 +712,15 @@ func generateIndex(l *logrus.Logger) (uint32, error) {
for index == 0 {
_, err := rand.Read(b)
if err != nil {
l.Errorln(err)
l.Error("Failed to generate index", "error", err)
return 0, err
}
index = binary.BigEndian.Uint32(b)
}
if l.Level >= logrus.DebugLevel {
l.WithField("index", index).
Debug("Generated index")
if l.Enabled(context.Background(), slog.LevelDebug) {
l.Debug("Generated index", "index", index)
}
return index, nil
}

View File

@@ -1,9 +1,11 @@
package nebula
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net"
"net/netip"
"slices"
@@ -13,10 +15,10 @@ import (
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/logging"
)
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
@@ -60,7 +62,7 @@ type HostMap struct {
RemoteIndexes map[uint32]*HostInfo
Hosts map[netip.Addr]*HostInfo
preferredRanges atomic.Pointer[[]netip.Prefix]
l *logrus.Logger
l *slog.Logger
}
// For synchronization, treat the pointed-to Relay struct as immutable. To edit the Relay
@@ -313,7 +315,7 @@ type cachedPacketMetrics struct {
dropped metrics.Counter
}
func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
func NewHostMapFromConfig(l *slog.Logger, c *config.C) *HostMap {
hm := newHostMap(l)
hm.reload(c, true)
@@ -321,13 +323,12 @@ func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
hm.reload(c, false)
})
l.WithField("preferredRanges", hm.GetPreferredRanges()).
Info("Main HostMap created")
l.Info("Main HostMap created", "preferredRanges", hm.GetPreferredRanges())
return hm
}
func newHostMap(l *logrus.Logger) *HostMap {
func newHostMap(l *slog.Logger) *HostMap {
return &HostMap{
Indexes: map[uint32]*HostInfo{},
Relays: map[uint32]*HostInfo{},
@@ -346,7 +347,10 @@ func (hm *HostMap) reload(c *config.C, initial bool) {
preferredRange, err := netip.ParsePrefix(rawPreferredRange)
if err != nil {
hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring")
hm.l.Warn("Failed to parse preferred ranges, ignoring",
"error", err,
"range", rawPreferredRanges,
)
continue
}
@@ -355,7 +359,10 @@ func (hm *HostMap) reload(c *config.C, initial bool) {
oldRanges := hm.preferredRanges.Swap(&preferredRanges)
if !initial {
hm.l.WithField("oldPreferredRanges", *oldRanges).WithField("newPreferredRanges", preferredRanges).Info("preferred_ranges changed")
hm.l.Info("preferred_ranges changed",
"oldPreferredRanges", *oldRanges,
"newPreferredRanges", preferredRanges,
)
}
}
}
@@ -488,10 +495,11 @@ func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Ad
hm.Indexes = map[uint32]*HostInfo{}
}
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts),
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
Debug("Hostmap hostInfo deleted")
if hm.l.Enabled(context.Background(), slog.LevelDebug) {
hm.l.Debug("Hostmap hostInfo deleted",
"hostMap", m{"mapTotalSize": len(hm.Hosts),
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId},
)
}
if isLastHostinfo {
@@ -615,10 +623,11 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
hm.Indexes[hostinfo.localIndexId] = hostinfo
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}).
Debug("Hostmap vpnIp added")
if hm.l.Enabled(context.Background(), slog.LevelDebug) {
hm.l.Debug("Hostmap vpnIp added",
"hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}},
)
}
}
@@ -784,18 +793,21 @@ func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certifica
}
}
func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
// logger returns a derived slog.Logger with per-hostinfo fields pre-bound.
func (i *HostInfo) logger(l *slog.Logger) *slog.Logger {
if i == nil {
return logrus.NewEntry(l)
return l
}
li := l.WithField("vpnAddrs", i.vpnAddrs).
WithField("localIndex", i.localIndexId).
WithField("remoteIndex", i.remoteIndexId)
li := l.With(
"vpnAddrs", i.vpnAddrs,
"localIndex", i.localIndexId,
"remoteIndex", i.remoteIndexId,
)
if connState := i.ConnectionState; connState != nil {
if peerCert := connState.peerCert; peerCert != nil {
li = li.WithField("certName", peerCert.Certificate.Name())
li = li.With("certName", peerCert.Certificate.Name())
}
}
@@ -804,14 +816,17 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
// Utility functions
func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
func localAddrs(l *slog.Logger, allowList *LocalAllowList) []netip.Addr {
//FIXME: This function is pretty garbage
var finalAddrs []netip.Addr
ifaces, _ := net.Interfaces()
for _, i := range ifaces {
allow := allowList.AllowName(i.Name)
if l.Level >= logrus.TraceLevel {
l.WithField("interfaceName", i.Name).WithField("allow", allow).Trace("localAllowList.AllowName")
if l.Enabled(context.Background(), logging.LevelTrace) {
l.Log(context.Background(), logging.LevelTrace, "localAllowList.AllowName",
"interfaceName", i.Name,
"allow", allow,
)
}
if !allow {
@@ -829,8 +844,8 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
}
if !addr.IsValid() {
if l.Level >= logrus.DebugLevel {
l.WithField("localAddr", rawAddr).Debug("addr was invalid")
if l.Enabled(context.Background(), slog.LevelDebug) {
l.Debug("addr was invalid", "localAddr", rawAddr)
}
continue
}
@@ -838,8 +853,11 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false {
isAllowed := allowList.Allow(addr)
if l.Level >= logrus.TraceLevel {
l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow")
if l.Enabled(context.Background(), logging.LevelTrace) {
l.Log(context.Background(), logging.LevelTrace, "localAllowList.Allow",
"localAddr", addr,
"allowed", isAllowed,
)
}
if !isAllowed {
continue

View File

@@ -196,7 +196,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
func TestHostMap_reload(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)
c := config.NewC(test.NewLogger())
hm := NewHostMapFromConfig(l, c)

119
inside.go
View File

@@ -1,9 +1,10 @@
package nebula
import (
"context"
"log/slog"
"net/netip"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
@@ -14,8 +15,11 @@ import (
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
err := newPacket(packet, false, fwPacket)
if err != nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Error while validating outbound packet",
"packet", packet,
"error", err,
)
}
return
}
@@ -35,7 +39,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
if immediatelyForwardToSelf {
_, err := f.readers[q].Write(packet)
if err != nil {
f.l.WithError(err).Error("Failed to forward to tun")
f.l.Error("Failed to forward to tun", "error", err)
}
}
// Otherwise, drop. On linux, we should never see these packets - Linux
@@ -54,10 +58,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
if hostinfo == nil {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
WithField("fwPacket", fwPacket).
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks",
"vpnAddr", fwPacket.RemoteAddr,
"fwPacket", fwPacket,
)
}
return
}
@@ -72,11 +77,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
} else {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).
WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping outbound packet")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("dropping outbound packet",
"fwPacket", fwPacket,
"reason", dropReason,
)
}
}
}
@@ -93,7 +98,7 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
_, err := f.readers[q].Write(out)
if err != nil {
f.l.WithError(err).Error("Failed to write to tun")
f.l.Error("Failed to write to tun", "error", err)
}
}
@@ -108,11 +113,11 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
}
if len(out) > iputil.MaxRejectPacketSize {
if f.l.GetLevel() >= logrus.InfoLevel {
f.l.
WithField("packet", packet).
WithField("outPacket", out).
Info("rejectOutside: packet too big, not sending")
if f.l.Enabled(context.Background(), slog.LevelInfo) {
f.l.Info("rejectOutside: packet too big, not sending",
"packet", packet,
"outPacket", out,
)
}
return
}
@@ -184,10 +189,11 @@ func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cac
// This would also need to interact with unsafe_route updates through reloading the config or
// use of the use_system_route_table option
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("destination", destinationAddr).
WithField("originalGateway", gatewayAddr).
Debugln("Calculated gateway for ECMP not available, attempting other gateways")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Calculated gateway for ECMP not available, attempting other gateways",
"destination", destinationAddr,
"originalGateway", gatewayAddr,
)
}
for i := range gateways {
@@ -213,17 +219,18 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
fp := &firewall.Packet{}
err := newPacket(p, false, fp)
if err != nil {
f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
f.l.Warn("error while parsing outgoing packet for firewall check", "error", err)
return
}
// check if packet is in outbound fw rules
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
if dropReason != nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("fwPacket", fp).
WithField("reason", dropReason).
Debugln("dropping cached packet")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("dropping cached packet",
"fwPacket", fp,
"reason", dropReason,
)
}
return
}
@@ -239,9 +246,10 @@ func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.Message
})
if hostInfo == nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddr", vpnAddr).
Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes",
"vpnAddr", vpnAddr,
)
}
return
}
@@ -297,12 +305,12 @@ func (f *Interface) SendVia(via *HostInfo,
if noiseutil.EncryptLockNeeded {
via.ConnectionState.writeLock.Unlock()
}
via.logger(f.l).
WithField("outCap", cap(out)).
WithField("payloadLen", len(ad)).
WithField("headerLen", len(out)).
WithField("cipherOverhead", via.ConnectionState.eKey.Overhead()).
Error("SendVia out buffer not large enough for relay")
via.logger(f.l).Error("SendVia out buffer not large enough for relay",
"outCap", cap(out),
"payloadLen", len(ad),
"headerLen", len(out),
"cipherOverhead", via.ConnectionState.eKey.Overhead(),
)
return
}
@@ -322,12 +330,12 @@ func (f *Interface) SendVia(via *HostInfo,
via.ConnectionState.writeLock.Unlock()
}
if err != nil {
via.logger(f.l).WithError(err).Info("Failed to EncryptDanger in sendVia")
via.logger(f.l).Info("Failed to EncryptDanger in sendVia", "error", err)
return
}
err = f.writers[0].WriteTo(out, via.remote)
if err != nil {
via.logger(f.l).WithError(err).Info("Failed to WriteTo in sendVia")
via.logger(f.l).Info("Failed to WriteTo in sendVia", "error", err)
}
f.connectionManager.RelayUsed(relay.LocalIndex)
}
@@ -366,8 +374,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
hostinfo.lastRebindCount = f.rebindCount
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Lighthouse update triggered for punch due to rebind counter",
"vpnAddrs", hostinfo.vpnAddrs,
)
}
}
@@ -377,24 +387,30 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
ci.writeLock.Unlock()
}
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).WithField("counter", c).
WithField("attemptedCounter", c).
Error("Failed to encrypt outgoing packet")
hostinfo.logger(f.l).Error("Failed to encrypt outgoing packet",
"error", err,
"udpAddr", remote,
"counter", c,
"attemptedCounter", c,
)
return
}
if remote.IsValid() {
err = f.writers[q].WriteTo(out, remote)
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
hostinfo.logger(f.l).Error("Failed to write outgoing packet",
"error", err,
"udpAddr", remote,
)
}
} else if hostinfo.remote.IsValid() {
err = f.writers[q].WriteTo(out, hostinfo.remote)
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
hostinfo.logger(f.l).Error("Failed to write outgoing packet",
"error", err,
"udpAddr", remote,
)
}
} else {
// Try to send via a relay
@@ -402,7 +418,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
if err != nil {
hostinfo.relayState.DeleteRelay(relayIP)
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
hostinfo.logger(f.l).Info("sendNoMetrics failed to find HostInfo",
"relay", relayIP,
"error", err,
)
continue
}
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net/netip"
"sync"
"sync/atomic"
@@ -12,7 +13,7 @@ import (
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
@@ -46,7 +47,7 @@ type InterfaceConfig struct {
reQueryWait time.Duration
ConntrackCacheTimeout time.Duration
l *logrus.Logger
l *slog.Logger
}
type Interface struct {
@@ -100,7 +101,7 @@ type Interface struct {
messageMetrics *MessageMetrics
cachedPacketMetrics *cachedPacketMetrics
l *logrus.Logger
l *slog.Logger
}
type EncWriter interface {
@@ -223,13 +224,16 @@ func (f *Interface) activate() error {
addr, err := f.outside.LocalAddr()
if err != nil {
f.l.WithError(err).Error("Failed to get udp listen address")
f.l.Error("Failed to get udp listen address", "error", err)
}
f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks).
WithField("build", f.version).WithField("udpAddr", addr).
WithField("boringcrypto", boringEnabled()).
Info("Nebula interface is active")
f.l.Info("Nebula interface is active",
"interface", f.inside.Name(),
"networks", f.myVpnNetworks,
"build", f.version,
"udpAddr", addr,
"boringcrypto", boringEnabled(),
)
if f.routines > 1 {
if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() {
@@ -305,7 +309,7 @@ func (f *Interface) listenOut(i int) {
li = f.outside
}
ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.conntrackCacheTimeout)
ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
lhh := f.lightHouse.NewRequestHandler()
plaintext := make([]byte, udp.MTU)
h := &header.H{}
@@ -313,15 +317,15 @@ func (f *Interface) listenOut(i int) {
nb := make([]byte, 12, 12)
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get())
})
if err != nil && !f.closed.Load() {
f.l.WithError(err).Error("Error while reading inbound packet, closing")
f.l.Error("Error while reading inbound packet, closing", "error", err)
f.onFatal(err)
}
f.l.Debugf("underlay reader %v is done", i)
f.l.Debug("underlay reader is done", "reader", i)
}
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
@@ -330,22 +334,22 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.conntrackCacheTimeout)
conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
for {
n, err := reader.Read(packet)
if err != nil {
if !f.closed.Load() {
f.l.WithError(err).WithField("reader", i).Error("Error while reading outbound packet, closing")
f.l.Error("Error while reading outbound packet, closing", "error", err, "reader", i)
f.onFatal(err)
}
break
}
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get())
}
f.l.Debugf("overlay reader %v is done", i)
f.l.Debug("overlay reader is done", "reader", i)
}
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
@@ -365,7 +369,7 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) {
if initial || c.HasChanged("pki.disconnect_invalid") {
f.disconnectInvalid.Store(c.GetBool("pki.disconnect_invalid", true))
if !initial {
f.l.Infof("pki.disconnect_invalid changed to %v", f.disconnectInvalid.Load())
f.l.Info("pki.disconnect_invalid changed", "value", f.disconnectInvalid.Load())
}
}
}
@@ -379,7 +383,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
if err != nil {
f.l.WithError(err).Error("Error while creating firewall during reload")
f.l.Error("Error while creating firewall during reload", "error", err)
return
}
@@ -392,10 +396,11 @@ func (f *Interface) reloadFirewall(c *config.C) {
// If rulesVersion is back to zero, we have wrapped all the way around. Be
// safe and just reset conntrack in this case.
if fw.rulesVersion == 0 {
f.l.WithField("firewallHashes", fw.GetRuleHashes()).
WithField("oldFirewallHashes", oldFw.GetRuleHashes()).
WithField("rulesVersion", fw.rulesVersion).
Warn("firewall rulesVersion has overflowed, resetting conntrack")
f.l.Warn("firewall rulesVersion has overflowed, resetting conntrack",
"firewallHashes", fw.GetRuleHashes(),
"oldFirewallHashes", oldFw.GetRuleHashes(),
"rulesVersion", fw.rulesVersion,
)
} else {
fw.Conntrack = conntrack
}
@@ -403,10 +408,11 @@ func (f *Interface) reloadFirewall(c *config.C) {
f.firewall = fw
oldFw.Destroy()
f.l.WithField("firewallHashes", fw.GetRuleHashes()).
WithField("oldFirewallHashes", oldFw.GetRuleHashes()).
WithField("rulesVersion", fw.rulesVersion).
Info("New firewall has been installed")
f.l.Info("New firewall has been installed",
"firewallHashes", fw.GetRuleHashes(),
"oldFirewallHashes", oldFw.GetRuleHashes(),
"rulesVersion", fw.rulesVersion,
)
}
func (f *Interface) reloadSendRecvError(c *config.C) {
@@ -428,8 +434,7 @@ func (f *Interface) reloadSendRecvError(c *config.C) {
}
}
f.l.WithField("sendRecvError", f.sendRecvErrorConfig.String()).
Info("Loaded send_recv_error config")
f.l.Info("Loaded send_recv_error config", "sendRecvError", f.sendRecvErrorConfig.String())
}
}
@@ -452,8 +457,7 @@ func (f *Interface) reloadAcceptRecvError(c *config.C) {
}
}
f.l.WithField("acceptRecvError", f.acceptRecvErrorConfig.String()).
Info("Loaded accept_recv_error config")
f.l.Info("Loaded accept_recv_error config", "acceptRecvError", f.acceptRecvErrorConfig.String())
}
}
@@ -527,7 +531,7 @@ func (f *Interface) Close() error {
for i, u := range f.writers {
err := u.Close()
if err != nil {
f.l.WithError(err).WithField("writer", i).Error("Error while closing udp socket")
f.l.Error("Error while closing udp socket", "error", err, "writer", i)
errs = append(errs, err)
}
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/binary"
"errors"
"fmt"
"log/slog"
"net"
"net/netip"
"slices"
@@ -15,10 +16,10 @@ import (
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/logging"
"github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util"
)
@@ -76,12 +77,12 @@ type LightHouse struct {
metrics *MessageMetrics
metricHolepunchTx metrics.Counter
l *logrus.Logger
l *slog.Logger
}
// NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
// addrMap should be nil unless this is during a config reload
func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) {
func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) {
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
nebulaPort := uint32(c.GetInt("listen.port", 0))
if amLighthouse && nebulaPort == 0 {
@@ -133,7 +134,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
case *util.ContextualError:
v.Log(l)
case error:
l.WithError(err).Error("failed to reload lighthouse")
l.Error("failed to reload lighthouse", "error", err)
}
})
@@ -205,8 +206,10 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
//TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used
addr := addrs[0].Unmap()
if lh.myVpnNetworksTable.Contains(addr) {
lh.l.WithField("addr", rawAddr).WithField("entry", i+1).
Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range")
lh.l.Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range",
"addr", rawAddr,
"entry", i+1,
)
continue
}
@@ -224,7 +227,9 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
lh.interval.Store(int64(c.GetInt("lighthouse.interval", 10)))
if !initial {
lh.l.Infof("lighthouse.interval changed to %v", lh.interval.Load())
lh.l.Info("lighthouse.interval changed",
"interval", lh.interval.Load(),
)
if lh.updateCancel != nil {
// May not always have a running routine
@@ -336,9 +341,12 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
for _, v := range c.GetStringSlice("relay.relays", nil) {
configRIP, err := netip.ParseAddr(v)
if err != nil {
lh.l.WithField("relay", v).WithError(err).Warn("Parse relay from config failed")
lh.l.Warn("Parse relay from config failed",
"relay", v,
"error", err,
)
} else {
lh.l.WithField("relay", v).Info("Read relay from config")
lh.l.Info("Read relay from config", "relay", v)
relaysForMe = append(relaysForMe, configRIP)
}
}
@@ -363,8 +371,10 @@ func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
}
if !lh.myVpnNetworksTable.Contains(addr) {
lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}).
Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not")
lh.l.Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not",
"vpnAddr", addr,
"networks", lh.myVpnNetworks,
)
}
out[i] = addr
}
@@ -435,8 +445,11 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
}
if !lh.myVpnNetworksTable.Contains(vpnAddr) {
lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}).
Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work")
lh.l.Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work",
"vpnAddr", vpnAddr,
"networks", lh.myVpnNetworks,
"entry", i+1,
)
}
vals, ok := v.([]any)
@@ -537,12 +550,13 @@ func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) {
lh.Lock()
rm, ok := lh.addrMap[allVpnAddrs[0]]
if ok {
debugEnabled := lh.l.Enabled(context.Background(), slog.LevelDebug)
for _, addr := range allVpnAddrs {
srm := lh.addrMap[addr]
if srm == rm {
delete(lh.addrMap, addr)
if lh.l.Level >= logrus.DebugLevel {
lh.l.Debugf("deleting %s from lighthouse.", addr)
if debugEnabled {
lh.l.Debug("deleting from lighthouse", "vpnAddr", addr)
}
}
}
@@ -659,9 +673,12 @@ func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow).
Trace("remoteAllowList.Allow")
if lh.l.Enabled(context.Background(), logging.LevelTrace) {
lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow",
"vpnAddrs", vpnAddrs,
"udpAddr", to,
"allow", allow,
)
}
if !allow {
return false
@@ -678,9 +695,12 @@ func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bool {
udpAddr := protoV4AddrPortToNetAddrPort(to)
allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr())
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow).
Trace("remoteAllowList.Allow")
if lh.l.Enabled(context.Background(), logging.LevelTrace) {
lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow",
"vpnAddr", vpnAddr,
"udpAddr", udpAddr,
"allow", allow,
)
}
if !allow {
@@ -698,9 +718,12 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bo
func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bool {
udpAddr := protoV6AddrPortToNetAddrPort(to)
allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr())
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow).
Trace("remoteAllowList.Allow")
if lh.l.Enabled(context.Background(), logging.LevelTrace) {
lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow",
"vpnAddr", vpnAddr,
"udpAddr", udpAddr,
"allow", allow,
)
}
if !allow {
@@ -775,8 +798,10 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
if v == cert.Version1 {
if !addr.Is4() {
lh.l.WithField("queryVpnAddr", addr).WithField("lighthouseAddr", lhVpnAddr).
Error("Can't query lighthouse for v6 address using a v1 protocol")
lh.l.Error("Can't query lighthouse for v6 address using a v1 protocol",
"queryVpnAddr", addr,
"lighthouseAddr", lhVpnAddr,
)
continue
}
@@ -787,9 +812,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
v1Query, err = msg.Marshal()
if err != nil {
lh.l.WithError(err).WithField("queryVpnAddr", addr).
WithField("lighthouseAddr", lhVpnAddr).
Error("Failed to marshal lighthouse v1 query payload")
lh.l.Error("Failed to marshal lighthouse v1 query payload",
"error", err,
"queryVpnAddr", addr,
"lighthouseAddr", lhVpnAddr,
)
continue
}
}
@@ -804,9 +831,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
v2Query, err = msg.Marshal()
if err != nil {
lh.l.WithError(err).WithField("queryVpnAddr", addr).
WithField("lighthouseAddr", lhVpnAddr).
Error("Failed to marshal lighthouse v2 query payload")
lh.l.Error("Failed to marshal lighthouse v2 query payload",
"error", err,
"queryVpnAddr", addr,
"lighthouseAddr", lhVpnAddr,
)
continue
}
}
@@ -815,7 +844,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
queried++
} else {
lh.l.Debugf("Can not query lighthouse for %v using unknown protocol version: %v", addr, v)
lh.l.Debug("unsupported protocol version",
"op", "query",
"queryVpnAddr", addr,
"version", v,
)
continue
}
}
@@ -907,8 +940,9 @@ func (lh *LightHouse) SendUpdate() {
if v == cert.Version1 {
if v1Update == nil {
if !lh.myVpnNetworks[0].Addr().Is4() {
lh.l.WithField("lighthouseAddr", lhVpnAddr).
Warn("cannot update lighthouse using v1 protocol without an IPv4 address")
lh.l.Warn("cannot update lighthouse using v1 protocol without an IPv4 address",
"lighthouseAddr", lhVpnAddr,
)
continue
}
var relays []uint32
@@ -932,8 +966,10 @@ func (lh *LightHouse) SendUpdate() {
v1Update, err = msg.Marshal()
if err != nil {
lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr).
Error("Error while marshaling for lighthouse v1 update")
lh.l.Error("Error while marshaling for lighthouse v1 update",
"error", err,
"lighthouseAddr", lhVpnAddr,
)
continue
}
}
@@ -959,8 +995,10 @@ func (lh *LightHouse) SendUpdate() {
v2Update, err = msg.Marshal()
if err != nil {
lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr).
Error("Error while marshaling for lighthouse v2 update")
lh.l.Error("Error while marshaling for lighthouse v2 update",
"error", err,
"lighthouseAddr", lhVpnAddr,
)
continue
}
}
@@ -969,7 +1007,10 @@ func (lh *LightHouse) SendUpdate() {
updated++
} else {
lh.l.Debugf("Can not update lighthouse using unknown protocol version: %v", v)
lh.l.Debug("unsupported protocol version",
"op", "update",
"version", v,
)
continue
}
}
@@ -983,7 +1024,7 @@ type LightHouseHandler struct {
out []byte
pb []byte
meta *NebulaMeta
l *logrus.Logger
l *slog.Logger
}
func (lh *LightHouse) NewRequestHandler() *LightHouseHandler {
@@ -1032,14 +1073,19 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [
n := lhh.resetMeta()
err := n.Unmarshal(p)
if err != nil {
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr).
Error("Failed to unmarshal lighthouse packet")
lhh.l.Error("Failed to unmarshal lighthouse packet",
"error", err,
"vpnAddrs", fromVpnAddrs,
"udpAddr", rAddr,
)
return
}
if n.Details == nil {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr).
Error("Invalid lighthouse update")
lhh.l.Error("Invalid lighthouse update",
"vpnAddrs", fromVpnAddrs,
"udpAddr", rAddr,
)
return
}
@@ -1067,25 +1113,29 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) {
// Exit if we don't answer queries
if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugln("I don't answer queries, but received from: ", addr)
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("I don't answer queries, but received one", "from", addr)
}
return
}
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
Debugln("Dropping malformed HostQuery")
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("Dropping malformed HostQuery",
"from", fromVpnAddrs,
"details", n.Details,
)
}
return
}
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
// this case really shouldn't be possible to represent, but reject it anyway.
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
Debugln("invalid vpn addr for v1 handleHostQuery")
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("invalid vpn addr for v1 handleHostQuery",
"vpnAddrs", fromVpnAddrs,
"queryVpnAddr", queryVpnAddr,
)
}
return
}
@@ -1110,7 +1160,10 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
}
if err != nil {
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host query reply")
lhh.l.Error("Failed to marshal lighthouse host query reply",
"error", err,
"vpnAddrs", fromVpnAddrs,
)
return
}
@@ -1138,8 +1191,10 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
if ok {
whereToPunch = newDest
} else {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("unable to punch to host, no addresses in common",
"to", crt.Networks(),
)
}
}
}
@@ -1165,7 +1220,10 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
}
if err != nil {
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host was queried for")
lhh.l.Error("Failed to marshal lighthouse host was queried for",
"error", err,
"vpnAddrs", fromVpnAddrs,
)
return
}
@@ -1207,8 +1265,11 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
}
} else {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("version", v).Debug("unsupported protocol version")
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("unsupported protocol version",
"op", "coalesceAnswers",
"version", v,
)
}
}
}
@@ -1221,8 +1282,11 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply")
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Error("dropping malformed HostQueryReply",
"error", err,
"vpnAddrs", fromVpnAddrs,
)
}
return
}
@@ -1247,8 +1311,8 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs)
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("I am not a lighthouse, do not take host updates", "from", fromVpnAddrs)
}
return
}
@@ -1271,8 +1335,11 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
//Simple check that the host sent this not someone else, if detailsVpnAddr is filled
if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("Host sent invalid update",
"vpnAddrs", fromVpnAddrs,
"answer", detailsVpnAddr,
)
}
return
}
@@ -1294,7 +1361,9 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
switch useVersion {
case cert.Version1:
if !fromVpnAddrs[0].Is4() {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
lhh.l.Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message",
"vpnAddrs", fromVpnAddrs,
)
return
}
vpnAddrB := fromVpnAddrs[0].As4()
@@ -1302,13 +1371,16 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
case cert.Version2:
// do nothing, we want to send a blank message
default:
lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
lhh.l.Error("invalid protocol version", "useVersion", useVersion)
return
}
ln, err := n.MarshalTo(lhh.pb)
if err != nil {
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host update ack")
lhh.l.Error("Failed to marshal lighthouse host update ack",
"error", err,
"vpnAddrs", fromVpnAddrs,
)
return
}
@@ -1325,8 +1397,11 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification")
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("dropping invalid HostPunchNotification",
"details", n.Details,
"error", err,
)
}
return
}
@@ -1343,8 +1418,11 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
lhh.lh.punchConn.WriteTo(empty, vpnPeer)
}()
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr)
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("Punching",
"vpnPeer", vpnPeer,
"logVpnAddr", logVpnAddr,
)
}
}
@@ -1369,8 +1447,10 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
if lhh.lh.punchy.GetRespond() {
go func() {
time.Sleep(lhh.lh.punchy.GetRespondDelay())
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr)
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debug("Sending a nebula test packet",
"vpnAddr", detailsVpnAddr,
)
}
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
// for each punchBack packet. We should move this into a timerwheel or a single goroutine

View File

@@ -1,45 +0,0 @@
package nebula
import (
"fmt"
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
)
func configLogger(l *logrus.Logger, c *config.C) error {
// set up our logging level
logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
if err != nil {
return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
}
l.SetLevel(logLevel)
disableTimestamp := c.GetBool("logging.disable_timestamp", false)
timestampFormat := c.GetString("logging.timestamp_format", "")
fullTimestamp := (timestampFormat != "")
if timestampFormat == "" {
timestampFormat = time.RFC3339
}
logFormat := strings.ToLower(c.GetString("logging.format", "text"))
switch logFormat {
case "text":
l.Formatter = &logrus.TextFormatter{
TimestampFormat: timestampFormat,
FullTimestamp: fullTimestamp,
DisableTimestamp: disableTimestamp,
}
case "json":
l.Formatter = &logrus.JSONFormatter{
TimestampFormat: timestampFormat,
DisableTimestamp: disableTimestamp,
}
default:
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
}
return nil
}

233
logging/logger.go Normal file
View File

@@ -0,0 +1,233 @@
// Package logging wires the nebula runtime-reconfigurable slog handler used
// by nebula.Main and the nebula CLI binaries. Callers build a logger with
// NewLogger, then call ApplyConfig at startup and from a config reload
// callback to push logging.level, logging.format, and
// logging.disable_timestamp changes onto the logger without rebuilding it.
package logging
import (
"context"
"fmt"
"io"
"log/slog"
"strings"
"sync/atomic"
"time"
)
// Config is the subset of *config.C that ApplyConfig reads. Declaring it
// here keeps the logging package from depending on config directly, which
// would cycle through the shared test helpers (test.NewLogger imports
// logging, and config's tests import test). *config.C satisfies this
// interface structurally with no adapter.
type Config interface {
GetString(key, def string) string
GetBool(key string, def bool) bool
}
// LevelTrace is a custom slog level below Debug, used when logging.level is
// "trace". slog has no builtin trace level; the value is one step below
// slog.LevelDebug in slog's 4-point spacing.
const LevelTrace = slog.Level(-8)
// NewLogger returns a *slog.Logger whose level, format, and timestamp
// emission can be reconfigured at runtime via ApplyConfig and the SSH debug
// commands. The default configuration is info-level text output so log
// calls made before ApplyConfig runs still produce output. Timestamps
// follow slog's default RFC3339Nano format; set logging.disable_timestamp
// in config to suppress them.
//
// ApplyConfig and the SSH commands discover the reconfig surface via
// structural type-assertion on l.Handler(), so replacement implementations
// (tests, platform-specific sinks) need only implement the subset of
// {SetLevel(slog.Level), SetFormat(string) error, SetDisableTimestamp(bool)}
// they care about. Callers that pass a plain *slog.Logger without these
// methods get a silent no-op; reconfiguration is always opt-in.
func NewLogger(w io.Writer) *slog.Logger {
return slog.New(NewHandler(w))
}
// NewHandler builds the *Handler that NewLogger wraps. Exported for
// platform-specific sinks (notably cmd/nebula-service/logs_windows.go)
// that want to wrap the handler with extra behavior, such as tagging each
// record with its Event Log severity, while still benefiting from all the
// level / format / timestamp / WithAttrs machinery implemented here.
func NewHandler(w io.Writer) *Handler {
root := &handlerRoot{}
root.level.Set(slog.LevelInfo)
opts := &slog.HandlerOptions{Level: &root.level}
return &Handler{
root: root,
text: slog.NewTextHandler(w, opts),
json: slog.NewJSONHandler(w, opts),
}
}
// handlerRoot carries the reconfiguration state shared by every logger
// derived from a NewHandler call. All fields are consulted on the log
// path and updated lock-free.
type handlerRoot struct {
level slog.LevelVar
disableTimestamp atomic.Bool
// jsonMode picks which of the pre-derived inner handlers Handler.Handle
// dispatches to. Flipping it propagates instantly to every derived logger
// without rebuilding or chain-replaying anything.
jsonMode atomic.Bool
}
// Handler is the slog.Handler returned by NewHandler. It holds two
// pre-derived slog handlers -- one text, one json -- both built from the
// same accumulated WithAttrs/WithGroup state. Handle picks which one to
// dispatch to based on handlerRoot.jsonMode, so a SetFormat call takes
// effect immediately across the whole process without having to rebuild
// any derived loggers.
type Handler struct {
root *handlerRoot
text slog.Handler
json slog.Handler
}
func (h *Handler) Enabled(_ context.Context, l slog.Level) bool {
return h.root.level.Level() <= l
}
func (h *Handler) Handle(ctx context.Context, r slog.Record) error {
if h.root.disableTimestamp.Load() {
r.Time = time.Time{}
}
if h.root.jsonMode.Load() {
return h.json.Handle(ctx, r)
}
return h.text.Handle(ctx, r)
}
func (h *Handler) WithAttrs(attrs []slog.Attr) slog.Handler {
if len(attrs) == 0 {
return h
}
return &Handler{
root: h.root,
text: h.text.WithAttrs(attrs),
json: h.json.WithAttrs(attrs),
}
}
func (h *Handler) WithGroup(name string) slog.Handler {
if name == "" {
return h
}
return &Handler{
root: h.root,
text: h.text.WithGroup(name),
json: h.json.WithGroup(name),
}
}
// SetLevel updates the effective log level. Propagates to every derived
// logger via the shared LevelVar.
func (h *Handler) SetLevel(level slog.Level) { h.root.level.Set(level) }
// GetLevel reports the current log level.
func (h *Handler) GetLevel() slog.Level { return h.root.level.Level() }
// SetFormat flips the output format atomically. Valid formats are "text"
// and "json". Every derived logger sees the new format on its next Handle
// call; no rebuild or registration is required.
func (h *Handler) SetFormat(format string) error {
switch format {
case "text":
h.root.jsonMode.Store(false)
case "json":
h.root.jsonMode.Store(true)
default:
return fmt.Errorf("unknown log format `%s`. possible formats: %s", format, []string{"text", "json"})
}
return nil
}
// GetFormat reports the currently selected format name.
func (h *Handler) GetFormat() string {
if h.root.jsonMode.Load() {
return "json"
}
return "text"
}
// SetDisableTimestamp toggles whether Handle zeroes r.Time before
// dispatching (slog's builtin text/json handlers skip emitting the time
// attribute on a zero time).
func (h *Handler) SetDisableTimestamp(v bool) { h.root.disableTimestamp.Store(v) }
// ApplyConfig reads logging.level, logging.format, and (optionally)
// logging.disable_timestamp from c and applies them to l. The reconfig
// surface is discovered via structural type-assertion on l.Handler(), so
// foreign handlers silently opt out of whichever capabilities they do not
// implement.
//
// nebula.Main does NOT call this function on your behalf; callers that want
// config-driven log level / format / timestamp updates invoke it at
// startup and register it as a reload callback themselves. This keeps the
// library from mutating an embedder's logger without their say-so.
func ApplyConfig(l *slog.Logger, c Config) error {
h := l.Handler()
lvl, err := ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
if err != nil {
return err
}
if ls, ok := h.(interface{ SetLevel(slog.Level) }); ok {
ls.SetLevel(lvl)
}
format := strings.ToLower(c.GetString("logging.format", "text"))
if fs, ok := h.(interface{ SetFormat(string) error }); ok {
if err := fs.SetFormat(format); err != nil {
return err
}
}
if ts, ok := h.(interface{ SetDisableTimestamp(bool) }); ok {
ts.SetDisableTimestamp(c.GetBool("logging.disable_timestamp", false))
}
return nil
}
// ParseLevel converts a config-string level name ("trace", "debug", "info",
// "warn"/"warning", "error", "fatal"/"panic") to a slog.Level. "fatal" and
// "panic" are accepted for backwards compatibility with pre-slog configs
// and both map to slog.LevelError.
func ParseLevel(s string) (slog.Level, error) {
switch s {
case "trace":
return LevelTrace, nil
case "debug":
return slog.LevelDebug, nil
case "info":
return slog.LevelInfo, nil
case "warn", "warning":
return slog.LevelWarn, nil
case "error":
return slog.LevelError, nil
case "fatal", "panic":
return slog.LevelError, nil
default:
return 0, fmt.Errorf("not a valid logging level: %q", s)
}
}
// LevelName returns a human-readable name for a slog.Level matching the
// strings accepted by ParseLevel.
func LevelName(l slog.Level) string {
switch {
case l <= LevelTrace:
return "trace"
case l <= slog.LevelDebug:
return "debug"
case l <= slog.LevelInfo:
return "info"
case l <= slog.LevelWarn:
return "warn"
default:
return "error"
}
}

View File

@@ -0,0 +1,90 @@
package logging
import (
"context"
"io"
"log/slog"
"testing"
)
// BenchmarkLogger_* compare the handler returned by NewLogger against a
// stock slog text handler. The key thing we care about is the per-log
// cost on a logger that has been derived via .With(), because that is the
// shape subsystems store on their structs (HostInfo.logger(),
// lh.l.With("subsystem", ...), etc.) and call from hot paths.
func BenchmarkLogger_Stock_RootInfo(b *testing.B) {
l := slog.New(slog.DiscardHandler)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.Info("hello", "i", i)
}
}
func BenchmarkLogger_Nebula_RootInfo(b *testing.B) {
l := NewLogger(io.Discard)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.Info("hello", "i", i)
}
}
func BenchmarkLogger_Stock_DerivedInfo(b *testing.B) {
l := slog.New(slog.DiscardHandler).With(
"subsystem", "bench",
"localIndex", 1234,
)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.Info("hello", "i", i)
}
}
func BenchmarkLogger_Nebula_DerivedInfo(b *testing.B) {
l := NewLogger(io.Discard).With(
"subsystem", "bench",
"localIndex", 1234,
)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.Info("hello", "i", i)
}
}
// Gated-off-path benchmarks: mimic the typical hot-path shape
// `if l.Enabled(ctx, slog.LevelDebug) { ... }` where the log is gated below
// the active level. This is the dominant pattern in inside.go/outside.go and
// what we pay on every packet.
func BenchmarkLogger_Stock_DerivedEnabledGateMiss(b *testing.B) {
l := slog.New(slog.DiscardHandler).With(
"subsystem", "bench",
"localIndex", 1234,
)
ctx := context.Background()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
if l.Enabled(ctx, slog.LevelDebug) {
l.Debug("hello", "i", i)
}
}
}
func BenchmarkLogger_Nebula_DerivedEnabledGateMiss(b *testing.B) {
l := NewLogger(io.Discard).With(
"subsystem", "bench",
"localIndex", 1234,
)
ctx := context.Background()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
if l.Enabled(ctx, slog.LevelDebug) {
l.Debug("hello", "i", i)
}
}
}

39
main.go
View File

@@ -3,13 +3,13 @@ package nebula
import (
"context"
"fmt"
"log/slog"
"net"
"net/netip"
"runtime/debug"
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/sshd"
@@ -20,7 +20,7 @@ import (
type m = map[string]any
func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
ctx, cancel := context.WithCancel(context.Background())
// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
defer func() {
@@ -33,11 +33,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
buildVersion = moduleVersion()
}
l := logger
l.Formatter = &logrus.TextFormatter{
FullTimestamp: true,
}
// Print the config if in test, the exit comes later
if configTest {
b, err := yaml.Marshal(c.Settings)
@@ -46,21 +41,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
// Print the final config
l.Println(string(b))
l.Info(string(b))
}
err := configLogger(l, c)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to configure the logger", err)
}
c.RegisterReloadCallback(func(c *config.C) {
err := configLogger(l, c)
if err != nil {
l.WithError(err).Error("Failed to configure the logger")
}
})
pki, err := NewPKIFromConfig(l, c)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
@@ -70,9 +53,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if err != nil {
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
}
l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
l.Info("Firewall started", "firewallHashes", fw.GetRuleHashes())
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
ssh, err := sshd.NewSSHServer(l.With("subsystem", "sshd"))
if err != nil {
return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
}
@@ -81,7 +64,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if c.GetBool("sshd.enabled", false) {
sshStart, err = configSSH(l, ssh, c)
if err != nil {
l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available")
l.Warn("Failed to configure sshd, ssh debugging will not be available", "error", err)
sshStart = nil
}
}
@@ -99,7 +82,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
routines = 1
}
if routines > 1 {
l.WithField("routines", routines).Info("Using multiple routines")
l.Info("Using multiple routines", "routines", routines)
}
} else {
// deprecated and undocumented
@@ -107,7 +90,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
udpQueues := c.GetInt("listen.routines", 1)
routines = max(tunQueues, udpQueues)
if routines != 1 {
l.WithField("routines", routines).Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead")
l.Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead", "routines", routines)
}
}
@@ -120,7 +103,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
conntrackCacheTimeout = 1 * time.Second
}
if conntrackCacheTimeout > 0 {
l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache")
l.Info("Using routine-local conntrack cache", "duration", conntrackCacheTimeout)
}
var tun overlay.Device
@@ -166,7 +149,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
for i := 0; i < routines; i++ {
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
l.Info("listening", "addr", netip.AddrPortFrom(listenHost, uint16(port)))
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
if err != nil {
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
@@ -217,7 +200,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
ds, err := newDnsServerFromConfig(ctx, l, pki.getCertState(), hostMap, c)
if err != nil {
l.WithError(err).Warn("Failed to start DNS responder")
l.Warn("Failed to start DNS responder", "error", err)
}
ifConfig := &InterfaceConfig{

View File

@@ -1,15 +1,16 @@
package nebula
import (
"context"
"encoding/binary"
"errors"
"log/slog"
"net/netip"
"time"
"github.com/google/gopacket/layers"
"golang.org/x/net/ipv6"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"golang.org/x/net/ipv4"
@@ -24,7 +25,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
if len(packet) > 1 {
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err)
f.l.Info("Error while parsing inbound packet",
"from", via,
"error", err,
"packet", packet,
)
}
return
}
@@ -32,8 +37,8 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
//l.Error("in packet ", header, packet[HeaderLen:])
if !via.IsRelayed {
if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("from", via).Debug("Refusing to process double encrypted packet")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Refusing to process double encrypted packet", "from", via)
}
return
}
@@ -87,7 +92,10 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
if !ok {
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
// its internal mapping. This should never happen.
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
hostinfo.logger(f.l).Error("HostInfo missing remote relay index",
"vpnAddrs", hostinfo.vpnAddrs,
"remoteIndex", h.RemoteIndex,
)
return
}
@@ -108,7 +116,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
// Find the target HostInfo relay object
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
if err != nil {
hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
hostinfo.logger(f.l).Info("Failed to find target host info by ip",
"relayTo", relay.PeerAddr,
"error", err,
"hostinfo.vpnAddrs", hostinfo.vpnAddrs,
)
return
}
@@ -124,7 +136,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
}
} else {
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
hostinfo.logger(f.l).Info("Unexpected target relay state",
"relayTo", relay.PeerAddr,
"relayFrom", hostinfo.vpnAddrs[0],
"targetRelayState", targetRelay.State,
)
return
}
}
@@ -138,9 +154,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("from", via).
WithField("packet", packet).
Error("Failed to decrypt lighthouse packet")
hostinfo.logger(f.l).Error("Failed to decrypt lighthouse packet",
"error", err,
"from", via,
"packet", packet,
)
return
}
@@ -157,9 +175,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("from", via).
WithField("packet", packet).
Error("Failed to decrypt test packet")
hostinfo.logger(f.l).Error("Failed to decrypt test packet",
"error", err,
"from", via,
"packet", packet,
)
return
}
@@ -192,14 +212,15 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
}
_, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("from", via).
WithField("packet", packet).
Error("Failed to decrypt CloseTunnel packet")
hostinfo.logger(f.l).Error("Failed to decrypt CloseTunnel packet",
"error", err,
"from", via,
"packet", packet,
)
return
}
hostinfo.logger(f.l).WithField("from", via).
Info("Close tunnel received, tearing down.")
hostinfo.logger(f.l).Info("Close tunnel received, tearing down.", "from", via)
f.closeTunnel(hostinfo)
return
@@ -211,9 +232,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("from", via).
WithField("packet", packet).
Error("Failed to decrypt Control packet")
hostinfo.logger(f.l).Error("Failed to decrypt Control packet",
"error", err,
"from", via,
"packet", packet,
)
return
}
@@ -221,7 +244,9 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("Unexpected packet received", "from", via)
}
return
}
@@ -247,20 +272,27 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) {
if !via.IsRelayed && hostinfo.remote != via.UdpAddr {
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
hostinfo.logger(f.l).WithField("newAddr", via.UdpAddr).Debug("lighthouse.remote_allow_list denied roaming")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("lighthouse.remote_allow_list denied roaming", "newAddr", via.UdpAddr)
}
return
}
if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("Suppressing roam back to previous remote",
"suppressSeconds", RoamingSuppressSeconds,
"udpAddr", hostinfo.remote,
"newAddr", via.UdpAddr,
)
}
return
}
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
Info("Host roamed to new udp ip/port.")
hostinfo.logger(f.l).Info("Host roamed to new udp ip/port.",
"udpAddr", hostinfo.remote,
"newAddr", via.UdpAddr,
)
hostinfo.lastRoam = time.Now()
hostinfo.lastRoamRemote = hostinfo.remote
hostinfo.SetRemote(via.UdpAddr)
@@ -491,8 +523,9 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
}
if !hostinfo.ConnectionState.window.Update(f.l, mc) {
hostinfo.logger(f.l).WithField("header", h).
Debugln("dropping out of window packet")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("dropping out of window packet", "header", h)
}
return nil, errors.New("out of window packet")
}
@@ -504,20 +537,23 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
hostinfo.logger(f.l).Error("Failed to decrypt packet", "error", err)
return false
}
err = newPacket(out, true, fwPacket)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
Warnf("Error while validating inbound packet")
hostinfo.logger(f.l).Warn("Error while validating inbound packet",
"error", err,
"packet", out,
)
return false
}
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
Debugln("dropping out of window packet")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("dropping out of window packet", "fwPacket", fwPacket)
}
return false
}
@@ -526,10 +562,11 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping inbound packet")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("dropping inbound packet",
"fwPacket", fwPacket,
"reason", dropReason,
)
}
return false
}
@@ -537,7 +574,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
f.connectionManager.In(hostinfo)
_, err = f.readers[q].Write(out)
if err != nil {
f.l.WithError(err).Error("Failed to write to tun")
f.l.Error("Failed to write to tun", "error", err)
}
return true
}
@@ -553,35 +590,41 @@ func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) {
b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
_ = f.outside.WriteTo(b, endpoint)
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("index", index).
WithField("udpAddr", endpoint).
Debug("Recv error sent")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Recv error sent",
"index", index,
"udpAddr", endpoint,
)
}
}
func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
if !f.acceptRecvErrorConfig.ShouldRecvError(addr) {
f.l.WithField("index", h.RemoteIndex).
WithField("udpAddr", addr).
Debug("Recv error received, ignoring")
f.l.Debug("Recv error received, ignoring",
"index", h.RemoteIndex,
"udpAddr", addr,
)
return
}
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("index", h.RemoteIndex).
WithField("udpAddr", addr).
Debug("Recv error received")
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Recv error received",
"index", h.RemoteIndex,
"udpAddr", addr,
)
}
hostinfo := f.hostMap.QueryReverseIndex(h.RemoteIndex)
if hostinfo == nil {
f.l.WithField("remoteIndex", h.RemoteIndex).Debugln("Did not find remote index in main hostmap")
f.l.Debug("Did not find remote index in main hostmap", "remoteIndex", h.RemoteIndex)
return
}
if hostinfo.remote.IsValid() && hostinfo.remote != addr {
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
f.l.Info("Someone spoofing recv_errors?",
"addr", addr,
"hostinfoRemote", hostinfo.remote,
)
return
}

View File

@@ -1,4 +1,6 @@
package test
// Package overlaytest provides fakes of overlay.Device for tests that do
// not want to touch a real tun device or route table.
package overlaytest
import (
"errors"
@@ -8,6 +10,9 @@ import (
"github.com/slackhq/nebula/routing"
)
// NoopTun is an overlay.Device that silently discards every read and write.
// Useful in tests that need to construct a nebula Interface but do not
// exercise the datapath.
type NoopTun struct{}
func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways {

View File

@@ -2,6 +2,7 @@ package overlay
import (
"fmt"
"log/slog"
"math"
"net"
"net/netip"
@@ -9,7 +10,6 @@ import (
"strconv"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
)
@@ -48,11 +48,14 @@ func (r Route) String() string {
return s
}
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) {
func makeRouteTree(l *slog.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) {
routeTree := new(bart.Table[routing.Gateways])
for _, r := range routes {
if !allowMTU && r.MTU > 0 {
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
l.Warn("route MTU is not supported on this platform",
"goos", runtime.GOOS,
"route", r,
)
}
gateways := r.Via

View File

@@ -295,7 +295,7 @@ func Test_makeRouteTree(t *testing.T) {
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
require.NoError(t, err)
assert.Len(t, routes, 2)
routeTree, err := makeRouteTree(l, routes, true)
routeTree, err := makeRouteTree(test.NewLogger(), routes, true)
require.NoError(t, err)
ip, err := netip.ParseAddr("1.0.0.2")
@@ -367,7 +367,7 @@ func Test_makeMultipathUnsafeRouteTree(t *testing.T) {
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
require.NoError(t, err)
assert.Len(t, routes, 3)
routeTree, err := makeRouteTree(l, routes, true)
routeTree, err := makeRouteTree(test.NewLogger(), routes, true)
require.NoError(t, err)
ip, err := netip.ParseAddr("192.168.86.1")

View File

@@ -2,10 +2,10 @@ package overlay
import (
"fmt"
"log/slog"
"net"
"net/netip"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
)
@@ -22,9 +22,9 @@ func (e *NameError) Error() string {
}
// TODO: We may be able to remove routines
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
type DeviceFactory func(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
func NewDeviceFromConfig(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
switch {
case c.GetBool("tun.disabled", false):
tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
@@ -36,7 +36,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Pref
}
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
return func(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
return newTunFromFd(c, l, *fd, vpnNetworks)
}
}

View File

@@ -6,12 +6,12 @@ package overlay
import (
"fmt"
"io"
"log/slog"
"net/netip"
"os"
"sync/atomic"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
@@ -23,10 +23,10 @@ type tun struct {
vpnNetworks []netip.Prefix
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
l *slog.Logger
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
@@ -53,7 +53,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
return t, nil
}
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in Android")
}

View File

@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net/netip"
"os"
"sync/atomic"
@@ -14,7 +15,6 @@ import (
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
@@ -30,7 +30,7 @@ type tun struct {
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
linkAddr *netroute.LinkAddr
l *logrus.Logger
l *slog.Logger
// cache out buffer since we need to prepend 4 bytes for tun metadata
out []byte
@@ -79,7 +79,7 @@ type ifreqAlias6 struct {
Lifetime addrLifetime
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
name := c.GetString("tun.dev", "")
ifIndex := -1
if name != "" && name != "utun" {
@@ -153,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
return
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
}
@@ -389,8 +389,7 @@ func (t *tun) addRoutes(logErrors bool) error {
err := addRoute(r.Cidr, t.linkAddr)
if err != nil {
if errors.Is(err, unix.EEXIST) {
t.l.WithField("route", r.Cidr).
Warnf("unable to add unsafe_route, identical route already exists")
t.l.Warn("unable to add unsafe_route, identical route already exists", "route", r.Cidr)
} else {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
if logErrors {
@@ -400,7 +399,7 @@ func (t *tun) addRoutes(logErrors bool) error {
}
}
} else {
t.l.WithField("route", r).Info("Added route")
t.l.Info("Added route", "route", r)
}
}
@@ -415,9 +414,9 @@ func (t *tun) removeRoutes(routes []Route) error {
err := delRoute(r.Cidr, t.linkAddr)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
t.l.Error("Failed to remove route", "error", err, "route", r)
} else {
t.l.WithField("route", r).Info("Removed route")
t.l.Info("Removed route", "route", r)
}
}
return nil

View File

@@ -1,13 +1,14 @@
package overlay
import (
"context"
"fmt"
"io"
"log/slog"
"net/netip"
"strings"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/routing"
)
@@ -19,10 +20,10 @@ type disabledTun struct {
// Track these metrics since we don't have the tun device to do it for us
tx metrics.Counter
rx metrics.Counter
l *logrus.Logger
l *slog.Logger
}
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *slog.Logger) *disabledTun {
tun := &disabledTun{
vpnNetworks: vpnNetworks,
read: make(chan []byte, queueLen),
@@ -67,8 +68,8 @@ func (t *disabledTun) Read(b []byte) (int, error) {
}
t.tx.Inc(1)
if t.l.Level >= logrus.DebugLevel {
t.l.WithField("raw", prettyPacket(r)).Debugf("Write payload")
if t.l.Enabled(context.Background(), slog.LevelDebug) {
t.l.Debug("Write payload", "raw", prettyPacket(r))
}
return copy(b, r), nil
@@ -85,7 +86,7 @@ func (t *disabledTun) handleICMPEchoRequest(b []byte) bool {
select {
case t.read <- out:
default:
t.l.Debugf("tun_disabled: dropped ICMP Echo Reply response")
t.l.Debug("tun_disabled: dropped ICMP Echo Reply response")
}
return true
@@ -96,11 +97,11 @@ func (t *disabledTun) Write(b []byte) (int, error) {
// Check for ICMP Echo Request before spending time doing the full parsing
if t.handleICMPEchoRequest(b) {
if t.l.Level >= logrus.DebugLevel {
t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request")
if t.l.Enabled(context.Background(), slog.LevelDebug) {
t.l.Debug("Disabled tun responded to ICMP Echo Request", "raw", prettyPacket(b))
}
} else if t.l.Level >= logrus.DebugLevel {
t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload")
} else if t.l.Enabled(context.Background(), slog.LevelDebug) {
t.l.Debug("Disabled tun received unexpected payload", "raw", prettyPacket(b))
}
return len(b), nil
}

View File

@@ -9,6 +9,7 @@ import (
"fmt"
"io"
"io/fs"
"log/slog"
"net/netip"
"os"
"sync/atomic"
@@ -17,8 +18,9 @@ import (
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
@@ -93,7 +95,7 @@ type tun struct {
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
linkAddr *netroute.LinkAddr
l *logrus.Logger
l *slog.Logger
fd int
shutdownR int // read end of the shutdown pipe; closing the write end wakes blocked polls
@@ -243,7 +245,7 @@ func (t *tun) Close() error {
if t.fd >= 0 {
if err := unix.Close(t.fd); err != nil {
t.l.WithError(err).Error("Error closing device")
t.l.Error("Error closing device", "error", err)
}
t.fd = -1
}
@@ -264,7 +266,7 @@ func (t *tun) Close() error {
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
}
if err != nil {
t.l.WithError(err).Error("Error destroying tunnel")
t.l.Error("Error destroying tunnel", "error", err)
}
}()
@@ -277,11 +279,11 @@ func (t *tun) Close() error {
return nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open existing tun device
var fd int
var err error
@@ -584,7 +586,7 @@ func (t *tun) addRoutes(logErrors bool) error {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
t.l.Info("Added route", "route", r)
}
}
@@ -599,9 +601,9 @@ func (t *tun) removeRoutes(routes []Route) error {
err := delRoute(r.Cidr, t.linkAddr)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
t.l.Error("Failed to remove route", "error", err, "route", r)
} else {
t.l.WithField("route", r).Info("Removed route")
t.l.Info("Removed route", "route", r)
}
}
return nil

View File

@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net/netip"
"os"
"sync"
@@ -14,7 +15,6 @@ import (
"syscall"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
@@ -25,14 +25,14 @@ type tun struct {
vpnNetworks []netip.Prefix
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
l *slog.Logger
}
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in iOS")
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
t := &tun{
vpnNetworks: vpnNetworks,

View File

@@ -7,6 +7,7 @@ import (
"encoding/binary"
"fmt"
"io"
"log/slog"
"net"
"net/netip"
"os"
@@ -17,7 +18,6 @@ import (
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
@@ -213,7 +213,7 @@ type tun struct {
routesFromSystem map[netip.Prefix]routing.Gateways
routesFromSystemLock sync.Mutex
l *logrus.Logger
l *slog.Logger
}
func (t *tun) Networks() []netip.Prefix {
@@ -238,7 +238,7 @@ type ifreqQLEN struct {
pad [8]byte
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
t, err := newTunGeneric(c, l, deviceFd, vpnNetworks)
if err != nil {
return nil, err
@@ -249,7 +249,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
return t, nil
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
@@ -299,7 +299,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
}
// newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error.
func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) {
func newTunGeneric(c *config.C, l *slog.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) {
tfd, err := newTunFd(fd)
if err != nil {
_ = unix.Close(fd)
@@ -378,16 +378,16 @@ func (t *tun) reload(c *config.C, initial bool) error {
if !initial {
if oldMaxMTU != newMaxMTU {
t.setMTU()
t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU)
t.l.Info("Set max MTU", "mtu", t.MaxMTU, "oldMTU", oldMaxMTU)
}
if oldDefaultMTU != newDefaultMTU {
for i := range t.vpnNetworks {
err := t.setDefaultRoute(t.vpnNetworks[i])
if err != nil {
t.l.Warn(err)
t.l.Warn(err.Error())
} else {
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
t.l.Info("Set default MTU", "mtu", t.DefaultMTU, "oldMTU", oldDefaultMTU)
}
}
}
@@ -492,9 +492,9 @@ func (t *tun) addIPs(link netlink.Link) error {
}
err = netlink.AddrDel(link, &al[i])
if err != nil {
t.l.WithError(err).Error("failed to remove address from tun address list")
t.l.Error("failed to remove address from tun address list", "error", err)
} else {
t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
t.l.Info("removed address not listed in cert(s)", "removed", al[i].String())
}
}
@@ -538,12 +538,12 @@ func (t *tun) Activate() error {
ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
// If we can't set the queue length nebula will still work but it may lead to packet loss
t.l.WithError(err).Error("Failed to set tun tx queue length")
t.l.Error("Failed to set tun tx queue length", "error", err)
}
const modeNone = 1
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
t.l.WithError(err).Warn("Failed to disable link local address generation")
t.l.Warn("Failed to disable link local address generation", "error", err)
}
if err = t.addIPs(link); err != nil {
@@ -582,7 +582,7 @@ func (t *tun) setMTU() {
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)}
if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well
t.l.WithError(err).Error("Failed to set tun mtu")
t.l.Error("Failed to set tun mtu", "error", err)
}
}
@@ -605,7 +605,7 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
}
err := netlink.RouteReplace(&nr)
if err != nil {
t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying")
t.l.Warn("Failed to set default route MTU, retrying", "error", err, "cidr", cidr)
//retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument`
for i := 0; i < 2; i++ {
time.Sleep(100 * time.Millisecond)
@@ -613,7 +613,11 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
if err == nil {
break
} else {
t.l.WithError(err).WithField("cidr", cidr).WithField("mtu", t.DefaultMTU).Warn("Failed to set default route MTU, retrying")
t.l.Warn("Failed to set default route MTU, retrying",
"error", err,
"cidr", cidr,
"mtu", t.DefaultMTU,
)
}
}
if err != nil {
@@ -658,7 +662,7 @@ func (t *tun) addRoutes(logErrors bool) error {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
t.l.Info("Added route", "route", r)
}
}
@@ -690,9 +694,9 @@ func (t *tun) removeRoutes(routes []Route) {
err := netlink.RouteDel(&nr)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
t.l.Error("Failed to remove route", "error", err, "route", r)
} else {
t.l.WithField("route", r).Info("Removed route")
t.l.Info("Removed route", "route", r)
}
}
}
@@ -721,11 +725,11 @@ func (t *tun) watchRoutes() {
netlinkOptions := netlink.RouteSubscribeOptions{
ReceiveBufferSize: t.useSystemRoutesBufferSize,
ReceiveBufferForceSize: t.useSystemRoutesBufferSize != 0,
ErrorCallback: func(e error) { t.l.WithError(e).Errorf("netlink error") },
ErrorCallback: func(e error) { t.l.Error("netlink error", "error", e) },
}
if err := netlink.RouteSubscribeWithOptions(rch, doneChan, netlinkOptions); err != nil {
t.l.WithError(err).Errorf("failed to subscribe to system route changes")
t.l.Error("failed to subscribe to system route changes", "error", err)
return
}
@@ -767,7 +771,7 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
link, err := netlink.LinkByName(t.Device)
if err != nil {
t.l.WithField("deviceName", t.Device).Error("Ignoring route update: failed to get link by name")
t.l.Error("Ignoring route update: failed to get link by name", "deviceName", t.Device)
return gateways
}
@@ -779,10 +783,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
} else {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
t.l.Debug("Ignoring route update, gateway is not in our network", "route", r)
}
} else {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
t.l.Debug("Ignoring route update, invalid gateway or via address", "route", r)
}
}
@@ -795,10 +799,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
} else {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
t.l.Debug("Ignoring route update, gateway is not in our network", "route", r)
}
} else {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
t.l.Debug("Ignoring route update, invalid gateway or via address", "route", r)
}
}
}
@@ -830,18 +834,18 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
gateways := t.getGatewaysFromRoute(&r.Route)
if len(gateways) == 0 {
// No gateways relevant to our network, no routing changes required.
t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
t.l.Debug("Ignoring route update, no gateways", "route", r)
return
}
if r.Dst == nil {
t.l.WithField("route", r).Debug("Ignoring route update, no destination address")
t.l.Debug("Ignoring route update, no destination address", "route", r)
return
}
dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
if !ok {
t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address")
t.l.Debug("Ignoring route update, invalid destination address", "route", r)
return
}
@@ -852,12 +856,12 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
t.routesFromSystemLock.Lock()
if r.Type == unix.RTM_NEWROUTE {
t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route")
t.l.Info("Adding route", "destination", dst, "via", gateways)
t.routesFromSystem[dst] = gateways
newTree.Insert(dst, gateways)
} else {
t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route")
t.l.Info("Removing route", "destination", dst, "via", gateways)
delete(t.routesFromSystem, dst)
newTree.Delete(dst)
}
@@ -888,18 +892,18 @@ func (t *tun) Close() error {
}
err := t.readers[i].Close()
if err != nil {
t.l.WithField("reader", i).WithError(err).Error("error closing tun reader")
t.l.Error("error closing tun reader", "reader", i, "error", err)
} else {
t.l.WithField("reader", i).Info("closed tun reader")
t.l.Info("closed tun reader", "reader", i)
}
}
//this is t.readers[0] too
err := t.tunFile.Close()
if err != nil {
t.l.WithField("reader", 0).WithError(err).Error("error closing tun reader")
t.l.Error("error closing tun reader", "reader", 0, "error", err)
} else {
t.l.WithField("reader", 0).Info("closed tun reader")
t.l.Info("closed tun reader", "reader", 0)
}
return err
}

View File

@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net/netip"
"os"
"regexp"
@@ -15,7 +16,6 @@ import (
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
@@ -63,18 +63,18 @@ type tun struct {
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
l *slog.Logger
f *os.File
fd int
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device
var err error
deviceName := c.GetString("tun.dev", "")
@@ -92,7 +92,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
err = unix.SetNonblock(fd, true)
if err != nil {
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
l.Warn("Failed to set the tun device as nonblocking", "error", err)
}
t := &tun{
@@ -416,7 +416,7 @@ func (t *tun) addRoutes(logErrors bool) error {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
t.l.Info("Added route", "route", r)
}
}
@@ -431,9 +431,9 @@ func (t *tun) removeRoutes(routes []Route) error {
err := delRoute(r.Cidr, t.vpnNetworks)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
t.l.Error("Failed to remove route", "error", err, "route", r)
} else {
t.l.WithField("route", r).Info("Removed route")
t.l.Info("Removed route", "route", r)
}
}
return nil

View File

@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net/netip"
"os"
"regexp"
@@ -15,7 +16,6 @@ import (
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
@@ -54,7 +54,7 @@ type tun struct {
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
l *slog.Logger
f *os.File
fd int
// cache out buffer since we need to prepend 4 bytes for tun metadata
@@ -63,11 +63,11 @@ type tun struct {
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device
var err error
deviceName := c.GetString("tun.dev", "")
@@ -85,7 +85,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
err = unix.SetNonblock(fd, true)
if err != nil {
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
l.Warn("Failed to set the tun device as nonblocking", "error", err)
}
t := &tun{
@@ -336,7 +336,7 @@ func (t *tun) addRoutes(logErrors bool) error {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
t.l.Info("Added route", "route", r)
}
}
@@ -351,9 +351,9 @@ func (t *tun) removeRoutes(routes []Route) error {
err := delRoute(r.Cidr, t.vpnNetworks)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
t.l.Error("Failed to remove route", "error", err, "route", r)
} else {
t.l.WithField("route", r).Info("Removed route")
t.l.Info("Removed route", "route", r)
}
}
return nil

View File

@@ -4,14 +4,15 @@
package overlay
import (
"context"
"fmt"
"io"
"log/slog"
"net/netip"
"os"
"sync/atomic"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
)
@@ -21,14 +22,14 @@ type TestTun struct {
vpnNetworks []netip.Prefix
Routes []Route
routeTree *bart.Table[routing.Gateways]
l *logrus.Logger
l *slog.Logger
closed atomic.Bool
rxPackets chan []byte // Packets to receive into nebula
TxPackets chan []byte // Packets transmitted outside by nebula
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
_, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true)
if err != nil {
return nil, err
@@ -49,7 +50,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
}, nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) {
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*TestTun, error) {
return nil, fmt.Errorf("newTunFromFd not supported")
}
@@ -61,8 +62,8 @@ func (t *TestTun) Send(packet []byte) {
return
}
if t.l.Level >= logrus.DebugLevel {
t.l.WithField("dataLen", len(packet)).Debug("Tun receiving injected packet")
if t.l.Enabled(context.Background(), slog.LevelDebug) {
t.l.Debug("Tun receiving injected packet", "dataLen", len(packet))
}
t.rxPackets <- packet
}

View File

@@ -7,6 +7,7 @@ import (
"crypto"
"fmt"
"io"
"log/slog"
"net/netip"
"os"
"path/filepath"
@@ -16,7 +17,6 @@ import (
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
@@ -33,16 +33,16 @@ type winTun struct {
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
l *slog.Logger
tun *wintun.NativeTun
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) {
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (Device, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
err := checkWinTunExists()
if err != nil {
return nil, fmt.Errorf("can not load the wintun driver: %w", err)
@@ -71,7 +71,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
if err != nil {
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
// Trying a second time resolves the issue.
l.WithError(err).Debug("Failed to create wintun device, retrying")
l.Debug("Failed to create wintun device, retrying", "error", err)
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
if err != nil {
return nil, &NameError{
@@ -170,7 +170,7 @@ func (t *winTun) addRoutes(logErrors bool) error {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
t.l.Info("Added route", "route", r)
}
if !foundDefault4 {
@@ -208,9 +208,9 @@ func (t *winTun) removeRoutes(routes []Route) error {
// See comment on luid.AddRoute
err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
t.l.Error("Failed to remove route", "error", err, "route", r)
} else {
t.l.WithField("route", r).Info("Removed route")
t.l.Info("Removed route", "route", r)
}
}
return nil

View File

@@ -2,14 +2,14 @@ package overlay
import (
"io"
"log/slog"
"net/netip"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
)
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
func NewUserDeviceFromConfig(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
return NewUserDevice(vpnNetworks)
}

18
pki.go
View File

@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net"
"net/netip"
"os"
@@ -15,7 +16,6 @@ import (
"time"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
@@ -24,7 +24,7 @@ import (
type PKI struct {
cs atomic.Pointer[CertState]
caPool atomic.Pointer[cert.CAPool]
l *logrus.Logger
l *slog.Logger
}
type CertState struct {
@@ -46,7 +46,7 @@ type CertState struct {
myVpnBroadcastAddrsTable *bart.Lite
}
func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
func NewPKIFromConfig(l *slog.Logger, c *config.C) (*PKI, error) {
pki := &PKI{l: l}
err := pki.reload(c, true)
if err != nil {
@@ -182,9 +182,9 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
p.cs.Store(newState)
if initial {
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
p.l.Debug("Client nebula certificate(s)", "cert", newState)
} else {
p.l.WithField("cert", newState).Info("Client certificate(s) refreshed from disk")
p.l.Info("Client certificate(s) refreshed from disk", "cert", newState)
}
return nil
}
@@ -196,7 +196,7 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
}
p.caPool.Store(caPool)
p.l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
p.l.Debug("Trusted CA fingerprints", "fingerprints", caPool.GetFingerprints())
return nil
}
@@ -487,7 +487,7 @@ func loadCertificate(b []byte) (cert.Certificate, []byte, error) {
return c, b, nil
}
func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
func loadCAPoolFromConfig(l *slog.Logger, c *config.C) (*cert.CAPool, error) {
caPathOrPEM := c.GetString("pki.ca", "")
if caPathOrPEM == "" {
return nil, errors.New("no pki.ca path or PEM data provided")
@@ -512,7 +512,7 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
for _, crt := range caPool.CAs {
if crt.Certificate.Expired(time.Now()) {
expired++
l.WithField("cert", crt).Warn("expired certificate present in CA pool")
l.Warn("expired certificate present in CA pool", "cert", crt)
}
}
@@ -530,7 +530,7 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
caPool.BlocklistFingerprint(fp)
}
l.WithField("fingerprintCount", len(bl)).Info("Blocklisted certificates")
l.Info("Blocklisted certificates", "fingerprintCount", len(bl))
}
return caPool, nil

View File

@@ -41,7 +41,7 @@ func BenchmarkReloadConfigWithCAs(b *testing.B) {
c := config.NewC(l)
require.NoError(b, c.Load(dir))
_, err := NewPKIFromConfig(l, c)
_, err := NewPKIFromConfig(test.NewLogger(), c)
require.NoError(b, err)
b.ReportAllocs()

View File

@@ -1,10 +1,10 @@
package nebula
import (
"log/slog"
"sync/atomic"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
)
@@ -14,10 +14,10 @@ type Punchy struct {
delay atomic.Int64
respondDelay atomic.Int64
punchEverything atomic.Bool
l *logrus.Logger
l *slog.Logger
}
func NewPunchyFromConfig(l *logrus.Logger, c *config.C) *Punchy {
func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy {
p := &Punchy{l: l}
p.reload(c, true)
@@ -62,7 +62,7 @@ func (p *Punchy) reload(c *config.C, initial bool) {
p.respond.Store(yes)
if !initial {
p.l.Infof("punchy.respond changed to %v", p.GetRespond())
p.l.Info("punchy.respond changed", "respond", p.GetRespond())
}
}
@@ -70,21 +70,21 @@ func (p *Punchy) reload(c *config.C, initial bool) {
if initial || c.HasChanged("punchy.delay") {
p.delay.Store((int64)(c.GetDuration("punchy.delay", time.Second)))
if !initial {
p.l.Infof("punchy.delay changed to %s", p.GetDelay())
p.l.Info("punchy.delay changed", "delay", p.GetDelay())
}
}
if initial || c.HasChanged("punchy.target_all_remotes") {
p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false))
if !initial {
p.l.WithField("target_all_remotes", p.GetTargetEverything()).Info("punchy.target_all_remotes changed")
p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", p.GetTargetEverything())
}
}
if initial || c.HasChanged("punchy.respond_delay") {
p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second)))
if !initial {
p.l.Infof("punchy.respond_delay changed to %s", p.GetRespondDelay())
p.l.Info("punchy.respond_delay changed", "respond_delay", p.GetRespondDelay())
}
}
}

View File

@@ -1,6 +1,8 @@
package nebula
import (
"context"
"log/slog"
"testing"
"time"
@@ -15,7 +17,7 @@ func TestNewPunchyFromConfig(t *testing.T) {
c := config.NewC(l)
// Test defaults
p := NewPunchyFromConfig(l, c)
p := NewPunchyFromConfig(test.NewLogger(), c)
assert.False(t, p.GetPunch())
assert.False(t, p.GetRespond())
assert.Equal(t, time.Second, p.GetDelay())
@@ -23,33 +25,33 @@ func TestNewPunchyFromConfig(t *testing.T) {
// punchy deprecation
c.Settings["punchy"] = true
p = NewPunchyFromConfig(l, c)
p = NewPunchyFromConfig(test.NewLogger(), c)
assert.True(t, p.GetPunch())
// punchy.punch
c.Settings["punchy"] = map[string]any{"punch": true}
p = NewPunchyFromConfig(l, c)
p = NewPunchyFromConfig(test.NewLogger(), c)
assert.True(t, p.GetPunch())
// punch_back deprecation
c.Settings["punch_back"] = true
p = NewPunchyFromConfig(l, c)
p = NewPunchyFromConfig(test.NewLogger(), c)
assert.True(t, p.GetRespond())
// punchy.respond
c.Settings["punchy"] = map[string]any{"respond": true}
c.Settings["punch_back"] = false
p = NewPunchyFromConfig(l, c)
p = NewPunchyFromConfig(test.NewLogger(), c)
assert.True(t, p.GetRespond())
// punchy.delay
c.Settings["punchy"] = map[string]any{"delay": "1m"}
p = NewPunchyFromConfig(l, c)
p = NewPunchyFromConfig(test.NewLogger(), c)
assert.Equal(t, time.Minute, p.GetDelay())
// punchy.respond_delay
c.Settings["punchy"] = map[string]any{"respond_delay": "1m"}
p = NewPunchyFromConfig(l, c)
p = NewPunchyFromConfig(test.NewLogger(), c)
assert.Equal(t, time.Minute, p.GetRespondDelay())
}
@@ -62,7 +64,7 @@ punchy:
delay: 1m
respond: false
`))
p := NewPunchyFromConfig(l, c)
p := NewPunchyFromConfig(test.NewLogger(), c)
assert.Equal(t, delay, p.GetDelay())
assert.False(t, p.GetRespond())
@@ -76,3 +78,158 @@ punchy:
assert.Equal(t, newDelay, p.GetDelay())
assert.True(t, p.GetRespond())
}
// The tests below pin the shape of each log line Punchy produces so changes
// cannot silently break whatever operators are grepping for. The assertions
// are on the structured message + attrs (e.g. "punchy.respond changed" with
// a respond=true field) rather than a formatted string.
//
// Punchy.reload also emits a spurious "Changing punchy.punch with reload is
// not supported" warning whenever any key under punchy changes, because of
// the c.HasChanged("punchy") fallback kept for the deprecated top-level
// punchy form. The tests filter by message rather than asserting total
// entry counts so that warning is tolerated without being locked into
// the format.
type capturedEntry struct {
Level slog.Level
Msg string
Attrs map[string]any
}
// capturingHandler is a slog.Handler that records each Record it receives so
// tests can assert on the level, message, and attribute map of individual log
// lines without coupling to any specific text format.
type capturingHandler struct {
entries []capturedEntry
}
func (h *capturingHandler) Enabled(_ context.Context, _ slog.Level) bool { return true }
func (h *capturingHandler) Handle(_ context.Context, r slog.Record) error {
e := capturedEntry{
Level: r.Level,
Msg: r.Message,
Attrs: make(map[string]any),
}
r.Attrs(func(a slog.Attr) bool {
e.Attrs[a.Key] = a.Value.Resolve().Any()
return true
})
h.entries = append(h.entries, e)
return nil
}
func (h *capturingHandler) WithAttrs(_ []slog.Attr) slog.Handler { return h }
func (h *capturingHandler) WithGroup(_ string) slog.Handler { return h }
func newCapturingPunchyLogger(t *testing.T) (*slog.Logger, *capturingHandler) {
t.Helper()
hook := &capturingHandler{}
return slog.New(hook), hook
}
func findEntry(t *testing.T, entries []capturedEntry, msg string) capturedEntry {
t.Helper()
for _, e := range entries {
if e.Msg == msg {
return e
}
}
t.Fatalf("no entry with message %q among %d entries", msg, len(entries))
return capturedEntry{}
}
func TestPunchy_LogFormat_InitialEnabled(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {punch: true}`))
NewPunchyFromConfig(l, c)
entry := findEntry(t, hook.entries, "punchy enabled")
assert.Equal(t, slog.LevelInfo, entry.Level)
assert.Empty(t, entry.Attrs)
}
func TestPunchy_LogFormat_InitialDisabled(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {punch: false}`))
NewPunchyFromConfig(l, c)
entry := findEntry(t, hook.entries, "punchy disabled")
assert.Equal(t, slog.LevelInfo, entry.Level)
assert.Empty(t, entry.Attrs)
}
func TestPunchy_LogFormat_ReloadPunchUnsupported(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {punch: false}`))
NewPunchyFromConfig(l, c)
hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {punch: true}`))
entry := findEntry(t, hook.entries, "Changing punchy.punch with reload is not supported, ignoring.")
assert.Equal(t, slog.LevelWarn, entry.Level)
assert.Empty(t, entry.Attrs)
}
func TestPunchy_LogFormat_ReloadRespond(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {respond: false}`))
NewPunchyFromConfig(l, c)
hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {respond: true}`))
entry := findEntry(t, hook.entries, "punchy.respond changed")
assert.Equal(t, slog.LevelInfo, entry.Level)
assert.Equal(t, map[string]any{"respond": true}, entry.Attrs)
}
func TestPunchy_LogFormat_ReloadDelay(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {delay: 1s}`))
NewPunchyFromConfig(l, c)
hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {delay: 10s}`))
entry := findEntry(t, hook.entries, "punchy.delay changed")
assert.Equal(t, slog.LevelInfo, entry.Level)
assert.Equal(t, map[string]any{"delay": 10 * time.Second}, entry.Attrs)
}
func TestPunchy_LogFormat_ReloadTargetAllRemotes(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {target_all_remotes: false}`))
NewPunchyFromConfig(l, c)
hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {target_all_remotes: true}`))
entry := findEntry(t, hook.entries, "punchy.target_all_remotes changed")
assert.Equal(t, slog.LevelInfo, entry.Level)
assert.Equal(t, map[string]any{"target_all_remotes": true}, entry.Attrs)
}
func TestPunchy_LogFormat_ReloadRespondDelay(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {respond_delay: 5s}`))
NewPunchyFromConfig(l, c)
hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {respond_delay: 15s}`))
entry := findEntry(t, hook.entries, "punchy.respond_delay changed")
assert.Equal(t, slog.LevelInfo, entry.Level)
assert.Equal(t, map[string]any{"respond_delay": 15 * time.Second}, entry.Attrs)
}

View File

@@ -5,22 +5,22 @@ import (
"encoding/binary"
"errors"
"fmt"
"log/slog"
"net/netip"
"sync/atomic"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
)
type relayManager struct {
l *logrus.Logger
l *slog.Logger
hostmap *HostMap
amRelay atomic.Bool
}
func NewRelayManager(ctx context.Context, l *logrus.Logger, hostmap *HostMap, c *config.C) *relayManager {
func NewRelayManager(ctx context.Context, l *slog.Logger, hostmap *HostMap, c *config.C) *relayManager {
rm := &relayManager{
l: l,
hostmap: hostmap,
@@ -29,7 +29,7 @@ func NewRelayManager(ctx context.Context, l *logrus.Logger, hostmap *HostMap, c
c.RegisterReloadCallback(func(c *config.C) {
err := rm.reload(c, false)
if err != nil {
l.WithError(err).Error("Failed to reload relay_manager")
rm.l.Error("Failed to reload relay_manager", "error", err)
}
})
return rm
@@ -52,7 +52,7 @@ func (rm *relayManager) setAmRelay(v bool) {
// AddRelay finds an available relay index on the hostmap, and associates the relay info with it.
// relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp.
func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) {
func AddRelay(l *slog.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) {
hm.Lock()
defer hm.Unlock()
for range 32 {
@@ -92,24 +92,24 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti
func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) {
relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex)
if !ok {
fields := logrus.Fields{
"relay": relayHostInfo.vpnAddrs[0],
"initiatorRelayIndex": m.InitiatorRelayIndex,
}
var relayFrom, relayTo any
if m.RelayFromAddr == nil {
fields["relayFrom"] = m.OldRelayFromAddr
relayFrom = m.OldRelayFromAddr
} else {
fields["relayFrom"] = m.RelayFromAddr
relayFrom = m.RelayFromAddr
}
if m.RelayToAddr == nil {
fields["relayTo"] = m.OldRelayToAddr
relayTo = m.OldRelayToAddr
} else {
fields["relayTo"] = m.RelayToAddr
relayTo = m.RelayToAddr
}
rm.l.WithFields(fields).Info("relayManager failed to update relay")
rm.l.Info("relayManager failed to update relay",
"relay", relayHostInfo.vpnAddrs[0],
"initiatorRelayIndex", m.InitiatorRelayIndex,
"relayFrom", relayFrom,
"relayTo", relayTo,
)
return nil, fmt.Errorf("unknown relay")
}
@@ -120,7 +120,7 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) {
msg := &NebulaControl{}
err := msg.Unmarshal(d)
if err != nil {
h.logger(f.l).WithError(err).Error("Failed to unmarshal control message")
h.logger(f.l).Error("Failed to unmarshal control message", "error", err)
return
}
@@ -147,20 +147,20 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) {
}
func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) {
rm.l.WithFields(logrus.Fields{
"relayFrom": protoAddrToNetAddr(m.RelayFromAddr),
"relayTo": protoAddrToNetAddr(m.RelayToAddr),
"initiatorRelayIndex": m.InitiatorRelayIndex,
"responderRelayIndex": m.ResponderRelayIndex,
"vpnAddrs": h.vpnAddrs}).
Info("handleCreateRelayResponse")
rm.l.Info("handleCreateRelayResponse",
"relayFrom", protoAddrToNetAddr(m.RelayFromAddr),
"relayTo", protoAddrToNetAddr(m.RelayToAddr),
"initiatorRelayIndex", m.InitiatorRelayIndex,
"responderRelayIndex", m.ResponderRelayIndex,
"vpnAddrs", h.vpnAddrs,
)
target := m.RelayToAddr
targetAddr := protoAddrToNetAddr(target)
relay, err := rm.EstablishRelay(h, m)
if err != nil {
rm.l.WithError(err).Error("Failed to update relay for relayTo")
rm.l.Error("Failed to update relay for relayTo", "error", err)
return
}
// Do I need to complete the relays now?
@@ -170,12 +170,12 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
// I'm the middle man. Let the initiator know that the I've established the relay they requested.
peerHostInfo := rm.hostmap.QueryVpnAddr(relay.PeerAddr)
if peerHostInfo == nil {
rm.l.WithField("relayTo", relay.PeerAddr).Error("Can't find a HostInfo for peer")
rm.l.Error("Can't find a HostInfo for peer", "relayTo", relay.PeerAddr)
return
}
peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr)
if !ok {
rm.l.WithField("relayTo", peerHostInfo.vpnAddrs[0]).Error("peerRelay does not have Relay state for relayTo")
rm.l.Error("peerRelay does not have Relay state for relayTo", "relayTo", peerHostInfo.vpnAddrs[0])
return
}
switch peerRelay.State {
@@ -193,12 +193,13 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
if v == cert.Version1 {
peer := peerHostInfo.vpnAddrs[0]
if !peer.Is4() {
rm.l.WithField("relayFrom", peer).
WithField("relayTo", target).
WithField("initiatorRelayIndex", resp.InitiatorRelayIndex).
WithField("responderRelayIndex", resp.ResponderRelayIndex).
WithField("vpnAddrs", peerHostInfo.vpnAddrs).
Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address")
rm.l.Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address",
"relayFrom", peer,
"relayTo", target,
"initiatorRelayIndex", resp.InitiatorRelayIndex,
"responderRelayIndex", resp.ResponderRelayIndex,
"vpnAddrs", peerHostInfo.vpnAddrs,
)
return
}
@@ -213,17 +214,16 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
msg, err := resp.Marshal()
if err != nil {
rm.l.WithError(err).
Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
rm.l.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err)
} else {
f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{
"relayFrom": resp.RelayFromAddr,
"relayTo": resp.RelayToAddr,
"initiatorRelayIndex": resp.InitiatorRelayIndex,
"responderRelayIndex": resp.ResponderRelayIndex,
"vpnAddrs": peerHostInfo.vpnAddrs}).
Info("send CreateRelayResponse")
rm.l.Info("send CreateRelayResponse",
"relayFrom", resp.RelayFromAddr,
"relayTo", resp.RelayToAddr,
"initiatorRelayIndex", resp.InitiatorRelayIndex,
"responderRelayIndex", resp.ResponderRelayIndex,
"vpnAddrs", peerHostInfo.vpnAddrs,
)
}
}
}
@@ -232,17 +232,18 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
from := protoAddrToNetAddr(m.RelayFromAddr)
target := protoAddrToNetAddr(m.RelayToAddr)
logMsg := rm.l.WithFields(logrus.Fields{
"relayFrom": from,
"relayTo": target,
"initiatorRelayIndex": m.InitiatorRelayIndex,
"vpnAddrs": h.vpnAddrs})
logMsg := rm.l.With(
"relayFrom", from,
"relayTo", target,
"initiatorRelayIndex", m.InitiatorRelayIndex,
"vpnAddrs", h.vpnAddrs,
)
logMsg.Info("handleCreateRelayRequest")
// Is the source of the relay me? This should never happen, but did happen due to
// an issue migrating relays over to newly re-handshaked host info objects.
if f.myVpnAddrsTable.Contains(from) {
logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
logMsg.Error("Discarding relay request from myself", "myIP", from)
return
}
@@ -261,37 +262,37 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
if existingRelay.RemoteIndex != m.InitiatorRelayIndex {
// We got a brand new Relay request, because its index is different than what we saw before.
// This should never happen. The peer should never change an index, once created.
logMsg.WithFields(logrus.Fields{
"existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
logMsg.Error("Existing relay mismatch with CreateRelayRequest",
"existingRemoteIndex", existingRelay.RemoteIndex)
return
}
case Disestablished:
if existingRelay.RemoteIndex != m.InitiatorRelayIndex {
// We got a brand new Relay request, because its index is different than what we saw before.
// This should never happen. The peer should never change an index, once created.
logMsg.WithFields(logrus.Fields{
"existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
logMsg.Error("Existing relay mismatch with CreateRelayRequest",
"existingRemoteIndex", existingRelay.RemoteIndex)
return
}
// Mark the relay as 'Established' because it's safe to use again
h.relayState.UpdateRelayForByIpState(from, Established)
case PeerRequested:
// I should never be in this state, because I am terminal, not forwarding.
logMsg.WithFields(logrus.Fields{
"existingRemoteIndex": existingRelay.RemoteIndex,
"state": existingRelay.State}).Error("Unexpected Relay State found")
logMsg.Error("Unexpected Relay State found",
"existingRemoteIndex", existingRelay.RemoteIndex,
"state", existingRelay.State)
}
} else {
_, err := AddRelay(rm.l, h, f.hostMap, from, &m.InitiatorRelayIndex, TerminalType, Established)
if err != nil {
logMsg.WithError(err).Error("Failed to add relay")
logMsg.Error("Failed to add relay", "error", err)
return
}
}
relay, ok := h.relayState.QueryRelayForByIp(from)
if !ok {
logMsg.WithField("from", from).Error("Relay State not found")
logMsg.Error("Relay State not found", "from", from)
return
}
@@ -313,17 +314,16 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
msg, err := resp.Marshal()
if err != nil {
logMsg.
WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
logMsg.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err)
} else {
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{
"relayFrom": from,
"relayTo": target,
"initiatorRelayIndex": resp.InitiatorRelayIndex,
"responderRelayIndex": resp.ResponderRelayIndex,
"vpnAddrs": h.vpnAddrs}).
Info("send CreateRelayResponse")
rm.l.Info("send CreateRelayResponse",
"relayFrom", from,
"relayTo", target,
"initiatorRelayIndex", resp.InitiatorRelayIndex,
"responderRelayIndex", resp.ResponderRelayIndex,
"vpnAddrs", h.vpnAddrs,
)
}
return
} else {
@@ -363,12 +363,13 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
if v == cert.Version1 {
if !h.vpnAddrs[0].Is4() {
rm.l.WithField("relayFrom", h.vpnAddrs[0]).
WithField("relayTo", target).
WithField("initiatorRelayIndex", req.InitiatorRelayIndex).
WithField("responderRelayIndex", req.ResponderRelayIndex).
WithField("vpnAddr", target).
Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address")
rm.l.Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address",
"relayFrom", h.vpnAddrs[0],
"relayTo", target,
"initiatorRelayIndex", req.InitiatorRelayIndex,
"responderRelayIndex", req.ResponderRelayIndex,
"vpnAddr", target,
)
return
}
@@ -383,17 +384,16 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
msg, err := req.Marshal()
if err != nil {
logMsg.
WithError(err).Error("relayManager Failed to marshal Control message to create relay")
logMsg.Error("relayManager Failed to marshal Control message to create relay", "error", err)
} else {
f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{
"relayFrom": h.vpnAddrs[0],
"relayTo": target,
"initiatorRelayIndex": req.InitiatorRelayIndex,
"responderRelayIndex": req.ResponderRelayIndex,
"vpnAddr": target}).
Info("send CreateRelayRequest")
rm.l.Info("send CreateRelayRequest",
"relayFrom", h.vpnAddrs[0],
"relayTo", target,
"initiatorRelayIndex", req.InitiatorRelayIndex,
"responderRelayIndex", req.ResponderRelayIndex,
"vpnAddr", target,
)
}
// Also track the half-created Relay state just received
@@ -401,8 +401,7 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
if !ok {
_, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, PeerRequested)
if err != nil {
logMsg.
WithError(err).Error("relayManager Failed to allocate a local index for relay")
logMsg.Error("relayManager Failed to allocate a local index for relay", "error", err)
return
}
}

View File

@@ -2,6 +2,7 @@ package nebula
import (
"context"
"log/slog"
"net"
"net/netip"
"slices"
@@ -10,8 +11,6 @@ import (
"sync"
"sync/atomic"
"time"
"github.com/sirupsen/logrus"
)
// forEachFunc is used to benefit folks that want to do work inside the lock
@@ -66,11 +65,11 @@ type hostnamesResults struct {
network string
lookupTimeout time.Duration
cancelFn func()
l *logrus.Logger
l *slog.Logger
ips atomic.Pointer[map[netip.AddrPort]struct{}]
}
func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) {
func NewHostnameResults(ctx context.Context, l *slog.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) {
r := &hostnamesResults{
hostnames: make([]hostnamePort, len(hostPorts)),
network: network,
@@ -121,7 +120,11 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, r.network, hostPort.name)
timeoutCancel()
if err != nil {
l.WithFields(logrus.Fields{"hostname": hostPort.name, "network": r.network}).WithError(err).Error("DNS resolution failed for static_map host")
l.Error("DNS resolution failed for static_map host",
"hostname", hostPort.name,
"network", r.network,
"error", err,
)
continue
}
for _, a := range addrs {
@@ -145,7 +148,10 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
}
}
if different {
l.WithFields(logrus.Fields{"origSet": origSet, "newSet": netipAddrs}).Info("DNS results changed for host list")
l.Info("DNS results changed for host list",
"origSet", origSet,
"newSet", netipAddrs,
)
r.ips.Store(&netipAddrs)
onUpdate()
}

View File

@@ -10,11 +10,11 @@ import (
"time"
"dario.cat/mergo"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/logging"
"github.com/slackhq/nebula/overlay"
"go.yaml.in/yaml/v3"
"golang.org/x/sync/errgroup"
@@ -75,8 +75,7 @@ func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp n
panic(err)
}
logger := logrus.New()
logger.Out = os.Stdout
logger := logging.NewLogger(os.Stdout)
control, err := nebula.Main(&c, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
if err != nil {

83
ssh.go
View File

@@ -6,21 +6,21 @@ import (
"errors"
"flag"
"fmt"
"log/slog"
"maps"
"net"
"net/netip"
"os"
"path/filepath"
"reflect"
"runtime"
"runtime/pprof"
"sort"
"strconv"
"strings"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/logging"
"github.com/slackhq/nebula/sshd"
)
@@ -57,12 +57,12 @@ type sshDeviceInfoFlags struct {
Pretty bool
}
func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) {
func wireSSHReload(l *slog.Logger, ssh *sshd.SSHServer, c *config.C) {
c.RegisterReloadCallback(func(c *config.C) {
if c.GetBool("sshd.enabled", false) {
sshRun, err := configSSH(l, ssh, c)
if err != nil {
l.WithError(err).Error("Failed to reconfigure the sshd")
l.Error("Failed to reconfigure the sshd", "error", err)
ssh.Stop()
}
if sshRun != nil {
@@ -78,7 +78,7 @@ func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) {
// updates the passed-in SSHServer. On success, it returns a function
// that callers may invoke to run the configured ssh server. On
// failure, it returns nil, error.
func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) {
func configSSH(l *slog.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) {
listen := c.GetString("sshd.listen", "")
if listen == "" {
return nil, fmt.Errorf("sshd.listen must be provided")
@@ -120,7 +120,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
for _, caAuthorizedKey := range rawCAs {
err := ssh.AddTrustedCA(caAuthorizedKey)
if err != nil {
l.WithError(err).WithField("sshCA", caAuthorizedKey).Warn("SSH CA had an error, ignoring")
l.Warn("SSH CA had an error, ignoring", "error", err, "sshCA", caAuthorizedKey)
continue
}
}
@@ -131,13 +131,13 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
for _, rk := range keys {
kDef, ok := rk.(map[string]any)
if !ok {
l.WithField("sshKeyConfig", rk).Warn("Authorized user had an error, ignoring")
l.Warn("Authorized user had an error, ignoring", "sshKeyConfig", rk)
continue
}
user, ok := kDef["user"].(string)
if !ok {
l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the user field")
l.Warn("Authorized user is missing the user field", "sshKeyConfig", rk)
continue
}
@@ -146,7 +146,11 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
case string:
err := ssh.AddAuthorizedKey(user, v)
if err != nil {
l.WithError(err).WithField("sshKeyConfig", rk).WithField("sshKey", v).Warn("Failed to authorize key")
l.Warn("Failed to authorize key",
"error", err,
"sshKeyConfig", rk,
"sshKey", v,
)
continue
}
@@ -154,19 +158,25 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
for _, subK := range v {
sk, ok := subK.(string)
if !ok {
l.WithField("sshKeyConfig", rk).WithField("sshKey", subK).Warn("Did not understand ssh key")
l.Warn("Did not understand ssh key",
"sshKeyConfig", rk,
"sshKey", subK,
)
continue
}
err := ssh.AddAuthorizedKey(user, sk)
if err != nil {
l.WithError(err).WithField("sshKeyConfig", sk).Warn("Failed to authorize key")
l.Warn("Failed to authorize key",
"error", err,
"sshKeyConfig", sk,
)
continue
}
}
default:
l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the keys field or was not understood")
l.Warn("Authorized user is missing the keys field or was not understood", "sshKeyConfig", rk)
}
}
} else {
@@ -178,7 +188,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
ssh.Stop()
runner = func() {
if err := ssh.Run(listen); err != nil {
l.WithField("err", err).Warn("Failed to run the SSH server")
l.Warn("Failed to run the SSH server", "error", err)
}
}
} else {
@@ -188,7 +198,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
return runner, nil
}
func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) {
func attachCommands(l *slog.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) {
// sandboxDir defaults to a dir in temp. The intention is that end user will
// create this dir as needed. Overriding this config value to "" allows
// writing to anywhere in the system.
@@ -789,36 +799,45 @@ func sshGetMutexProfile(sandboxDir string, fs any, a []string, w sshd.StringWrit
return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a))
}
func sshLogLevel(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error {
func sshLogLevel(l *slog.Logger, fs any, a []string, w sshd.StringWriter) error {
ctrl, ok := l.Handler().(interface {
GetLevel() slog.Level
SetLevel(slog.Level)
})
if !ok {
return w.WriteLine("Log level is not reconfigurable on this logger")
}
if len(a) == 0 {
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
return w.WriteLine(fmt.Sprintf("Log level is: %s", logging.LevelName(ctrl.GetLevel())))
}
level, err := logrus.ParseLevel(a[0])
level, err := logging.ParseLevel(strings.ToLower(a[0]))
if err != nil {
return w.WriteLine(fmt.Sprintf("Unknown log level %s. Possible log levels: %s", a, logrus.AllLevels))
return w.WriteLine(fmt.Sprintf("Unknown log level %s. Possible log levels: trace, debug, info, warn, error", a))
}
l.SetLevel(level)
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
ctrl.SetLevel(level)
return w.WriteLine(fmt.Sprintf("Log level is: %s", logging.LevelName(ctrl.GetLevel())))
}
func sshLogFormat(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error {
func sshLogFormat(l *slog.Logger, fs any, a []string, w sshd.StringWriter) error {
ctrl, ok := l.Handler().(interface {
GetFormat() string
SetFormat(string) error
})
if !ok {
return w.WriteLine("Log format is not reconfigurable on this logger")
}
if len(a) == 0 {
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
return w.WriteLine(fmt.Sprintf("Log format is: %s", ctrl.GetFormat()))
}
logFormat := strings.ToLower(a[0])
switch logFormat {
case "text":
l.Formatter = &logrus.TextFormatter{}
case "json":
l.Formatter = &logrus.JSONFormatter{}
default:
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
if err := ctrl.SetFormat(strings.ToLower(a[0])); err != nil {
return err
}
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
return w.WriteLine(fmt.Sprintf("Log format is: %s", ctrl.GetFormat()))
}
func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {

View File

@@ -5,16 +5,16 @@ import (
"context"
"errors"
"fmt"
"log/slog"
"net"
"github.com/armon/go-radix"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
)
type SSHServer struct {
config *ssh.ServerConfig
l *logrus.Entry
l *slog.Logger
certChecker *ssh.CertChecker
@@ -33,7 +33,7 @@ type SSHServer struct {
}
// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen
func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
func NewSSHServer(l *slog.Logger) (*SSHServer, error) {
ctx, cancel := context.WithCancel(context.Background())
s := &SSHServer{
@@ -121,7 +121,7 @@ func (s *SSHServer) AddTrustedCA(pubKey string) error {
}
s.trustedCAs = append(s.trustedCAs, pk)
s.l.WithField("sshKey", pubKey).Info("Trusted CA key")
s.l.Info("Trusted CA key", "sshKey", pubKey)
return nil
}
@@ -139,7 +139,10 @@ func (s *SSHServer) AddAuthorizedKey(user, pubKey string) error {
}
tk[string(pk.Marshal())] = true
s.l.WithField("sshKey", pubKey).WithField("sshUser", user).Info("Authorized ssh key")
s.l.Info("Authorized ssh key",
"sshKey", pubKey,
"sshUser", user,
)
return nil
}
@@ -156,7 +159,7 @@ func (s *SSHServer) Run(addr string) error {
return err
}
s.l.WithField("sshListener", addr).Info("SSH server is listening")
s.l.Info("SSH server is listening", "sshListener", addr)
// Run loops until there is an error
s.run()
@@ -172,7 +175,7 @@ func (s *SSHServer) run() {
c, err := s.listener.Accept()
if err != nil {
if !errors.Is(err, net.ErrClosed) {
s.l.WithError(err).Warn("Error in listener, shutting down")
s.l.Warn("Error in listener, shutting down", "error", err)
}
return
}
@@ -193,23 +196,29 @@ func (s *SSHServer) run() {
}
if err != nil {
l := s.l.WithError(err).WithField("remoteAddress", c.RemoteAddr())
l := s.l.With(
"error", err,
"remoteAddress", c.RemoteAddr(),
)
if conn != nil {
l = l.WithField("sshUser", conn.User())
l = l.With("sshUser", conn.User())
conn.Close()
}
if fp != "" {
l = l.WithField("sshFingerprint", fp)
l = l.With("sshFingerprint", fp)
}
l.Warn("failed to handshake")
sessionCancel()
return
}
l := s.l.WithField("sshUser", conn.User())
l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in")
l := s.l.With("sshUser", conn.User())
l.Info("ssh user logged in",
"remoteAddress", c.RemoteAddr(),
"sshFingerprint", fp,
)
NewSession(s.commands, conn, chans, sessionCancel, l.WithField("subsystem", "sshd.session"))
NewSession(s.commands, conn, chans, sessionCancel, l.With("subsystem", "sshd.session"))
go ssh.DiscardRequests(reqs)
@@ -221,7 +230,7 @@ func (s *SSHServer) Stop() {
// Close the listener, this will cause all session to terminate as well, see SSHServer.Run
if s.listener != nil {
if err := s.listener.Close(); err != nil {
s.l.WithError(err).Warn("Failed to close the sshd listener")
s.l.Warn("Failed to close the sshd listener", "error", err)
}
}
}

View File

@@ -2,25 +2,25 @@ package sshd
import (
"fmt"
"log/slog"
"sort"
"strings"
"github.com/anmitsu/go-shlex"
"github.com/armon/go-radix"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"golang.org/x/term"
)
type session struct {
l *logrus.Entry
l *slog.Logger
c *ssh.ServerConn
term *term.Terminal
commands *radix.Tree
cancel func()
}
func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, cancel func(), l *logrus.Entry) *session {
func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, cancel func(), l *slog.Logger) *session {
s := &session{
commands: radix.NewFromMap(commands.ToMap()),
l: l,
@@ -45,14 +45,14 @@ func (s *session) handleChannels(chans <-chan ssh.NewChannel) {
defer s.Close()
for newChannel := range chans {
if newChannel.ChannelType() != "session" {
s.l.WithField("sshChannelType", newChannel.ChannelType()).Error("unknown channel type")
s.l.Error("unknown channel type", "sshChannelType", newChannel.ChannelType())
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
s.l.WithError(err).Warn("could not accept channel")
s.l.Warn("could not accept channel", "error", err)
continue
}
@@ -95,12 +95,12 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) {
return
default:
s.l.WithField("sshRequest", req.Type).Debug("Rejected unknown request")
s.l.Debug("Rejected unknown request", "sshRequest", req.Type)
err = req.Reply(false, nil)
}
if err != nil {
s.l.WithError(err).Info("Error handling ssh session requests")
s.l.Info("Error handling ssh session requests", "error", err)
return
}
}

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"log"
"log/slog"
"net"
"net/http"
"runtime"
@@ -15,14 +16,13 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
)
// startStats initializes stats from config. On success, if any further work
// is needed to serve stats, it returns a func to handle that work. If no
// work is needed, it'll return nil. On failure, it returns nil, error.
func startStats(l *logrus.Logger, c *config.C, buildVersion string, configTest bool) (func(), error) {
func startStats(l *slog.Logger, c *config.C, buildVersion string, configTest bool) (func(), error) {
mType := c.GetString("stats.type", "")
if mType == "" || mType == "none" {
return nil, nil
@@ -59,7 +59,7 @@ func startStats(l *logrus.Logger, c *config.C, buildVersion string, configTest b
return startFn, nil
}
func startGraphiteStats(l *logrus.Logger, i time.Duration, c *config.C, configTest bool) error {
func startGraphiteStats(l *slog.Logger, i time.Duration, c *config.C, configTest bool) error {
proto := c.GetString("stats.protocol", "tcp")
host := c.GetString("stats.host", "")
if host == "" {
@@ -73,13 +73,17 @@ func startGraphiteStats(l *logrus.Logger, i time.Duration, c *config.C, configTe
}
if !configTest {
l.Infof("Starting graphite. Interval: %s, prefix: %s, addr: %s", i, prefix, addr)
l.Info("Starting graphite",
"interval", i,
"prefix", prefix,
"addr", addr.String(),
)
go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr)
}
return nil
}
func startPrometheusStats(l *logrus.Logger, i time.Duration, c *config.C, buildVersion string, configTest bool) (func(), error) {
func startPrometheusStats(l *slog.Logger, i time.Duration, c *config.C, buildVersion string, configTest bool) (func(), error) {
namespace := c.GetString("stats.namespace", "")
subsystem := c.GetString("stats.subsystem", "")
@@ -116,9 +120,15 @@ func startPrometheusStats(l *logrus.Logger, i time.Duration, c *config.C, buildV
var startFn func()
if !configTest {
// promhttp.HandlerOpts.ErrorLog needs a stdlib-shaped Println logger,
// so bridge our slog.Logger back to a *log.Logger that emits at Error.
errLog := slog.NewLogLogger(l.Handler(), slog.LevelError)
startFn = func() {
l.Infof("Prometheus stats listening on %s at %s", listen, path)
http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: l}))
l.Info("Prometheus stats listening",
"listen", listen,
"path", path,
)
http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: errLog}))
log.Fatal(http.ListenAndServe(listen, nil))
}
}

View File

@@ -1,29 +1,73 @@
package test
import (
"context"
"io"
"log/slog"
"os"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/logging"
)
func NewLogger() *logrus.Logger {
l := logrus.New()
// NewLogger returns a *slog.Logger suitable for use in tests. Output goes to
// io.Discard by default; set TEST_LOGS=1 (info), 2 (debug), or 3 (trace) to
// stream output to stderr for local debugging.
func NewLogger() *slog.Logger {
v := os.Getenv("TEST_LOGS")
if v == "" {
l.SetOutput(io.Discard)
return l
return slog.New(slog.DiscardHandler)
}
level := slog.LevelInfo
switch v {
case "2":
l.SetLevel(logrus.DebugLevel)
level = slog.LevelDebug
case "3":
l.SetLevel(logrus.TraceLevel)
default:
l.SetLevel(logrus.InfoLevel)
level = logging.LevelTrace
}
return l
return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level}))
}
// NewLoggerWithOutput returns a *slog.Logger whose text output is captured by
// w. Timestamps are suppressed so tests can assert on exact output without
// baking the current time into expected strings.
func NewLoggerWithOutput(w io.Writer) *slog.Logger {
return slog.New(&stripTimeHandler{inner: slog.NewTextHandler(w, nil)})
}
// NewLoggerWithOutputAndLevel is NewLoggerWithOutput with an explicit level
// so tests can exercise Enabled-gated paths.
func NewLoggerWithOutputAndLevel(w io.Writer, level slog.Level) *slog.Logger {
return slog.New(&stripTimeHandler{inner: slog.NewTextHandler(w, &slog.HandlerOptions{Level: level})})
}
// NewJSONLoggerWithOutput returns a *slog.Logger emitting JSON to w with
// timestamps suppressed, for tests that pin the JSON shape.
func NewJSONLoggerWithOutput(w io.Writer, level slog.Level) *slog.Logger {
return slog.New(&stripTimeHandler{inner: slog.NewJSONHandler(w, &slog.HandlerOptions{Level: level})})
}
// stripTimeHandler zeros each record's time before delegating so slog's
// built-in handlers skip emitting the time attribute. Used to avoid
// timestamp-dependent assertions in tests without resorting to ReplaceAttr.
type stripTimeHandler struct {
inner slog.Handler
}
func (h *stripTimeHandler) Enabled(ctx context.Context, l slog.Level) bool {
return h.inner.Enabled(ctx, l)
}
func (h *stripTimeHandler) Handle(ctx context.Context, r slog.Record) error {
r.Time = time.Time{}
return h.inner.Handle(ctx, r)
}
func (h *stripTimeHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
return &stripTimeHandler{inner: h.inner.WithAttrs(attrs)}
}
func (h *stripTimeHandler) WithGroup(name string) slog.Handler {
return &stripTimeHandler{inner: h.inner.WithGroup(name)}
}

View File

@@ -9,11 +9,12 @@ import (
"net/netip"
"syscall"
"github.com/sirupsen/logrus"
"log/slog"
"golang.org/x/sys/unix"
)
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch)
}

View File

@@ -12,11 +12,12 @@ import (
"net/netip"
"syscall"
"github.com/sirupsen/logrus"
"log/slog"
"golang.org/x/sys/unix"
)
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch)
}

View File

@@ -8,12 +8,12 @@ import (
"encoding/binary"
"errors"
"fmt"
"log/slog"
"net"
"net/netip"
"syscall"
"unsafe"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"golang.org/x/sys/unix"
)
@@ -22,12 +22,12 @@ type StdConn struct {
*net.UDPConn
isV4 bool
sysFd uintptr
l *logrus.Logger
l *slog.Logger
}
var _ Conn = &StdConn{}
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
lc := NewListenConfig(multi)
pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
if err != nil {
@@ -176,7 +176,7 @@ func (u *StdConn) ListenOut(r EncReader) error {
return err
}
u.l.WithError(err).Error("unexpected udp socket receive error")
u.l.Error("unexpected udp socket receive error", "error", err)
}
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
@@ -196,7 +196,7 @@ func (u *StdConn) Rebind() error {
}
if err != nil {
u.l.WithError(err).Error("Failed to rebind udp socket")
u.l.Error("Failed to rebind udp socket", "error", err)
}
return nil

View File

@@ -12,22 +12,22 @@ import (
"context"
"errors"
"fmt"
"log/slog"
"net"
"net/netip"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
)
type GenericConn struct {
*net.UDPConn
l *logrus.Logger
l *slog.Logger
}
var _ Conn = &GenericConn{}
func NewGenericListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
func NewGenericListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
lc := NewListenConfig(multi)
pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
if err != nil {
@@ -88,7 +88,7 @@ func (u *GenericConn) ListenOut(r EncReader) error {
// Dampen unexpected message warns to once per minute
if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute {
lastRecvErr = time.Now()
u.l.WithError(err).Warn("unexpected udp socket receive error")
u.l.Warn("unexpected udp socket receive error", "error", err)
}
continue
}

View File

@@ -7,13 +7,13 @@ import (
"context"
"encoding/binary"
"fmt"
"log/slog"
"net"
"net/netip"
"syscall"
"unsafe"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"golang.org/x/sys/unix"
)
@@ -22,7 +22,7 @@ type StdConn struct {
udpConn *net.UDPConn
rawConn syscall.RawConn
isV4 bool
l *logrus.Logger
l *slog.Logger
batch int
}
@@ -38,7 +38,7 @@ func setReusePort(network, address string, c syscall.RawConn) error {
return opErr
}
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
listen := netip.AddrPortFrom(ip, uint16(port))
lc := net.ListenConfig{}
if multi {
@@ -242,12 +242,12 @@ func (u *StdConn) ReloadConfig(c *config.C) {
if err == nil {
s, err := u.GetRecvBuffer()
if err == nil {
u.l.WithField("size", s).Info("listen.read_buffer was set")
u.l.Info("listen.read_buffer was set", "size", s)
} else {
u.l.WithError(err).Warn("Failed to get listen.read_buffer")
u.l.Warn("Failed to get listen.read_buffer", "error", err)
}
} else {
u.l.WithError(err).Error("Failed to set listen.read_buffer")
u.l.Error("Failed to set listen.read_buffer", "error", err)
}
}
@@ -257,12 +257,12 @@ func (u *StdConn) ReloadConfig(c *config.C) {
if err == nil {
s, err := u.GetSendBuffer()
if err == nil {
u.l.WithField("size", s).Info("listen.write_buffer was set")
u.l.Info("listen.write_buffer was set", "size", s)
} else {
u.l.WithError(err).Warn("Failed to get listen.write_buffer")
u.l.Warn("Failed to get listen.write_buffer", "error", err)
}
} else {
u.l.WithError(err).Error("Failed to set listen.write_buffer")
u.l.Error("Failed to set listen.write_buffer", "error", err)
}
}
@@ -273,12 +273,12 @@ func (u *StdConn) ReloadConfig(c *config.C) {
if err == nil {
s, err := u.GetSoMark()
if err == nil {
u.l.WithField("mark", s).Info("listen.so_mark was set")
u.l.Info("listen.so_mark was set", "mark", s)
} else {
u.l.WithError(err).Warn("Failed to get listen.so_mark")
u.l.Warn("Failed to get listen.so_mark", "error", err)
}
} else {
u.l.WithError(err).Error("Failed to set listen.so_mark")
u.l.Error("Failed to set listen.so_mark", "error", err)
}
}
}

View File

@@ -11,11 +11,12 @@ import (
"net/netip"
"syscall"
"github.com/sirupsen/logrus"
"log/slog"
"golang.org/x/sys/unix"
)
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch)
}

View File

@@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net"
"net/netip"
"sync"
@@ -17,7 +18,6 @@ import (
"time"
"unsafe"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/conn/winrio"
@@ -53,14 +53,14 @@ type ringBuffer struct {
type RIOConn struct {
isOpen atomic.Bool
l *logrus.Logger
l *slog.Logger
sock windows.Handle
rx, tx ringBuffer
rq winrio.Rq
results [packetsPerRing]winrio.Result
}
func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, error) {
func NewRIOListener(l *slog.Logger, addr netip.Addr, port int) (*RIOConn, error) {
if !winrio.Initialize() {
return nil, errors.New("could not initialize winrio")
}
@@ -83,7 +83,7 @@ func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, erro
return u, nil
}
func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error {
func (u *RIOConn) bind(l *slog.Logger, sa windows.Sockaddr) error {
var err error
u.sock, err = winrio.Socket(windows.AF_INET6, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
if err != nil {
@@ -103,7 +103,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error {
if err != nil {
// This is a best-effort to prevent errors from being returned by the udp recv operation.
// Quietly log a failure and continue.
l.WithError(err).Debug("failed to set UDP_CONNRESET ioctl")
l.Debug("failed to set UDP_CONNRESET ioctl", "error", err)
}
ret = 0
@@ -114,7 +114,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error {
if err != nil {
// This is a best-effort to prevent errors from being returned by the udp recv operation.
// Quietly log a failure and continue.
l.WithError(err).Debug("failed to set UDP_NETRESET ioctl")
l.Debug("failed to set UDP_NETRESET ioctl", "error", err)
}
err = u.rx.Open()
@@ -156,7 +156,7 @@ func (u *RIOConn) ListenOut(r EncReader) error {
// Dampen unexpected message warns to once per minute
if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute {
lastRecvErr = time.Now()
u.l.WithError(err).Warn("unexpected udp socket receive error")
u.l.Warn("unexpected udp socket receive error", "error", err)
}
continue
}

View File

@@ -4,12 +4,13 @@
package udp
import (
"context"
"io"
"log/slog"
"net/netip"
"os"
"sync"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
)
@@ -46,10 +47,10 @@ type TesterConn struct {
done chan struct{}
closeOnce sync.Once
l *logrus.Logger
l *slog.Logger
}
func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) {
func NewListener(l *slog.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) {
return &TesterConn{
Addr: netip.AddrPortFrom(ip, uint16(port)),
RxPackets: make(chan *Packet, 10),
@@ -67,11 +68,12 @@ func (u *TesterConn) Send(packet *Packet) {
if err := h.Parse(packet.Data); err != nil {
panic(err)
}
if u.l.Level >= logrus.DebugLevel {
u.l.WithField("header", h).
WithField("udpAddr", packet.From).
WithField("dataLen", len(packet.Data)).
Debug("UDP receiving injected packet")
if u.l.Enabled(context.Background(), slog.LevelDebug) {
u.l.Debug("UDP receiving injected packet",
"header", h,
"udpAddr", packet.From,
"dataLen", len(packet.Data),
)
}
select {
case <-u.done:

View File

@@ -5,14 +5,13 @@ package udp
import (
"fmt"
"log/slog"
"net"
"net/netip"
"syscall"
"github.com/sirupsen/logrus"
)
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
if multi {
//NOTE: Technically we can support it with RIO but it wouldn't be at the socket level
// The udp stack would need to be reworked to hide away the implementation differences between
@@ -25,7 +24,7 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
return rc, nil
}
l.WithError(err).Error("Falling back to standard udp sockets")
l.Error("Falling back to standard udp sockets", "error", err)
return NewGenericListener(l, ip, port, multi, batch)
}

View File

@@ -1,10 +1,10 @@
package util
import (
"context"
"errors"
"fmt"
"github.com/sirupsen/logrus"
"log/slog"
)
type ContextualError struct {
@@ -28,12 +28,12 @@ func ContextualizeIfNeeded(msg string, err error) error {
}
// LogWithContextIfNeeded is a helper function to log an error line for an error or ContextualError
func LogWithContextIfNeeded(msg string, err error, l *logrus.Logger) {
func LogWithContextIfNeeded(msg string, err error, l *slog.Logger) {
switch v := err.(type) {
case *ContextualError:
v.Log(l)
default:
l.WithError(err).Error(msg)
l.Error(msg, "error", err)
}
}
@@ -51,10 +51,19 @@ func (ce *ContextualError) Unwrap() error {
return ce.RealError
}
func (ce *ContextualError) Log(lr *logrus.Logger) {
if ce.RealError != nil {
lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context)
} else {
lr.WithFields(ce.Fields).Error(ce.Context)
// Log emits ce as a single error-level log line with Fields and RealError
// promoted to top-level attributes, producing a flat shape callers can grep
// or parse without walking into a nested object.
func (ce *ContextualError) Log(l *slog.Logger) {
attrs := make([]slog.Attr, 0, len(ce.Fields)+1)
for k, v := range ce.Fields {
attrs = append(attrs, slog.Any(k, v))
}
if ce.RealError != nil {
attrs = append(attrs, slog.Any("error", ce.RealError))
}
// LogAttrs is intentional: attrs is built from a map[string]any so it has
// no pair-form equivalent.
//nolint:sloglint
l.LogAttrs(context.Background(), slog.LevelError, ce.Context, attrs...)
}

View File

@@ -1,95 +1,67 @@
package util
import (
"bytes"
"errors"
"fmt"
"testing"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
)
type m = map[string]any
type TestLogWriter struct {
Logs []string
}
func NewTestLogWriter() *TestLogWriter {
return &TestLogWriter{Logs: make([]string, 0)}
}
func (tl *TestLogWriter) Write(p []byte) (n int, err error) {
tl.Logs = append(tl.Logs, string(p))
return len(p), nil
}
func (tl *TestLogWriter) Reset() {
tl.Logs = tl.Logs[:0]
}
func TestContextualError_Log(t *testing.T) {
l := logrus.New()
l.Formatter = &logrus.TextFormatter{
DisableTimestamp: true,
DisableColors: true,
}
tl := NewTestLogWriter()
l.Out = tl
buf := &bytes.Buffer{}
l := test.NewLoggerWithOutput(buf)
// Test a full context line
tl.Reset()
buf.Reset()
e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
e.Log(l)
assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs)
assert.Equal(t, "level=ERROR msg=\"test message\" field=1 error=error\n", buf.String())
// Test a line with an error and msg but no fields
tl.Reset()
buf.Reset()
e = NewContextualError("test message", nil, errors.New("error"))
e.Log(l)
assert.Equal(t, []string{"level=error msg=\"test message\" error=error\n"}, tl.Logs)
assert.Equal(t, "level=ERROR msg=\"test message\" error=error\n", buf.String())
// Test just a context and fields
tl.Reset()
buf.Reset()
e = NewContextualError("test message", m{"field": "1"}, nil)
e.Log(l)
assert.Equal(t, []string{"level=error msg=\"test message\" field=1\n"}, tl.Logs)
assert.Equal(t, "level=ERROR msg=\"test message\" field=1\n", buf.String())
// Test just a context
tl.Reset()
buf.Reset()
e = NewContextualError("test message", nil, nil)
e.Log(l)
assert.Equal(t, []string{"level=error msg=\"test message\"\n"}, tl.Logs)
assert.Equal(t, "level=ERROR msg=\"test message\"\n", buf.String())
// Test just an error
tl.Reset()
buf.Reset()
e = NewContextualError("", nil, errors.New("error"))
e.Log(l)
assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs)
assert.Equal(t, "level=ERROR msg=\"\" error=error\n", buf.String())
}
func TestLogWithContextIfNeeded(t *testing.T) {
l := logrus.New()
l.Formatter = &logrus.TextFormatter{
DisableTimestamp: true,
DisableColors: true,
}
tl := NewTestLogWriter()
l.Out = tl
buf := &bytes.Buffer{}
l := test.NewLoggerWithOutput(buf)
// Test ignoring fallback context
tl.Reset()
buf.Reset()
e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
LogWithContextIfNeeded("This should get thrown away", e, l)
assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs)
assert.Equal(t, "level=ERROR msg=\"test message\" field=1 error=error\n", buf.String())
// Test using fallback context
tl.Reset()
buf.Reset()
err := fmt.Errorf("this is a normal error")
LogWithContextIfNeeded("Fallback context woo", err, l)
assert.Equal(t, []string{"level=error msg=\"Fallback context woo\" error=\"this is a normal error\"\n"}, tl.Logs)
assert.Equal(t, "level=ERROR msg=\"Fallback context woo\" error=\"this is a normal error\"\n", buf.String())
}
func TestContextualizeIfNeeded(t *testing.T) {