diff --git a/.golangci.yaml b/.golangci.yaml index bd82a952..be0513d4 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -2,7 +2,21 @@ version: "2" linters: default: none enable: + - sloglint - testifylint + settings: + sloglint: + # Enforce key-value pair form for Info/Debug/Warn/Error/Log/With and + # the package-level slog equivalents. Use l.Log(ctx, level, ...) for + # custom levels instead of LogAttrs when you can. + # + # LogAttrs is also flagged by this rule because it takes ...slog.Attr; + # the few legitimate sites (where attrs is built up as a []slog.Attr) + # carry a //nolint:sloglint with rationale. + kv-only: true + # no-mixed-args is on by default: forbids mixing kv and attrs in one call. + # discard-handler is on by default (since Go 1.24): suggests + # slog.DiscardHandler over slog.NewTextHandler(io.Discard, nil). exclusions: generated: lax presets: diff --git a/bits.go b/bits.go index af11cc48..5c8f902b 100644 --- a/bits.go +++ b/bits.go @@ -1,8 +1,10 @@ package nebula import ( + "context" + "log/slog" + "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" ) type Bits struct { @@ -30,7 +32,7 @@ func NewBits(bits uint64) *Bits { return b } -func (b *Bits) Check(l *logrus.Logger, i uint64) bool { +func (b *Bits) Check(l *slog.Logger, i uint64) bool { // If i is the next number, return true. if i > b.current { return true @@ -42,13 +44,16 @@ func (b *Bits) Check(l *logrus.Logger, i uint64) bool { } // Not within the window - if l.Level >= logrus.DebugLevel { - l.Debugf("rejected a packet (top) %d %d\n", b.current, i) + if l.Enabled(context.Background(), slog.LevelDebug) { + l.Debug("rejected a packet (top)", + "current", b.current, + "incoming", i, + ) } return false } -func (b *Bits) Update(l *logrus.Logger, i uint64) bool { +func (b *Bits) Update(l *slog.Logger, i uint64) bool { // If i is the next number, return true and update current. if i == b.current+1 { // Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter @@ -87,9 +92,13 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool { // Check to see if it's a duplicate if i > b.current-b.length || i < b.length && b.current < b.length { if b.current == i || b.bits[i%b.length] == true { - if l.Level >= logrus.DebugLevel { - l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}). - Debug("Receive window") + if l.Enabled(context.Background(), slog.LevelDebug) { + l.Debug("Receive window", + "accepted", false, + "currentCounter", b.current, + "incomingCounter", i, + "reason", "duplicate", + ) } b.dupeCounter.Inc(1) return false @@ -101,12 +110,13 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool { // In all other cases, fail and don't change current. b.outOfWindowCounter.Inc(1) - if l.Level >= logrus.DebugLevel { - l.WithField("accepted", false). - WithField("currentCounter", b.current). - WithField("incomingCounter", i). - WithField("reason", "nonsense"). - Debug("Receive window") + if l.Enabled(context.Background(), slog.LevelDebug) { + l.Debug("Receive window", + "accepted", false, + "currentCounter", b.current, + "incomingCounter", i, + "reason", "nonsense", + ) } return false } diff --git a/cmd/nebula-service/logs_generic.go b/cmd/nebula-service/logs_generic.go index 3b7cdd1c..cc06b4c5 100644 --- a/cmd/nebula-service/logs_generic.go +++ b/cmd/nebula-service/logs_generic.go @@ -3,8 +3,15 @@ package main -import "github.com/sirupsen/logrus" +import ( + "log/slog" + "os" -func HookLogger(l *logrus.Logger) { - // Do nothing, let the logs flow to stdout/stderr + "github.com/slackhq/nebula/logging" +) + +// newPlatformLogger returns a *slog.Logger that writes to stdout. Non-Windows +// platforms have no special sink to integrate with. +func newPlatformLogger() *slog.Logger { + return logging.NewLogger(os.Stdout) } diff --git a/cmd/nebula-service/logs_windows.go b/cmd/nebula-service/logs_windows.go index af6480ef..ca0a55c5 100644 --- a/cmd/nebula-service/logs_windows.go +++ b/cmd/nebula-service/logs_windows.go @@ -1,54 +1,86 @@ package main import ( - "fmt" - "io/ioutil" - "os" + "context" + "log/slog" + "strings" + "sync" - "github.com/kardianos/service" - "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/logging" ) -// HookLogger routes the logrus logs through the service logger so that they end up in the Windows Event Viewer -// logrus output will be discarded -func HookLogger(l *logrus.Logger) { - l.AddHook(newLogHook(logger)) - l.SetOutput(ioutil.Discard) +// newPlatformLogger returns a *slog.Logger that routes every log record +// through the Windows service logger so records end up in the Windows +// Event Log. All the heavy lifting (level management, format swap, +// timestamp toggle, WithAttrs/WithGroup) comes from logging.NewHandler; +// this file only contributes: +// +// - an io.Writer that forwards each formatted line to the service +// logger at the current record's Event Log severity, and +// - a thin severityTag that embeds *logging.Handler and overrides +// only Handle / WithAttrs / WithGroup, so Event Viewer's severity +// column and severity-based filters keep working the way they did +// before the slog migration. +// +// Format (text vs json) is carried by the embedded *logging.Handler, so +// logging.format: json in config still produces JSON lines in Event +// Viewer, same as the pre-slog logrus setup. +func newPlatformLogger() *slog.Logger { + w := &eventLogWriter{} + return slog.New(&severityTag{Handler: logging.NewHandler(w), w: w}) } -type logHook struct { - sl service.Logger +// eventLogWriter forwards slog-formatted lines to the Windows service +// logger at the severity most recently stashed by severityTag.Handle. +// The mutex serializes the stash + inner.Handle + Write cycle per record +// across all concurrent goroutines; slog's builtin text/json handlers +// each hold their own mutex around Write, but that only protects the +// Write call itself, not our stash-then-handle sequence. +type eventLogWriter struct { + mu sync.Mutex + level slog.Level } -func newLogHook(sl service.Logger) *logHook { - return &logHook{sl: sl} -} - -func (h *logHook) Fire(entry *logrus.Entry) error { - line, err := entry.String() - if err != nil { - fmt.Fprintf(os.Stderr, "Unable to read entry, %v", err) - return err - } - - switch entry.Level { - case logrus.PanicLevel: - return h.sl.Error(line) - case logrus.FatalLevel: - return h.sl.Error(line) - case logrus.ErrorLevel: - return h.sl.Error(line) - case logrus.WarnLevel: - return h.sl.Warning(line) - case logrus.InfoLevel: - return h.sl.Info(line) - case logrus.DebugLevel: - return h.sl.Info(line) +func (w *eventLogWriter) Write(p []byte) (int, error) { + line := strings.TrimRight(string(p), "\n") + switch { + case w.level >= slog.LevelError: + return len(p), logger.Error(line) + case w.level >= slog.LevelWarn: + return len(p), logger.Warning(line) default: - return nil + return len(p), logger.Info(line) } } -func (h *logHook) Levels() []logrus.Level { - return logrus.AllLevels +// severityTag embeds *logging.Handler to pick up everything it does for +// free (Enabled, SetLevel, GetLevel, SetFormat, GetFormat, +// SetDisableTimestamp) and overrides only Handle / WithAttrs / WithGroup +// so each record's slog.Level is stashed on the writer before formatting +// and so derived handlers stay wrapped as severityTag rather than +// downgrading to bare *logging.Handler. +type severityTag struct { + *logging.Handler + w *eventLogWriter +} + +func (s *severityTag) Handle(ctx context.Context, r slog.Record) error { + s.w.mu.Lock() + defer s.w.mu.Unlock() + s.w.level = r.Level + return s.Handler.Handle(ctx, r) +} + +func (s *severityTag) WithAttrs(attrs []slog.Attr) slog.Handler { + if len(attrs) == 0 { + return s + } + return &severityTag{Handler: s.Handler.WithAttrs(attrs).(*logging.Handler), w: s.w} +} + +func (s *severityTag) WithGroup(name string) slog.Handler { + if name == "" { + return s + } + return &severityTag{Handler: s.Handler.WithGroup(name).(*logging.Handler), w: s.w} } diff --git a/cmd/nebula-service/main.go b/cmd/nebula-service/main.go index 021e36fa..19fb3a9f 100644 --- a/cmd/nebula-service/main.go +++ b/cmd/nebula-service/main.go @@ -7,9 +7,9 @@ import ( "runtime/debug" "strings" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/logging" "github.com/slackhq/nebula/util" ) @@ -50,12 +50,11 @@ func main() { os.Exit(0) } - l := logrus.New() - l.Out = os.Stdout + l := logging.NewLogger(os.Stdout) if *serviceFlag != "" { if err := doService(configPath, configTest, Build, serviceFlag); err != nil { - l.WithError(err).Error("Service command failed") + l.Error("Service command failed", "error", err) os.Exit(1) } return @@ -74,6 +73,16 @@ func main() { os.Exit(1) } + if err := logging.ApplyConfig(l, c); err != nil { + fmt.Printf("failed to apply logging config: %s", err) + os.Exit(1) + } + c.RegisterReloadCallback(func(c *config.C) { + if err := logging.ApplyConfig(l, c); err != nil { + l.Error("Failed to reconfigure logger on reload", "error", err) + } + }) + ctrl, err := nebula.Main(c, *configTest, Build, l, nil) if err != nil { util.LogWithContextIfNeeded("Failed to start", err, l) @@ -90,7 +99,7 @@ func main() { go ctrl.ShutdownBlock() if err := wait(); err != nil { - l.WithError(err).Error("Nebula stopped due to fatal error") + l.Error("Nebula stopped due to fatal error", "error", err) os.Exit(2) } diff --git a/cmd/nebula-service/service.go b/cmd/nebula-service/service.go index 1f45f95b..6551ceb4 100644 --- a/cmd/nebula-service/service.go +++ b/cmd/nebula-service/service.go @@ -7,9 +7,9 @@ import ( "path/filepath" "github.com/kardianos/service" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/logging" ) var logger service.Logger @@ -25,8 +25,7 @@ func (p *program) Start(s service.Service) error { // Start should not block. logger.Info("Nebula service starting.") - l := logrus.New() - HookLogger(l) + l := newPlatformLogger() c := config.NewC(l) err := c.Load(*p.configPath) @@ -34,6 +33,15 @@ func (p *program) Start(s service.Service) error { return fmt.Errorf("failed to load config: %s", err) } + if err := logging.ApplyConfig(l, c); err != nil { + return fmt.Errorf("failed to apply logging config: %s", err) + } + c.RegisterReloadCallback(func(c *config.C) { + if err := logging.ApplyConfig(l, c); err != nil { + l.Error("Failed to reconfigure logger on reload", "error", err) + } + }) + p.control, err = nebula.Main(c, *p.configTest, Build, l, nil) if err != nil { return err @@ -85,7 +93,7 @@ func doService(configPath *string, configTest *bool, build string, serviceFlag * // Here are what the different loggers are doing: // - `log` is the standard go log utility, meant to be used while the process is still attached to stdout/stderr // - `logger` is the service log utility that may be attached to a special place depending on OS (Windows will have it attached to the event log) - // - above, in `Run` we create a `logrus.Logger` which is what nebula expects to use + // - in program.Start we build a *slog.Logger via newPlatformLogger; on non-Windows that is a stdout-backed slog logger, on Windows it routes records through the service logger s, err := service.New(prg, svcConfig) if err != nil { return err diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index f29f4537..d7f0de93 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -7,9 +7,9 @@ import ( "runtime/debug" "strings" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/logging" "github.com/slackhq/nebula/util" ) @@ -55,8 +55,7 @@ func main() { os.Exit(1) } - l := logrus.New() - l.Out = os.Stdout + l := logging.NewLogger(os.Stdout) c := config.NewC(l) err := c.Load(*configPath) @@ -65,6 +64,16 @@ func main() { os.Exit(1) } + if err := logging.ApplyConfig(l, c); err != nil { + fmt.Printf("failed to apply logging config: %s", err) + os.Exit(1) + } + c.RegisterReloadCallback(func(c *config.C) { + if err := logging.ApplyConfig(l, c); err != nil { + l.Error("Failed to reconfigure logger on reload", "error", err) + } + }) + ctrl, err := nebula.Main(c, *configTest, Build, l, nil) if err != nil { util.LogWithContextIfNeeded("Failed to start", err, l) @@ -82,7 +91,7 @@ func main() { notifyReady(l) if err := wait(); err != nil { - l.WithError(err).Error("Nebula stopped due to fatal error") + l.Error("Nebula stopped due to fatal error", "error", err) os.Exit(2) } diff --git a/cmd/nebula/notify_linux.go b/cmd/nebula/notify_linux.go index 8c3dca55..965986a9 100644 --- a/cmd/nebula/notify_linux.go +++ b/cmd/nebula/notify_linux.go @@ -1,11 +1,10 @@ package main import ( + "log/slog" "net" "os" "time" - - "github.com/sirupsen/logrus" ) // SdNotifyReady tells systemd the service is ready and dependent services can now be started @@ -13,30 +12,30 @@ import ( // https://www.freedesktop.org/software/systemd/man/systemd.service.html const SdNotifyReady = "READY=1" -func notifyReady(l *logrus.Logger) { +func notifyReady(l *slog.Logger) { sockName := os.Getenv("NOTIFY_SOCKET") if sockName == "" { - l.Debugln("NOTIFY_SOCKET systemd env var not set, not sending ready signal") + l.Debug("NOTIFY_SOCKET systemd env var not set, not sending ready signal") return } conn, err := net.DialTimeout("unixgram", sockName, time.Second) if err != nil { - l.WithError(err).Error("failed to connect to systemd notification socket") + l.Error("failed to connect to systemd notification socket", "error", err) return } defer conn.Close() err = conn.SetWriteDeadline(time.Now().Add(time.Second)) if err != nil { - l.WithError(err).Error("failed to set the write deadline for the systemd notification socket") + l.Error("failed to set the write deadline for the systemd notification socket", "error", err) return } if _, err = conn.Write([]byte(SdNotifyReady)); err != nil { - l.WithError(err).Error("failed to signal the systemd notification socket") + l.Error("failed to signal the systemd notification socket", "error", err) return } - l.Debugln("notified systemd the service is ready") + l.Debug("notified systemd the service is ready") } diff --git a/cmd/nebula/notify_notlinux.go b/cmd/nebula/notify_notlinux.go index e7758e09..48cfe949 100644 --- a/cmd/nebula/notify_notlinux.go +++ b/cmd/nebula/notify_notlinux.go @@ -3,8 +3,8 @@ package main -import "github.com/sirupsen/logrus" +import "log/slog" -func notifyReady(_ *logrus.Logger) { +func notifyReady(_ *slog.Logger) { // No init service to notify } diff --git a/config/config.go b/config/config.go index 0d1be128..5bf994a1 100644 --- a/config/config.go +++ b/config/config.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "log/slog" "math" "os" "os/signal" @@ -16,7 +17,6 @@ import ( "time" "dario.cat/mergo" - "github.com/sirupsen/logrus" "go.yaml.in/yaml/v3" ) @@ -26,11 +26,11 @@ type C struct { Settings map[string]any oldSettings map[string]any callbacks []func(*C) - l *logrus.Logger + l *slog.Logger reloadLock sync.Mutex } -func NewC(l *logrus.Logger) *C { +func NewC(l *slog.Logger) *C { return &C{ Settings: make(map[string]any), l: l, @@ -107,12 +107,18 @@ func (c *C) HasChanged(k string) bool { newVals, err := yaml.Marshal(nv) if err != nil { - c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config") + c.l.Error("Error while marshaling new config", + "config_path", k, + "error", err, + ) } oldVals, err := yaml.Marshal(ov) if err != nil { - c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config") + c.l.Error("Error while marshaling old config", + "config_path", k, + "error", err, + ) } return string(newVals) != string(oldVals) @@ -154,7 +160,10 @@ func (c *C) ReloadConfig() { err := c.Load(c.path) if err != nil { - c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config") + c.l.Error("Error occurred while reloading config", + "config_path", c.path, + "error", err, + ) return } diff --git a/connection_manager.go b/connection_manager.go index 4c2f26ef..e7fc04cd 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -5,13 +5,13 @@ import ( "context" "encoding/binary" "fmt" + "log/slog" "net/netip" "sync" "sync/atomic" "time" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" @@ -47,10 +47,10 @@ type connectionManager struct { metricsTxPunchy metrics.Counter - l *logrus.Logger + l *slog.Logger } -func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager { +func newConnectionManagerFromConfig(l *slog.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager { cm := &connectionManager{ hostMap: hm, l: l, @@ -85,9 +85,10 @@ func (cm *connectionManager) reload(c *config.C, initial bool) { old := cm.getInactivityTimeout() cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute))) if !initial { - cm.l.WithField("oldDuration", old). - WithField("newDuration", cm.getInactivityTimeout()). - Info("Inactivity timeout has changed") + cm.l.Info("Inactivity timeout has changed", + "oldDuration", old, + "newDuration", cm.getInactivityTimeout(), + ) } } @@ -95,9 +96,10 @@ func (cm *connectionManager) reload(c *config.C, initial bool) { old := cm.dropInactive.Load() cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false)) if !initial { - cm.l.WithField("oldBool", old). - WithField("newBool", cm.dropInactive.Load()). - Info("Drop inactive setting has changed") + cm.l.Info("Drop inactive setting has changed", + "oldBool", old, + "newBool", cm.dropInactive.Load(), + ) } } } @@ -256,7 +258,7 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo var err error index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested) if err != nil { - cm.l.WithError(err).Error("failed to migrate relay to new hostinfo") + cm.l.Error("failed to migrate relay to new hostinfo", "error", err) continue } switch r.Type { @@ -304,16 +306,16 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo msg, err := req.Marshal() if err != nil { - cm.l.WithError(err).Error("failed to marshal Control message to migrate relay") + cm.l.Error("failed to marshal Control message to migrate relay", "error", err) } else { cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) - cm.l.WithFields(logrus.Fields{ - "relayFrom": req.RelayFromAddr, - "relayTo": req.RelayToAddr, - "initiatorRelayIndex": req.InitiatorRelayIndex, - "responderRelayIndex": req.ResponderRelayIndex, - "vpnAddrs": newhostinfo.vpnAddrs}). - Info("send CreateRelayRequest") + cm.l.Info("send CreateRelayRequest", + "relayFrom", req.RelayFromAddr, + "relayTo", req.RelayToAddr, + "initiatorRelayIndex", req.InitiatorRelayIndex, + "responderRelayIndex", req.ResponderRelayIndex, + "vpnAddrs", newhostinfo.vpnAddrs, + ) } } } @@ -325,7 +327,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim hostinfo := cm.hostMap.Indexes[localIndex] if hostinfo == nil { - cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap") + cm.l.Debug("Not found in hostmap", "localIndex", localIndex) return doNothing, nil, nil } @@ -345,10 +347,10 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim // A hostinfo is determined alive if there is incoming traffic if inTraffic { decision := doNothing - if cm.l.Level >= logrus.DebugLevel { - hostinfo.logger(cm.l). - WithField("tunnelCheck", m{"state": "alive", "method": "passive"}). - Debug("Tunnel status") + if cm.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(cm.l).Debug("Tunnel status", + "tunnelCheck", m{"state": "alive", "method": "passive"}, + ) } hostinfo.pendingDeletion.Store(false) @@ -375,9 +377,9 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim if hostinfo.pendingDeletion.Load() { // We have already sent a test packet and nothing was returned, this hostinfo is dead - hostinfo.logger(cm.l). - WithField("tunnelCheck", m{"state": "dead", "method": "active"}). - Info("Tunnel status") + hostinfo.logger(cm.l).Info("Tunnel status", + "tunnelCheck", m{"state": "dead", "method": "active"}, + ) return deleteTunnel, hostinfo, nil } @@ -388,10 +390,10 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim inactiveFor, isInactive := cm.isInactive(hostinfo, now) if isInactive { // Tunnel is inactive, tear it down - hostinfo.logger(cm.l). - WithField("inactiveDuration", inactiveFor). - WithField("primary", mainHostInfo). - Info("Dropping tunnel due to inactivity") + hostinfo.logger(cm.l).Info("Dropping tunnel due to inactivity", + "inactiveDuration", inactiveFor, + "primary", mainHostInfo, + ) return closeTunnel, hostinfo, primary } @@ -410,18 +412,18 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim cm.sendPunch(hostinfo) } - if cm.l.Level >= logrus.DebugLevel { - hostinfo.logger(cm.l). - WithField("tunnelCheck", m{"state": "testing", "method": "active"}). - Debug("Tunnel status") + if cm.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(cm.l).Debug("Tunnel status", + "tunnelCheck", m{"state": "testing", "method": "active"}, + ) } // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues decision = sendTestPacket } else { - if cm.l.Level >= logrus.DebugLevel { - hostinfo.logger(cm.l).Debugf("Hostinfo sadness") + if cm.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(cm.l).Debug("Hostinfo sadness") } } @@ -493,14 +495,16 @@ func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostI return false //cert is still valid! yay! } else if err == cert.ErrBlockListed { //avoiding errors.Is for speed // Block listed certificates should always be disconnected - hostinfo.logger(cm.l).WithError(err). - WithField("fingerprint", remoteCert.Fingerprint). - Info("Remote certificate is blocked, tearing down the tunnel") + hostinfo.logger(cm.l).Info("Remote certificate is blocked, tearing down the tunnel", + "error", err, + "fingerprint", remoteCert.Fingerprint, + ) return true } else if cm.intf.disconnectInvalid.Load() { - hostinfo.logger(cm.l).WithError(err). - WithField("fingerprint", remoteCert.Fingerprint). - Info("Remote certificate is no longer valid, tearing down the tunnel") + hostinfo.logger(cm.l).Info("Remote certificate is no longer valid, tearing down the tunnel", + "error", err, + "fingerprint", remoteCert.Fingerprint, + ) return true } else { //if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open @@ -539,10 +543,11 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) { curCrtVersion := curCrt.Version() myCrt := cs.getCertificate(curCrtVersion) if myCrt == nil { - cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("version", curCrtVersion). - WithField("reason", "local certificate removed"). - Info("Re-handshaking with remote") + cm.l.Info("Re-handshaking with remote", + "vpnAddrs", hostinfo.vpnAddrs, + "version", curCrtVersion, + "reason", "local certificate removed", + ) cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) return } @@ -550,11 +555,12 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) { if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() { // if our certificate version is less than theirs, and we have a matching version available, rehandshake? if cs.getCertificate(peerCrt.Certificate.Version()) != nil { - cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("version", curCrtVersion). - WithField("peerVersion", peerCrt.Certificate.Version()). - WithField("reason", "local certificate version lower than peer, attempting to correct"). - Info("Re-handshaking with remote") + cm.l.Info("Re-handshaking with remote", + "vpnAddrs", hostinfo.vpnAddrs, + "version", curCrtVersion, + "peerVersion", peerCrt.Certificate.Version(), + "reason", "local certificate version lower than peer, attempting to correct", + ) cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) { hh.initiatingVersionOverride = peerCrt.Certificate.Version() }) @@ -562,17 +568,19 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) { } } if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) { - cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("reason", "local certificate is not current"). - Info("Re-handshaking with remote") + cm.l.Info("Re-handshaking with remote", + "vpnAddrs", hostinfo.vpnAddrs, + "reason", "local certificate is not current", + ) cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) return } if curCrtVersion < cs.initiatingVersion { - cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("reason", "current cert version < pki.initiatingVersion"). - Info("Re-handshaking with remote") + cm.l.Info("Re-handshaking with remote", + "vpnAddrs", hostinfo.vpnAddrs, + "reason", "current cert version < pki.initiatingVersion", + ) cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) return diff --git a/connection_manager_test.go b/connection_manager_test.go index 647dd72b..a015fba9 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -10,6 +10,7 @@ import ( "github.com/flynn/noise" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/overlaytest" "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" @@ -52,7 +53,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, - inside: &test.NoopTun{}, + inside: &overlaytest.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, @@ -63,9 +64,9 @@ func Test_NewConnectionManagerTest(t *testing.T) { ifce.pki.cs.Store(cs) // Create manager - conf := config.NewC(l) - punchy := NewPunchyFromConfig(l, conf) - nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) + conf := config.NewC(test.NewLogger()) + punchy := NewPunchyFromConfig(test.NewLogger(), conf) + nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc.intf = ifce p := []byte("") nb := make([]byte, 12, 12) @@ -135,7 +136,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, - inside: &test.NoopTun{}, + inside: &overlaytest.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, @@ -146,9 +147,9 @@ func Test_NewConnectionManagerTest2(t *testing.T) { ifce.pki.cs.Store(cs) // Create manager - conf := config.NewC(l) - punchy := NewPunchyFromConfig(l, conf) - nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) + conf := config.NewC(test.NewLogger()) + punchy := NewPunchyFromConfig(test.NewLogger(), conf) + nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc.intf = ifce p := []byte("") nb := make([]byte, 12, 12) @@ -220,7 +221,7 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) { lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, - inside: &test.NoopTun{}, + inside: &overlaytest.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, @@ -231,12 +232,12 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) { ifce.pki.cs.Store(cs) // Create manager - conf := config.NewC(l) + conf := config.NewC(test.NewLogger()) conf.Settings["tunnels"] = map[string]any{ "drop_inactive": true, } - punchy := NewPunchyFromConfig(l, conf) - nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) + punchy := NewPunchyFromConfig(test.NewLogger(), conf) + nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) assert.True(t, nc.dropInactive.Load()) nc.intf = ifce @@ -347,7 +348,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, - inside: &test.NoopTun{}, + inside: &overlaytest.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, @@ -360,9 +361,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { ifce.disconnectInvalid.Store(true) // Create manager - conf := config.NewC(l) - punchy := NewPunchyFromConfig(l, conf) - nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) + conf := config.NewC(test.NewLogger()) + punchy := NewPunchyFromConfig(test.NewLogger(), conf) + nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc.intf = ifce ifce.connectionManager = nc diff --git a/connection_state.go b/connection_state.go index db885d42..b85aebd4 100644 --- a/connection_state.go +++ b/connection_state.go @@ -8,7 +8,6 @@ import ( "sync/atomic" "github.com/flynn/noise" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/noiseutil" ) @@ -27,7 +26,7 @@ type ConnectionState struct { writeLock sync.Mutex } -func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) { +func NewConnectionState(cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) { var dhFunc noise.DHFunc switch crt.Curve() { case cert.Curve_CURVE25519: diff --git a/control.go b/control.go index 75eccef1..ef58988b 100644 --- a/control.go +++ b/control.go @@ -3,13 +3,13 @@ package nebula import ( "context" "errors" + "log/slog" "net/netip" "os" "os/signal" "sync" "syscall" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/overlay" @@ -46,7 +46,7 @@ type Control struct { state RunState f *Interface - l *logrus.Logger + l *slog.Logger ctx context.Context cancel context.CancelFunc sshStart func() @@ -151,7 +151,7 @@ func (c *Control) Stop() { c.CloseAllTunnels(false) if err := c.f.Close(); err != nil { - c.l.WithError(err).Error("Close interface failed") + c.l.Error("Close interface failed", "error", err) } c.stateLock.Lock() c.state = StateStopped @@ -166,7 +166,7 @@ func (c *Control) ShutdownBlock() { rawSig := <-sigChan sig := rawSig.String() - c.l.WithField("signal", sig).Info("Caught signal, shutting down") + c.l.Info("Caught signal, shutting down", "signal", sig) c.Stop() } @@ -303,8 +303,10 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) c.f.closeTunnel(h) - c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote). - Debug("Sending close tunnel message") + c.l.Debug("Sending close tunnel message", + "vpnAddrs", h.vpnAddrs, + "udpAddr", h.remote, + ) closed++ } diff --git a/control_test.go b/control_test.go index 558d8669..5e381c46 100644 --- a/control_test.go +++ b/control_test.go @@ -6,7 +6,6 @@ import ( "reflect" "testing" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" @@ -83,7 +82,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { f: &Interface{ hostMap: hm, }, - l: logrus.New(), + l: test.NewLogger(), } thi := c.GetHostInfoByVpnAddr(vpnIp, false) diff --git a/dns_server.go b/dns_server.go index 5b12b922..ff1369ab 100644 --- a/dns_server.go +++ b/dns_server.go @@ -3,6 +3,7 @@ package nebula import ( "context" "fmt" + "log/slog" "net" "net/netip" "strconv" @@ -12,13 +13,12 @@ import ( "github.com/gaissmai/bart" "github.com/miekg/dns" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" ) type dnsServer struct { sync.RWMutex - l *logrus.Logger + l *slog.Logger ctx context.Context dnsMap4 map[string]netip.Addr dnsMap6 map[string]netip.Addr @@ -55,7 +55,7 @@ type dnsServer struct { // they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel // watcher that tears the listener down on nebula shutdown. The returned // pointer is always non-nil, even on error. -func newDnsServerFromConfig(ctx context.Context, l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) { +func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) { ds := &dnsServer{ l: l, ctx: ctx, @@ -69,7 +69,7 @@ func newDnsServerFromConfig(ctx context.Context, l *logrus.Logger, cs *CertState c.RegisterReloadCallback(func(c *config.C) { if err := ds.reload(c, false); err != nil { - l.WithError(err).Error("Failed to reload DNS responder from config") + ds.l.Error("Failed to reload DNS responder from config", "error", err) } }) @@ -145,7 +145,7 @@ func (d *dnsServer) shutdownServer(srv *dns.Server, started chan struct{}, reaso <-started } if err := srv.Shutdown(); err != nil { - d.l.WithError(err).WithField("reason", reason).Warn("Failed to shut down the DNS responder") + d.l.Warn("Failed to shut down the DNS responder", "reason", reason, "error", err) } } @@ -188,7 +188,7 @@ func (d *dnsServer) Start() { } }() - d.l.WithField("dnsListener", addr).Info("Starting DNS responder") + d.l.Info("Starting DNS responder", "dnsListener", addr) err := server.ListenAndServe() close(done) @@ -201,7 +201,7 @@ func (d *dnsServer) Start() { } if err != nil { - d.l.WithError(err).Warn("Failed to run the DNS responder") + d.l.Warn("Failed to run the DNS responder", "error", err) } } @@ -314,6 +314,7 @@ func (d *dnsServer) isSelfNebulaOrLocalhost(addr string) bool { } func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) { + debugEnabled := d.l.Enabled(context.Background(), slog.LevelDebug) // Per RFC 2308 §2.2, a name that exists but has no record of the requested // type must be answered with NOERROR and an empty answer section (NODATA), // not NXDOMAIN (RFC 2308 §2.1), which is reserved for names that do not @@ -323,7 +324,9 @@ func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) { switch q.Qtype { case dns.TypeA, dns.TypeAAAA: qType := dns.TypeToString[q.Qtype] - d.l.Debugf("Query for %s %s", qType, q.Name) + if debugEnabled { + d.l.Debug("DNS query", "type", qType, "name", q.Name) + } ip, nameExists := d.Query(q.Qtype, q.Name) if nameExists { anyNameExists = true @@ -339,7 +342,9 @@ func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) { if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) { return } - d.l.Debugf("Query for TXT %s", q.Name) + if debugEnabled { + d.l.Debug("DNS query", "type", "TXT", "name", q.Name) + } ip := d.QueryCert(q.Name) if ip != "" { rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip)) diff --git a/dns_server_test.go b/dns_server_test.go index e09d3fa9..dcea046c 100644 --- a/dns_server_test.go +++ b/dns_server_test.go @@ -2,7 +2,7 @@ package nebula import ( "context" - "io" + "log/slog" "net" "net/netip" "strconv" @@ -10,7 +10,6 @@ import ( "time" "github.com/miekg/dns" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -30,7 +29,7 @@ func (stubDNSWriter) TsigTimersOnly(bool) {} func (stubDNSWriter) Hijack() {} func TestParsequery(t *testing.T) { - l := logrus.New() + l := slog.New(slog.DiscardHandler) hostMap := &HostMap{} ds := &dnsServer{ l: l, @@ -137,10 +136,9 @@ func Test_getDnsServerAddr(t *testing.T) { func newTestDnsServer(t *testing.T) (*dnsServer, *config.C) { t.Helper() - l := logrus.New() - l.Out = io.Discard + sl := slog.New(slog.DiscardHandler) ds := &dnsServer{ - l: l, + l: sl, ctx: context.Background(), dnsMap4: make(map[string]netip.Addr), dnsMap6: make(map[string]netip.Addr), @@ -148,7 +146,7 @@ func newTestDnsServer(t *testing.T) (*dnsServer, *config.C) { } ds.mux = dns.NewServeMux() ds.mux.HandleFunc(".", ds.handleDnsRequest) - return ds, config.NewC(l) + return ds, config.NewC(nil) } func setDnsConfig(c *config.C, host string, port string, amLighthouse, serveDns bool) { diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 7729465b..93f200ac 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -11,7 +11,6 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert_test" @@ -749,7 +748,6 @@ func TestStage1RaceRelays2(t *testing.T) { myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) - l := NewTestLogger() // Teach my how to get to the relay and that their can be reached via the relay myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) @@ -771,49 +769,41 @@ func TestStage1RaceRelays2(t *testing.T) { theirControl.Start() r.Log("Get a tunnel between me and relay") - l.Info("Get a tunnel between me and relay") assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) r.Log("Get a tunnel between them and relay") - l.Info("Get a tunnel between them and relay") assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") - l.Info("Trigger a handshake from both them and me via relay to them and me") myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) //r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone) //r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone) - r.Log("Wait for a packet from them to me") - l.Info("Wait for a packet from them to me; myControl") + r.Log("Wait for a packet from them to me; myControl") r.RouteForAllUntilTxTun(myControl) - l.Info("Wait for a packet from them to me; theirControl") + r.Log("Wait for a packet from them to me; theirControl") r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - l.Info("Assert the tunnel works") assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) t.Log("Wait until we remove extra tunnels") - l.Info("Wait until we remove extra tunnels") - l.WithFields( - logrus.Fields{ - "myControl": len(myControl.GetHostmap().Indexes), - "theirControl": len(theirControl.GetHostmap().Indexes), - "relayControl": len(relayControl.GetHostmap().Indexes), - }).Info("Waiting for hostinfos to be removed...") + t.Logf("Waiting for hostinfos to be removed... myControl=%d theirControl=%d relayControl=%d", + len(myControl.GetHostmap().Indexes), + len(theirControl.GetHostmap().Indexes), + len(relayControl.GetHostmap().Indexes), + ) hostInfos := len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes) retries := 60 for hostInfos > 6 && retries > 0 { hostInfos = len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes) - l.WithFields( - logrus.Fields{ - "myControl": len(myControl.GetHostmap().Indexes), - "theirControl": len(theirControl.GetHostmap().Indexes), - "relayControl": len(relayControl.GetHostmap().Indexes), - }).Info("Waiting for hostinfos to be removed...") + t.Logf("Waiting for hostinfos to be removed... myControl=%d theirControl=%d relayControl=%d", + len(myControl.GetHostmap().Indexes), + len(theirControl.GetHostmap().Indexes), + len(relayControl.GetHostmap().Indexes), + ) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) @@ -821,7 +811,6 @@ func TestStage1RaceRelays2(t *testing.T) { } r.Log("Assert the tunnel works") - l.Info("Assert the tunnel works") assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) myControl.Stop() diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index 39843efe..381ae897 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -4,7 +4,6 @@ package e2e import ( - "fmt" "io" "net/netip" "os" @@ -12,15 +11,18 @@ import ( "testing" "time" + "log/slog" + "dario.cat/mergo" "github.com/google/gopacket" "github.com/google/gopacket/layers" - "github.com/sirupsen/logrus" + "github.com/slackhq/nebula" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/e2e/router" + "github.com/slackhq/nebula/logging" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.yaml.in/yaml/v3" @@ -132,8 +134,7 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific "port": udpAddr.Port(), }, "logging": m{ - "timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name), - "level": l.Level.String(), + "level": testLogLevelName(), }, "timers": m{ "pending_deletion_interval": 2, @@ -234,8 +235,7 @@ func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, o "port": udpAddr.Port(), }, "logging": m{ - "timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()), - "level": l.Level.String(), + "level": testLogLevelName(), }, "timers": m{ "pending_deletion_interval": 2, @@ -379,24 +379,32 @@ func getAddrs(ns []netip.Prefix) []netip.Addr { return a } -func NewTestLogger() *logrus.Logger { - l := logrus.New() - +func NewTestLogger() *slog.Logger { v := os.Getenv("TEST_LOGS") if v == "" { - l.SetOutput(io.Discard) - l.SetLevel(logrus.PanicLevel) - return l + return slog.New(slog.NewTextHandler(io.Discard, nil)) } + level := slog.LevelInfo switch v { case "2": - l.SetLevel(logrus.DebugLevel) + level = slog.LevelDebug case "3": - l.SetLevel(logrus.TraceLevel) - default: - l.SetLevel(logrus.InfoLevel) + level = logging.LevelTrace } - - return l + return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level})) +} + +// testLogLevelName returns the level name string accepted by logging.ApplyConfig +// for the current TEST_LOGS setting. Kept in sync with NewTestLogger. +func testLogLevelName() string { + switch os.Getenv("TEST_LOGS") { + case "2": + return "debug" + case "3": + return "trace" + case "": + return "info" + } + return "info" } diff --git a/examples/config.yml b/examples/config.yml index 5bb87d8e..b02b3d58 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -292,23 +292,17 @@ tun: # Configure logging level logging: - # panic, fatal, error, warning, info, or debug. Default is info and is reloadable. - #NOTE: Debug mode can log remotely controlled/untrusted data which can quickly fill a disk in some - # scenarios. Debug logging is also CPU intensive and will decrease performance overall. - # Only enable debug logging while actively investigating an issue. + # trace, debug, info, warn, or error. Default is info and is reloadable. + # fatal and panic are accepted for backwards compatibility and map to error. + #NOTE: Debug and trace modes can log remotely controlled/untrusted data which can quickly fill a disk in some + # scenarios. Debug and trace logging are also CPU intensive and will decrease performance overall. + # Only enable debug or trace logging while actively investigating an issue. level: info - # json or text formats currently available. Default is text + # json or text formats currently available. Default is text. format: text - # Disable timestamp logging. useful when output is redirected to logging system that already adds timestamps. Default is false + # Disable timestamp logging. Useful when output is redirected to a logging system that already adds timestamps. Default is false. #disable_timestamp: true - # timestamp format is specified in Go time format, see: - # https://golang.org/pkg/time/#pkg-constants - # default when `format: json`: "2006-01-02T15:04:05Z07:00" (RFC3339) - # default when `format: text`: - # when TTY attached: seconds since beginning of execution - # otherwise: "2006-01-02T15:04:05Z07:00" (RFC3339) - # As an example, to log as RFC3339 with millisecond precision, set to: - #timestamp_format: "2006-01-02T15:04:05.000Z07:00" + # Timestamps use RFC3339Nano ("2006-01-02T15:04:05.999999999Z07:00") and are not configurable. #stats: #type: graphite diff --git a/examples/go_service/main.go b/examples/go_service/main.go index 2f8efbfb..3f98fe3d 100644 --- a/examples/go_service/main.go +++ b/examples/go_service/main.go @@ -7,9 +7,9 @@ import ( "net" "os" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/logging" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/service" ) @@ -64,8 +64,7 @@ pki: return err } - logger := logrus.New() - logger.Out = os.Stdout + logger := logging.NewLogger(os.Stdout) ctrl, err := nebula.Main(&cfg, false, "custom-app", logger, overlay.NewUserDeviceFromConfig) if err != nil { diff --git a/firewall.go b/firewall.go index 93b16891..adecbe81 100644 --- a/firewall.go +++ b/firewall.go @@ -1,11 +1,13 @@ package nebula import ( + "context" "crypto/sha256" "encoding/hex" "errors" "fmt" "hash/fnv" + "log/slog" "net/netip" "reflect" "slices" @@ -16,7 +18,6 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" @@ -67,7 +68,7 @@ type Firewall struct { incomingMetrics firewallMetrics outgoingMetrics firewallMetrics - l *logrus.Logger + l *slog.Logger } type firewallMetrics struct { @@ -131,7 +132,7 @@ type firewallLocalCIDR struct { // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. // The certificate provided should be the highest version loaded in memory. -func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall { +func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall { //TODO: error on 0 duration var tmin, tmax time.Duration @@ -191,7 +192,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D } } -func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) { +func NewFirewallFromConfig(l *slog.Logger, cs *CertState, c *config.C) (*Firewall, error) { certificate := cs.getCertificate(cert.Version2) if certificate == nil { certificate = cs.getCertificate(cert.Version1) @@ -219,7 +220,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew case "drop": fw.InSendReject = false default: - l.WithField("action", inboundAction).Warn("invalid firewall.inbound_action, defaulting to `drop`") + l.Warn("invalid firewall.inbound_action, defaulting to `drop`", "action", inboundAction) fw.InSendReject = false } @@ -230,7 +231,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew case "drop": fw.OutSendReject = false default: - l.WithField("action", outboundAction).Warn("invalid firewall.outbound_action, defaulting to `drop`") + l.Warn("invalid firewall.outbound_action, defaulting to `drop`", "action", outboundAction) fw.OutSendReject = false } @@ -268,7 +269,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort case firewall.ProtoICMP, firewall.ProtoICMPv6: //ICMP traffic doesn't have ports, so we always coerce to "any", even if a value is provided if startPort != firewall.PortAny { - f.l.WithField("startPort", startPort).Warn("ignoring port specification for ICMP firewall rule") + f.l.Warn("ignoring port specification for ICMP firewall rule", "startPort", startPort) } startPort = firewall.PortAny endPort = firewall.PortAny @@ -290,8 +291,9 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort if !incoming { direction = "outgoing" } - f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}). - Info("Firewall rule added") + f.l.Info("Firewall rule added", + "firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}, + ) return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha) } @@ -314,7 +316,7 @@ func (f *Firewall) GetRuleHashes() string { return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10) } -func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error { +func AddFirewallRulesFromConfig(l *slog.Logger, inbound bool, c *config.C, fw FirewallInterface) error { var table string if inbound { table = "firewall.inbound" @@ -372,7 +374,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw startPort = firewall.PortAny endPort = firewall.PortAny if sPort != "" { - l.WithField("port", sPort).Warn("ignoring port specification for ICMP firewall rule") + l.Warn("ignoring port specification for ICMP firewall rule", "port", sPort) } default: return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto) @@ -396,7 +398,11 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw } if warning := r.sanity(); warning != nil { - l.Warnf("%s rule #%v; %s", table, i, warning) + l.Warn("firewall rule sanity check", + "table", table, + "rule", i, + "warning", warning, + ) } err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha) @@ -528,26 +534,26 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, // We now know which firewall table to check against if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) { - if f.l.Level >= logrus.DebugLevel { - h.logger(f.l). - WithField("fwPacket", fp). - WithField("incoming", c.incoming). - WithField("rulesVersion", f.rulesVersion). - WithField("oldRulesVersion", c.rulesVersion). - Debugln("dropping old conntrack entry, does not match new ruleset") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + h.logger(f.l).Debug("dropping old conntrack entry, does not match new ruleset", + "fwPacket", fp, + "incoming", c.incoming, + "rulesVersion", f.rulesVersion, + "oldRulesVersion", c.rulesVersion, + ) } delete(conntrack.Conns, fp) conntrack.Unlock() return false } - if f.l.Level >= logrus.DebugLevel { - h.logger(f.l). - WithField("fwPacket", fp). - WithField("incoming", c.incoming). - WithField("rulesVersion", f.rulesVersion). - WithField("oldRulesVersion", c.rulesVersion). - Debugln("keeping old conntrack entry, does match new ruleset") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + h.logger(f.l).Debug("keeping old conntrack entry, does match new ruleset", + "fwPacket", fp, + "incoming", c.incoming, + "rulesVersion", f.rulesVersion, + "oldRulesVersion", c.rulesVersion, + ) } c.rulesVersion = f.rulesVersion @@ -935,7 +941,7 @@ type rule struct { CASha string } -func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) { +func convertRule(l *slog.Logger, p any, table string, i int) (rule, error) { r := rule{} m, ok := p.(map[string]any) @@ -966,7 +972,10 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) { return r, errors.New("group should contain a single value, an array with more than one entry was provided") } - l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i) + l.Warn("group was an array with a single value, converting to simple value", + "table", table, + "rule", i, + ) m["group"] = v[0] } diff --git a/firewall/cache.go b/firewall/cache.go index a4ffc100..ba4b9732 100644 --- a/firewall/cache.go +++ b/firewall/cache.go @@ -2,10 +2,9 @@ package firewall import ( "context" + "log/slog" "sync/atomic" "time" - - "github.com/sirupsen/logrus" ) // ConntrackCache is used as a local routine cache to know if a given flow @@ -16,15 +15,17 @@ type ConntrackCacheTicker struct { cacheV uint64 cacheTick atomic.Uint64 + l *slog.Logger cache ConntrackCache } -func NewConntrackCacheTicker(ctx context.Context, d time.Duration) *ConntrackCacheTicker { +func NewConntrackCacheTicker(ctx context.Context, l *slog.Logger, d time.Duration) *ConntrackCacheTicker { if d == 0 { return nil } c := &ConntrackCacheTicker{ + l: l, cache: ConntrackCache{}, } @@ -48,15 +49,15 @@ func (c *ConntrackCacheTicker) tick(ctx context.Context, d time.Duration) { // Get checks if the cache ticker has moved to the next version before returning // the map. If it has moved, we reset the map. -func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache { +func (c *ConntrackCacheTicker) Get() ConntrackCache { if c == nil { return nil } if tick := c.cacheTick.Load(); tick != c.cacheV { c.cacheV = tick if ll := len(c.cache); ll > 0 { - if l.Level == logrus.DebugLevel { - l.WithField("len", ll).Debug("resetting conntrack cache") + if c.l.Enabled(context.Background(), slog.LevelDebug) { + c.l.Debug("resetting conntrack cache", "len", ll) } c.cache = make(ConntrackCache, ll) } diff --git a/firewall/cache_test.go b/firewall/cache_test.go new file mode 100644 index 00000000..ab807984 --- /dev/null +++ b/firewall/cache_test.go @@ -0,0 +1,69 @@ +package firewall + +import ( + "bytes" + "log/slog" + "strings" + "testing" + + "github.com/slackhq/nebula/test" + "github.com/stretchr/testify/assert" +) + +// The tests below pin the log format produced by ConntrackCacheTicker.Get +// so changes cannot silently break what operators are grepping for. The +// ticker's internal state (cache + cacheTick) is poked directly to avoid +// racing a goroutine-driven tick in tests. + +func newFixedTicker(t *testing.T, l *slog.Logger, cacheLen int) *ConntrackCacheTicker { + t.Helper() + c := &ConntrackCacheTicker{ + l: l, + cache: make(ConntrackCache, cacheLen), + } + for i := 0; i < cacheLen; i++ { + c.cache[Packet{LocalPort: uint16(i) + 1}] = struct{}{} + } + c.cacheTick.Store(1) // cacheV starts at 0, so Get() takes the reset path + return c +} + +func TestConntrackCacheTicker_Get_TextFormat(t *testing.T) { + buf := &bytes.Buffer{} + l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug) + + c := newFixedTicker(t, l, 3) + c.Get() + + assert.Equal(t, "level=DEBUG msg=\"resetting conntrack cache\" len=3\n", buf.String()) +} + +func TestConntrackCacheTicker_Get_JSONFormat(t *testing.T) { + buf := &bytes.Buffer{} + l := test.NewJSONLoggerWithOutput(buf, slog.LevelDebug) + + c := newFixedTicker(t, l, 2) + c.Get() + + assert.JSONEq(t, `{"level":"DEBUG","msg":"resetting conntrack cache","len":2}`, strings.TrimSpace(buf.String())) +} + +func TestConntrackCacheTicker_Get_QuietBelowDebug(t *testing.T) { + buf := &bytes.Buffer{} + l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelInfo) + + c := newFixedTicker(t, l, 5) + c.Get() + + assert.Empty(t, buf.String()) +} + +func TestConntrackCacheTicker_Get_QuietWhenCacheEmpty(t *testing.T) { + buf := &bytes.Buffer{} + l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug) + + c := newFixedTicker(t, l, 0) + c.Get() + + assert.Empty(t, buf.String()) +} diff --git a/firewall_test.go b/firewall_test.go index a2133760..cbf090fd 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -3,13 +3,13 @@ package nebula import ( "bytes" "errors" + "log/slog" "math" "net/netip" "testing" "time" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" @@ -58,9 +58,8 @@ func TestNewFirewall(t *testing.T) { } func TestFirewall_AddRule(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) c := &dummyCert{} fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) @@ -177,9 +176,8 @@ func TestFirewall_AddRule(t *testing.T) { } func TestFirewall_Drop(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) p := firewall.Packet{ @@ -254,9 +252,8 @@ func TestFirewall_Drop(t *testing.T) { } func TestFirewall_DropV6(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7")) @@ -485,9 +482,8 @@ func BenchmarkFirewallTable_match(b *testing.B) { } func TestFirewall_Drop2(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) @@ -544,9 +540,8 @@ func TestFirewall_Drop2(t *testing.T) { } func TestFirewall_Drop3(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) @@ -633,9 +628,8 @@ func TestFirewall_Drop3(t *testing.T) { } func TestFirewall_Drop3V6(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7")) @@ -671,9 +665,8 @@ func TestFirewall_Drop3V6(t *testing.T) { } func TestFirewall_DropConntrackReload(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) @@ -736,9 +729,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) { } func TestFirewall_ICMPPortBehavior(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) @@ -880,9 +872,8 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { } func TestFirewall_DropIPSpoofing(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24")) @@ -1045,25 +1036,25 @@ func TestNewFirewallFromConfig(t *testing.T) { cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil) require.NoError(t, err) - conf := config.NewC(l) + conf := config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": "asdf"} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") // Test both port and code - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") // Test missing host, group, cidr, ca_name and ca_sha - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") // Test code/port error - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") @@ -1073,25 +1064,25 @@ func TestNewFirewallFromConfig(t *testing.T) { require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") // Test proto error - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") // Test cidr parse error - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test local_cidr parse error - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test both group and groups - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") @@ -1100,35 +1091,35 @@ func TestNewFirewallFromConfig(t *testing.T) { func TestAddFirewallRulesFromConfig(t *testing.T) { l := test.NewLogger() // Test adding tcp rule - conf := config.NewC(l) + conf := config.NewC(test.NewLogger()) mf := &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding udp rule - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding icmp rule - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding icmp rule no port - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"proto": "icmp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding any rule - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) @@ -1136,14 +1127,14 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { // Test adding rule with cidr cidr := netip.MustParsePrefix("10.0.0.0/8") - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall) // Test adding rule with local_cidr - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) @@ -1151,82 +1142,82 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { // Test adding rule with cidr ipv6 cidr6 := netip.MustParsePrefix("fd00::/8") - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall) // Test adding rule with any cidr - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall) // Test adding rule with junk cidr - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}} require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP") // Test adding rule with local_cidr ipv6 - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall) // Test adding rule with any local_cidr - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall) // Test adding rule with junk local_cidr - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}} require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP") // Test adding rule with ca_sha - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall) // Test adding rule with ca_name - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall) // Test single group - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) // Test single groups - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) // Test multiple AND groups - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall) // Test Add error - conf = config.NewC(l) + conf = config.NewC(test.NewLogger()) mf = &mockFirewall{} mf.nextCallReturn = errors.New("test error") conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} @@ -1234,9 +1225,8 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { } func TestFirewall_convertRule(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) // Ensure group array of 1 is converted and a warning is printed c := map[string]any{ @@ -1244,7 +1234,9 @@ func TestFirewall_convertRule(t *testing.T) { } r, err := convertRule(l, c, "test", 1) - assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value") + assert.Contains(t, ob.String(), "group was an array with a single value, converting to simple value") + assert.Contains(t, ob.String(), "table=test") + assert.Contains(t, ob.String(), "rule=1") require.NoError(t, err) assert.Equal(t, []string{"group1"}, r.Groups) @@ -1270,9 +1262,8 @@ func TestFirewall_convertRule(t *testing.T) { } func TestFirewall_convertRuleSanity(t *testing.T) { - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) noWarningPlease := []map[string]any{ {"group": "group1"}, @@ -1386,7 +1377,7 @@ type testsetup struct { fw *Firewall } -func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup { +func newSetup(t *testing.T, l *slog.Logger, myPrefixes ...netip.Prefix) testsetup { c := dummyCert{ name: "me", networks: myPrefixes, @@ -1397,7 +1388,7 @@ func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testse return newSetupFromCert(t, l, c) } -func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup { +func newSetupFromCert(t *testing.T, l *slog.Logger, c dummyCert) testsetup { myVpnNetworksTable := new(bart.Lite) for _, prefix := range c.Networks() { myVpnNetworksTable.Insert(prefix) @@ -1414,9 +1405,8 @@ func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup { func TestFirewall_Drop_EnforceIPMatch(t *testing.T) { t.Parallel() - l := test.NewLogger() ob := &bytes.Buffer{} - l.SetOutput(ob) + l := test.NewLoggerWithOutput(ob) myPrefix := netip.MustParsePrefix("1.1.1.1/8") // for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out diff --git a/go.mod b/go.mod index 169cf1ca..0de2df7d 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,6 @@ require ( github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f github.com/prometheus/client_golang v1.23.2 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 - github.com/sirupsen/logrus v1.9.4 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stretchr/testify v1.11.1 diff --git a/go.sum b/go.sum index d56177b7..aad164c7 100644 --- a/go.sum +++ b/go.sum @@ -133,8 +133,6 @@ github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncj github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= -github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w= -github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw= diff --git a/handshake_ix.go b/handshake_ix.go index f081eb8c..a086960e 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -2,11 +2,12 @@ package nebula import ( "bytes" + "context" + "log/slog" "net/netip" "time" "github.com/flynn/noise" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" ) @@ -18,8 +19,11 @@ import ( func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { err := f.handshakeManager.allocateIndex(hh) if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") + f.l.Error("Failed to generate index", + "error", err, + "vpnAddrs", hh.hostinfo.vpnAddrs, + "handshake", m{"stage": 0, "style": "ix_psk0"}, + ) return false } @@ -39,28 +43,32 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { crt := cs.getCertificate(v) if crt == nil { - f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). - WithField("certVersion", v). - Error("Unable to handshake with host because no certificate is available") + f.l.Error("Unable to handshake with host because no certificate is available", + "vpnAddrs", hh.hostinfo.vpnAddrs, + "handshake", m{"stage": 0, "style": "ix_psk0"}, + "certVersion", v, + ) return false } crtHs := cs.getHandshakeBytes(v) if crtHs == nil { - f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). - WithField("certVersion", v). - Error("Unable to handshake with host because no certificate handshake bytes is available") + f.l.Error("Unable to handshake with host because no certificate handshake bytes is available", + "vpnAddrs", hh.hostinfo.vpnAddrs, + "handshake", m{"stage": 0, "style": "ix_psk0"}, + "certVersion", v, + ) return false } - ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX) + ci, err := NewConnectionState(cs, crt, true, noise.HandshakeIX) if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). - WithField("certVersion", v). - Error("Failed to create connection state") + f.l.Error("Failed to create connection state", + "error", err, + "vpnAddrs", hh.hostinfo.vpnAddrs, + "handshake", m{"stage": 0, "style": "ix_psk0"}, + "certVersion", v, + ) return false } hh.hostinfo.ConnectionState = ci @@ -76,9 +84,12 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { hsBytes, err := hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). - WithField("certVersion", v). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") + f.l.Error("Failed to marshal handshake message", + "error", err, + "vpnAddrs", hh.hostinfo.vpnAddrs, + "certVersion", v, + "handshake", m{"stage": 0, "style": "ix_psk0"}, + ) return false } @@ -86,8 +97,11 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { msg, _, _, err := ci.H.WriteMessage(h, hsBytes) if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") + f.l.Error("Failed to call noise.WriteMessage", + "error", err, + "vpnAddrs", hh.hostinfo.vpnAddrs, + "handshake", m{"stage": 0, "style": "ix_psk0"}, + ) return false } @@ -104,18 +118,21 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) cs := f.pki.getCertState() crt := cs.GetDefaultCertificate() if crt == nil { - f.l.WithField("from", via). - WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). - WithField("certVersion", cs.initiatingVersion). - Error("Unable to handshake with host because no certificate is available") + f.l.Error("Unable to handshake with host because no certificate is available", + "from", via, + "handshake", m{"stage": 0, "style": "ix_psk0"}, + "certVersion", cs.initiatingVersion, + ) return } - ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX) + ci, err := NewConnectionState(cs, crt, false, noise.HandshakeIX) if err != nil { - f.l.WithError(err).WithField("from", via). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Error("Failed to create connection state") + f.l.Error("Failed to create connection state", + "error", err, + "from", via, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } @@ -124,26 +141,32 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { - f.l.WithError(err).WithField("from", via). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Error("Failed to call noise.ReadMessage") + f.l.Error("Failed to call noise.ReadMessage", + "error", err, + "from", via, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } hs := &NebulaHandshake{} err = hs.Unmarshal(msg) if err != nil || hs.Details == nil { - f.l.WithError(err).WithField("from", via). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Error("Failed unmarshal handshake message") + f.l.Error("Failed unmarshal handshake message", + "error", err, + "from", via, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) if err != nil { - f.l.WithError(err).WithField("from", via). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Info("Handshake did not contain a certificate") + f.l.Info("Handshake did not contain a certificate", + "error", err, + "from", via, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } @@ -154,23 +177,30 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) fp = "" } - e := f.l.WithError(err).WithField("from", via). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("certVpnNetworks", rc.Networks()). - WithField("certFingerprint", fp) - - if f.l.Level >= logrus.DebugLevel { - e = e.WithField("cert", rc) + attrs := []slog.Attr{ + slog.Any("error", err), + slog.Any("from", via), + slog.Any("handshake", m{"stage": 1, "style": "ix_psk0"}), + slog.Any("certVpnNetworks", rc.Networks()), + slog.String("certFingerprint", fp), + } + if f.l.Enabled(context.Background(), slog.LevelDebug) { + attrs = append(attrs, slog.Any("cert", rc)) } - e.Info("Invalid certificate from host") + // LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that + // callers grow conditionally, which has no pair-form equivalent. + //nolint:sloglint + f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...) return } if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) { - f.l.WithField("from", via). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake") + f.l.Info("public key mismatch between certificate and handshake", + "from", via, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + "cert", remoteCert, + ) return } @@ -178,12 +208,13 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) // We started off using the wrong certificate version, lets see if we can match the version that was sent to us myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version()) if myCertOtherVersion == nil { - if f.l.Level >= logrus.DebugLevel { - f.l.WithError(err).WithFields(m{ - "from": via, - "handshake": m{"stage": 1, "style": "ix_psk0"}, - "cert": remoteCert, - }).Debug("Might be unable to handshake with host due to missing certificate version") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Might be unable to handshake with host due to missing certificate version", + "error", err, + "from", via, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + "cert", remoteCert, + ) } } else { // Record the certificate we are actually using @@ -192,10 +223,12 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) } if len(remoteCert.Certificate.Networks()) == 0 { - f.l.WithError(err).WithField("from", via). - WithField("cert", remoteCert). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Info("No networks in certificate") + f.l.Info("No networks in certificate", + "error", err, + "from", via, + "cert", remoteCert, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } @@ -209,12 +242,15 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) vpnAddrs := make([]netip.Addr, len(vpnNetworks)) for i, network := range vpnNetworks { if f.myVpnAddrsTable.Contains(network.Addr()) { - f.l.WithField("vpnNetworks", vpnNetworks).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") + f.l.Error("Refusing to handshake with myself", + "vpnNetworks", vpnNetworks, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } vpnAddrs[i] = network.Addr() @@ -226,20 +262,28 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) if !via.IsRelayed { // We only want to apply the remote allow list for direct tunnels here if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) { - f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). - Debug("lighthouse.remote_allow_list denied incoming handshake") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("lighthouse.remote_allow_list denied incoming handshake", + "vpnAddrs", vpnAddrs, + "from", via, + ) + } return } } myIndex, err := generateIndex(f.l) if err != nil { - f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index") + f.l.Error("Failed to generate index", + "error", err, + "vpnAddrs", vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } @@ -257,18 +301,18 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) }, } - msgRxL := f.l.WithFields(m{ - "vpnAddrs": vpnAddrs, - "from": via, - "certName": certName, - "certVersion": certVersion, - "fingerprint": fingerprint, - "issuer": issuer, - "initiatorIndex": hs.Details.InitiatorIndex, - "responderIndex": hs.Details.ResponderIndex, - "remoteIndex": h.RemoteIndex, - "handshake": m{"stage": 1, "style": "ix_psk0"}, - }) + msgRxL := f.l.With( + "vpnAddrs", vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "initiatorIndex", hs.Details.InitiatorIndex, + "responderIndex", hs.Details.ResponderIndex, + "remoteIndex", h.RemoteIndex, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) if anyVpnAddrsInCommon { msgRxL.Info("Handshake message received") @@ -280,8 +324,9 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) hs.Details.ResponderIndex = myIndex hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version()) if hs.Details.Cert == nil { - msgRxL.WithField("myCertVersion", ci.myCert.Version()). - Error("Unable to handshake with host because no certificate handshake bytes is available") + msgRxL.Error("Unable to handshake with host because no certificate handshake bytes is available", + "myCertVersion", ci.myCert.Version(), + ) return } @@ -291,32 +336,43 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) hsBytes, err := hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message") + f.l.Error("Failed to marshal handshake message", + "error", err, + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2) msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes) if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") + f.l.Error("Failed to call noise.WriteMessage", + "error", err, + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } else if dKey == nil || eKey == nil { - f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key") + f.l.Error("Noise did not arrive at a key", + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } @@ -358,13 +414,20 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) if !via.IsRelayed { err := f.outside.WriteTo(msg, via.UdpAddr) if err != nil { - f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). - WithError(err).Error("Failed to send handshake message") + f.l.Error("Failed to send handshake message", + "vpnAddrs", existing.vpnAddrs, + "from", via, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + "cached", true, + "error", err, + ) } else { - f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). - Info("Handshake message sent") + f.l.Info("Handshake message sent", + "vpnAddrs", existing.vpnAddrs, + "from", via, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + "cached", true, + ) } return } else { @@ -374,50 +437,67 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) } hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). - Info("Handshake message sent") + f.l.Info("Handshake message sent", + "vpnAddrs", existing.vpnAddrs, + "relay", via.relayHI.vpnAddrs[0], + "handshake", m{"stage": 2, "style": "ix_psk0"}, + "cached", true, + ) return } case ErrExistingHostInfo: // This means there was an existing tunnel and this handshake was older than the one we are currently based on - f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("oldHandshakeTime", existing.lastHandshakeTime). - WithField("newHandshakeTime", hostinfo.lastHandshakeTime). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Info("Handshake too old") + f.l.Info("Handshake too old", + "vpnAddrs", vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "oldHandshakeTime", existing.lastHandshakeTime, + "newHandshakeTime", hostinfo.lastHandshakeTime, + "fingerprint", fingerprint, + "issuer", issuer, + "initiatorIndex", hs.Details.InitiatorIndex, + "responderIndex", hs.Details.ResponderIndex, + "remoteIndex", h.RemoteIndex, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) return case ErrLocalIndexCollision: // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry - f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs). - Error("Failed to add HostInfo due to localIndex collision") + f.l.Error("Failed to add HostInfo due to localIndex collision", + "vpnAddrs", vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "initiatorIndex", hs.Details.InitiatorIndex, + "responderIndex", hs.Details.ResponderIndex, + "remoteIndex", h.RemoteIndex, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + "localIndex", hostinfo.localIndexId, + "collision", existing.vpnAddrs, + ) return default: // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // And we forget to update it here - f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Error("Failed to add HostInfo to HostMap") + f.l.Error("Failed to add HostInfo to HostMap", + "error", err, + "vpnAddrs", vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "initiatorIndex", hs.Details.InitiatorIndex, + "responderIndex", hs.Details.ResponderIndex, + "remoteIndex", h.RemoteIndex, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) return } } @@ -426,15 +506,20 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) if !via.IsRelayed { err = f.outside.WriteTo(msg, via.UdpAddr) - log := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) + log := f.l.With( + "vpnAddrs", vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "initiatorIndex", hs.Details.InitiatorIndex, + "responderIndex", hs.Details.ResponderIndex, + "remoteIndex", h.RemoteIndex, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + ) if err != nil { - log.WithError(err).Error("Failed to send handshake") + log.Error("Failed to send handshake", "error", err) } else { log.Info("Handshake message sent") } @@ -448,14 +533,18 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) // it's correctly marked as working. via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Info("Handshake message sent") + f.l.Info("Handshake message sent", + "vpnAddrs", vpnAddrs, + "relay", via.relayHI.vpnAddrs[0], + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "initiatorIndex", hs.Details.InitiatorIndex, + "responderIndex", hs.Details.ResponderIndex, + "remoteIndex", h.RemoteIndex, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + ) } f.connectionManager.AddTrafficWatch(hostinfo) @@ -483,7 +572,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe if !via.IsRelayed { // The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list. if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) { - f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("lighthouse.remote_allow_list denied incoming handshake", + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + ) + } return false } } @@ -491,18 +585,24 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe ci := hostinfo.ConnectionState msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). - Error("Failed to call noise.ReadMessage") + f.l.Error("Failed to call noise.ReadMessage", + "error", err, + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + "header", h, + ) // We don't want to tear down the connection on a bad ReadMessage because it could be an attacker trying // to DOS us. Every other error condition after should to allow a possible good handshake to complete in the // near future return false } else if dKey == nil || eKey == nil { - f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Error("Noise did not arrive at a key") + f.l.Error("Noise did not arrive at a key", + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + ) // This should be impossible in IX but just in case, if we get here then there is no chance to recover // the handshake state machine. Tear it down @@ -512,8 +612,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe hs := &NebulaHandshake{} err = hs.Unmarshal(msg) if err != nil || hs.Details == nil { - f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") + f.l.Error("Failed unmarshal handshake message", + "error", err, + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + ) // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again return true @@ -521,10 +625,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) if err != nil { - f.l.WithError(err).WithField("from", via). - WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Info("Handshake did not contain a certificate") + f.l.Info("Handshake did not contain a certificate", + "error", err, + "from", via, + "vpnAddrs", hostinfo.vpnAddrs, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + ) return true } @@ -535,32 +641,41 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe fp = "" } - e := f.l.WithError(err).WithField("from", via). - WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - WithField("certFingerprint", fp). - WithField("certVpnNetworks", rc.Networks()) - - if f.l.Level >= logrus.DebugLevel { - e = e.WithField("cert", rc) + attrs := []slog.Attr{ + slog.Any("error", err), + slog.Any("from", via), + slog.Any("vpnAddrs", hostinfo.vpnAddrs), + slog.Any("handshake", m{"stage": 2, "style": "ix_psk0"}), + slog.String("certFingerprint", fp), + slog.Any("certVpnNetworks", rc.Networks()), + } + if f.l.Enabled(context.Background(), slog.LevelDebug) { + attrs = append(attrs, slog.Any("cert", rc)) } - e.Info("Invalid certificate from host") + // LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that + // callers grow conditionally, which has no pair-form equivalent. + //nolint:sloglint + f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...) return true } if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) { - f.l.WithField("from", via). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake") + f.l.Info("public key mismatch between certificate and handshake", + "from", via, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + "cert", remoteCert, + ) return true } if len(remoteCert.Certificate.Networks()) == 0 { - f.l.WithError(err).WithField("from", via). - WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("cert", remoteCert). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Info("No networks in certificate") + f.l.Info("No networks in certificate", + "error", err, + "from", via, + "vpnAddrs", hostinfo.vpnAddrs, + "cert", remoteCert, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + ) return true } @@ -601,12 +716,14 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe // Ensure the right host responded if !correctHostResponded { - f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). - WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Info("Incorrect host responded to handshake") + f.l.Info("Incorrect host responded to handshake", + "intendedVpnAddrs", hostinfo.vpnAddrs, + "haveVpnNetworks", vpnNetworks, + "from", via, + "certName", certName, + "certVersion", certVersion, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + ) // Release our old handshake from pending, it should not continue f.handshakeManager.DeleteHostInfo(hostinfo) @@ -618,10 +735,11 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe newHH.hostinfo.remotes = hostinfo.remotes newHH.hostinfo.remotes.BlockRemote(via) - f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()). - WithField("vpnNetworks", vpnNetworks). - WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())). - Info("Blocked addresses for handshakes") + f.l.Info("Blocked addresses for handshakes", + "blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes(), + "vpnNetworks", vpnNetworks, + "remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges()), + ) // Swap the packet store to benefit the original intended recipient newHH.packetStore = hh.packetStore @@ -639,15 +757,20 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe ci.window.Update(f.l, 2) duration := time.Since(hh.startTime).Nanoseconds() - msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - WithField("durationNs", duration). - WithField("sentCachedPackets", len(hh.packetStore)) + msgRxL := f.l.With( + "vpnAddrs", vpnAddrs, + "from", via, + "certName", certName, + "certVersion", certVersion, + "fingerprint", fingerprint, + "issuer", issuer, + "initiatorIndex", hs.Details.InitiatorIndex, + "responderIndex", hs.Details.ResponderIndex, + "remoteIndex", h.RemoteIndex, + "handshake", m{"stage": 2, "style": "ix_psk0"}, + "durationNs", duration, + "sentCachedPackets", len(hh.packetStore), + ) if anyVpnAddrsInCommon { msgRxL.Info("Handshake message received") } else { @@ -663,8 +786,10 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe f.handshakeManager.Complete(hostinfo, f) f.connectionManager.AddTrafficWatch(hostinfo) - if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore)) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Sending stored packets", + "count", len(hh.packetStore), + ) } if len(hh.packetStore) > 0 { diff --git a/handshake_manager.go b/handshake_manager.go index 25a59b6d..8040ec2e 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -6,13 +6,13 @@ import ( "crypto/rand" "encoding/binary" "errors" + "log/slog" "net/netip" "slices" "sync" "time" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" @@ -59,7 +59,7 @@ type HandshakeManager struct { metricInitiated metrics.Counter metricTimedOut metrics.Counter f *Interface - l *logrus.Logger + l *slog.Logger // can be used to trigger outbound handshake for the given vpnIp trigger chan netip.Addr @@ -78,32 +78,32 @@ type HandshakeHostInfo struct { hostinfo *HostInfo } -func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { +func (hh *HandshakeHostInfo) cachePacket(l *slog.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { if len(hh.packetStore) < 100 { tempPacket := make([]byte, len(packet)) copy(tempPacket, packet) hh.packetStore = append(hh.packetStore, &cachedPacket{t, st, f, tempPacket}) - if l.Level >= logrus.DebugLevel { - hh.hostinfo.logger(l). - WithField("length", len(hh.packetStore)). - WithField("stored", true). - Debugf("Packet store") + if l.Enabled(context.Background(), slog.LevelDebug) { + hh.hostinfo.logger(l).Debug("Packet store", + "length", len(hh.packetStore), + "stored", true, + ) } } else { m.dropped.Inc(1) - if l.Level >= logrus.DebugLevel { - hh.hostinfo.logger(l). - WithField("length", len(hh.packetStore)). - WithField("stored", false). - Debugf("Packet store") + if l.Enabled(context.Background(), slog.LevelDebug) { + hh.hostinfo.logger(l).Debug("Packet store", + "length", len(hh.packetStore), + "stored", false, + ) } } } -func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager { +func NewHandshakeManager(l *slog.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager { return &HandshakeManager{ vpnIps: map[netip.Addr]*HandshakeHostInfo{}, indexes: map[uint32]*HandshakeHostInfo{}, @@ -140,7 +140,7 @@ func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *head // First remote allow list check before we know the vpnIp if !via.IsRelayed { if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) { - hm.l.WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake") + hm.l.Debug("lighthouse.remote_allow_list denied incoming handshake", "from", via) return } } @@ -183,12 +183,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered hostinfo := hh.hostinfo // If we are out of time, clean up if hh.counter >= hm.config.retries { - hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())). - WithField("initiatorIndex", hh.hostinfo.localIndexId). - WithField("remoteIndex", hh.hostinfo.remoteIndexId). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("durationNs", time.Since(hh.startTime).Nanoseconds()). - Info("Handshake timed out") + hh.hostinfo.logger(hm.l).Info("Handshake timed out", + "udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()), + "initiatorIndex", hh.hostinfo.localIndexId, + "remoteIndex", hh.hostinfo.remoteIndexId, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + "durationNs", time.Since(hh.startTime).Nanoseconds(), + ) hm.metricTimedOut.Inc(1) hm.DeleteHostInfo(hostinfo) return @@ -241,10 +242,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr) if err != nil { - hostinfo.logger(hm.l).WithField("udpAddr", addr). - WithField("initiatorIndex", hostinfo.localIndexId). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithError(err).Error("Failed to send handshake message") + hostinfo.logger(hm.l).Error("Failed to send handshake message", + "udpAddr", addr, + "initiatorIndex", hostinfo.localIndexId, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + "error", err, + ) } else { sentTo = append(sentTo, addr) @@ -254,19 +257,21 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered // Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout, // so only log when the list of remotes has changed if remotesHaveChanged { - hostinfo.logger(hm.l).WithField("udpAddrs", sentTo). - WithField("initiatorIndex", hostinfo.localIndexId). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Info("Handshake message sent") - } else if hm.l.Level >= logrus.DebugLevel { - hostinfo.logger(hm.l).WithField("udpAddrs", sentTo). - WithField("initiatorIndex", hostinfo.localIndexId). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Debug("Handshake message sent") + hostinfo.logger(hm.l).Info("Handshake message sent", + "udpAddrs", sentTo, + "initiatorIndex", hostinfo.localIndexId, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) + } else if hm.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(hm.l).Debug("Handshake message sent", + "udpAddrs", sentTo, + "initiatorIndex", hostinfo.localIndexId, + "handshake", m{"stage": 1, "style": "ix_psk0"}, + ) } if hm.config.useRelays && len(hostinfo.remotes.relays) > 0 { - hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts") + hostinfo.logger(hm.l).Info("Attempt to relay through hosts", "relays", hostinfo.remotes.relays) // Send a RelayRequest to all known Relay IP's for _, relay := range hostinfo.remotes.relays { // Don't relay through the host I'm trying to connect to @@ -281,7 +286,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay) if relayHostInfo == nil || !relayHostInfo.remote.IsValid() { - hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") + hostinfo.logger(hm.l).Info("Establish tunnel to relay target", "relay", relay.String()) hm.f.Handshake(relay) continue } @@ -292,7 +297,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered if relayHostInfo.remote.IsValid() { idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested) if err != nil { - hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap") + hostinfo.logger(hm.l).Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err) } m := NebulaControl{ @@ -326,17 +331,15 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered msg, err := m.Marshal() if err != nil { - hostinfo.logger(hm.l). - WithError(err). - Error("Failed to marshal Control message to create relay") + hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err) } else { hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.f.myVpnAddrs[0], - "relayTo": vpnIp, - "initiatorRelayIndex": idx, - "relay": relay}). - Info("send CreateRelayRequest") + hm.l.Info("send CreateRelayRequest", + "relayFrom", hm.f.myVpnAddrs[0], + "relayTo", vpnIp, + "initiatorRelayIndex", idx, + "relay", relay, + ) } } continue @@ -344,14 +347,14 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered switch existingRelay.State { case Established: - hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay") + hostinfo.logger(hm.l).Info("Send handshake via relay", "relay", relay.String()) hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) case Disestablished: // Mark this relay as 'requested' relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested) fallthrough case Requested: - hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") + hostinfo.logger(hm.l).Info("Re-send CreateRelay request", "relay", relay.String()) // Re-send the CreateRelay request, in case the previous one was lost. m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, @@ -383,28 +386,26 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered } msg, err := m.Marshal() if err != nil { - hostinfo.logger(hm.l). - WithError(err). - Error("Failed to marshal Control message to create relay") + hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err) } else { // This must send over the hostinfo, not over hm.Hosts[ip] hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.f.myVpnAddrs[0], - "relayTo": vpnIp, - "initiatorRelayIndex": existingRelay.LocalIndex, - "relay": relay}). - Info("send CreateRelayRequest") + hm.l.Info("send CreateRelayRequest", + "relayFrom", hm.f.myVpnAddrs[0], + "relayTo", vpnIp, + "initiatorRelayIndex", existingRelay.LocalIndex, + "relay", relay, + ) } case PeerRequested: // PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case. fallthrough default: - hostinfo.logger(hm.l). - WithField("vpnIp", vpnIp). - WithField("state", existingRelay.State). - WithField("relay", relay). - Errorf("Relay unexpected state") + hostinfo.logger(hm.l).Error("Relay unexpected state", + "vpnIp", vpnIp, + "state", existingRelay.State, + "relay", relay, + ) } } @@ -549,9 +550,10 @@ func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. - hostinfo.logger(hm.l). - WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs). - Info("New host shadows existing host remoteIndex") + hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex", + "remoteIndex", hostinfo.remoteIndexId, + "collision", existingRemoteIndex.vpnAddrs, + ) } hm.mainHostMap.unlockedAddHostInfo(hostinfo, f) @@ -571,9 +573,10 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { if found && existingRemoteIndex != nil { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. - hostinfo.logger(hm.l). - WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs). - Info("New host shadows existing host remoteIndex") + hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex", + "remoteIndex", hostinfo.remoteIndexId, + "collision", existingRemoteIndex.vpnAddrs, + ) } // We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap. @@ -629,10 +632,11 @@ func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { hm.indexes = map[uint32]*HandshakeHostInfo{} } - if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps), - "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). - Debug("Pending hostmap hostInfo deleted") + if hm.l.Enabled(context.Background(), slog.LevelDebug) { + hm.l.Debug("Pending hostmap hostInfo deleted", + "hostMap", m{"mapTotalSize": len(hm.vpnIps), + "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}, + ) } } @@ -700,7 +704,7 @@ func (hm *HandshakeManager) EmitStats() { // Utility functions below -func generateIndex(l *logrus.Logger) (uint32, error) { +func generateIndex(l *slog.Logger) (uint32, error) { b := make([]byte, 4) // Let zero mean we don't know the ID, so don't generate zero @@ -708,16 +712,15 @@ func generateIndex(l *logrus.Logger) (uint32, error) { for index == 0 { _, err := rand.Read(b) if err != nil { - l.Errorln(err) + l.Error("Failed to generate index", "error", err) return 0, err } index = binary.BigEndian.Uint32(b) } - if l.Level >= logrus.DebugLevel { - l.WithField("index", index). - Debug("Generated index") + if l.Enabled(context.Background(), slog.LevelDebug) { + l.Debug("Generated index", "index", index) } return index, nil } diff --git a/hostmap.go b/hostmap.go index 25181d83..08acd1be 100644 --- a/hostmap.go +++ b/hostmap.go @@ -1,9 +1,11 @@ package nebula import ( + "context" "encoding/json" "errors" "fmt" + "log/slog" "net" "net/netip" "slices" @@ -13,10 +15,10 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/logging" ) const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address @@ -60,7 +62,7 @@ type HostMap struct { RemoteIndexes map[uint32]*HostInfo Hosts map[netip.Addr]*HostInfo preferredRanges atomic.Pointer[[]netip.Prefix] - l *logrus.Logger + l *slog.Logger } // For synchronization, treat the pointed-to Relay struct as immutable. To edit the Relay @@ -313,7 +315,7 @@ type cachedPacketMetrics struct { dropped metrics.Counter } -func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap { +func NewHostMapFromConfig(l *slog.Logger, c *config.C) *HostMap { hm := newHostMap(l) hm.reload(c, true) @@ -321,13 +323,12 @@ func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap { hm.reload(c, false) }) - l.WithField("preferredRanges", hm.GetPreferredRanges()). - Info("Main HostMap created") + l.Info("Main HostMap created", "preferredRanges", hm.GetPreferredRanges()) return hm } -func newHostMap(l *logrus.Logger) *HostMap { +func newHostMap(l *slog.Logger) *HostMap { return &HostMap{ Indexes: map[uint32]*HostInfo{}, Relays: map[uint32]*HostInfo{}, @@ -346,7 +347,10 @@ func (hm *HostMap) reload(c *config.C, initial bool) { preferredRange, err := netip.ParsePrefix(rawPreferredRange) if err != nil { - hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring") + hm.l.Warn("Failed to parse preferred ranges, ignoring", + "error", err, + "range", rawPreferredRanges, + ) continue } @@ -355,7 +359,10 @@ func (hm *HostMap) reload(c *config.C, initial bool) { oldRanges := hm.preferredRanges.Swap(&preferredRanges) if !initial { - hm.l.WithField("oldPreferredRanges", *oldRanges).WithField("newPreferredRanges", preferredRanges).Info("preferred_ranges changed") + hm.l.Info("preferred_ranges changed", + "oldPreferredRanges", *oldRanges, + "newPreferredRanges", preferredRanges, + ) } } } @@ -488,10 +495,11 @@ func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Ad hm.Indexes = map[uint32]*HostInfo{} } - if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts), - "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). - Debug("Hostmap hostInfo deleted") + if hm.l.Enabled(context.Background(), slog.LevelDebug) { + hm.l.Debug("Hostmap hostInfo deleted", + "hostMap", m{"mapTotalSize": len(hm.Hosts), + "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}, + ) } if isLastHostinfo { @@ -615,10 +623,11 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { hm.Indexes[hostinfo.localIndexId] = hostinfo hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo - if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts), - "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}). - Debug("Hostmap vpnIp added") + if hm.l.Enabled(context.Background(), slog.LevelDebug) { + hm.l.Debug("Hostmap vpnIp added", + "hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts), + "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}, + ) } } @@ -784,18 +793,21 @@ func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certifica } } -func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { +// logger returns a derived slog.Logger with per-hostinfo fields pre-bound. +func (i *HostInfo) logger(l *slog.Logger) *slog.Logger { if i == nil { - return logrus.NewEntry(l) + return l } - li := l.WithField("vpnAddrs", i.vpnAddrs). - WithField("localIndex", i.localIndexId). - WithField("remoteIndex", i.remoteIndexId) + li := l.With( + "vpnAddrs", i.vpnAddrs, + "localIndex", i.localIndexId, + "remoteIndex", i.remoteIndexId, + ) if connState := i.ConnectionState; connState != nil { if peerCert := connState.peerCert; peerCert != nil { - li = li.WithField("certName", peerCert.Certificate.Name()) + li = li.With("certName", peerCert.Certificate.Name()) } } @@ -804,14 +816,17 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { // Utility functions -func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { +func localAddrs(l *slog.Logger, allowList *LocalAllowList) []netip.Addr { //FIXME: This function is pretty garbage var finalAddrs []netip.Addr ifaces, _ := net.Interfaces() for _, i := range ifaces { allow := allowList.AllowName(i.Name) - if l.Level >= logrus.TraceLevel { - l.WithField("interfaceName", i.Name).WithField("allow", allow).Trace("localAllowList.AllowName") + if l.Enabled(context.Background(), logging.LevelTrace) { + l.Log(context.Background(), logging.LevelTrace, "localAllowList.AllowName", + "interfaceName", i.Name, + "allow", allow, + ) } if !allow { @@ -829,8 +844,8 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { } if !addr.IsValid() { - if l.Level >= logrus.DebugLevel { - l.WithField("localAddr", rawAddr).Debug("addr was invalid") + if l.Enabled(context.Background(), slog.LevelDebug) { + l.Debug("addr was invalid", "localAddr", rawAddr) } continue } @@ -838,8 +853,11 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false { isAllowed := allowList.Allow(addr) - if l.Level >= logrus.TraceLevel { - l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow") + if l.Enabled(context.Background(), logging.LevelTrace) { + l.Log(context.Background(), logging.LevelTrace, "localAllowList.Allow", + "localAddr", addr, + "allowed", isAllowed, + ) } if !isAllowed { continue diff --git a/hostmap_test.go b/hostmap_test.go index e34a4ad0..2bd7bd43 100644 --- a/hostmap_test.go +++ b/hostmap_test.go @@ -196,7 +196,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { func TestHostMap_reload(t *testing.T) { l := test.NewLogger() - c := config.NewC(l) + c := config.NewC(test.NewLogger()) hm := NewHostMapFromConfig(l, c) diff --git a/inside.go b/inside.go index 0d53f952..68cb38ec 100644 --- a/inside.go +++ b/inside.go @@ -1,9 +1,10 @@ package nebula import ( + "context" + "log/slog" "net/netip" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" @@ -14,8 +15,11 @@ import ( func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { err := newPacket(packet, false, fwPacket) if err != nil { - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Error while validating outbound packet", + "packet", packet, + "error", err, + ) } return } @@ -35,7 +39,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet if immediatelyForwardToSelf { _, err := f.readers[q].Write(packet) if err != nil { - f.l.WithError(err).Error("Failed to forward to tun") + f.l.Error("Failed to forward to tun", "error", err) } } // Otherwise, drop. On linux, we should never see these packets - Linux @@ -54,10 +58,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet if hostinfo == nil { f.rejectInside(packet, out, q) - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnAddr", fwPacket.RemoteAddr). - WithField("fwPacket", fwPacket). - Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks", + "vpnAddr", fwPacket.RemoteAddr, + "fwPacket", fwPacket, + ) } return } @@ -72,11 +77,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } else { f.rejectInside(packet, out, q) - if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l). - WithField("fwPacket", fwPacket). - WithField("reason", dropReason). - Debugln("dropping outbound packet") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("dropping outbound packet", + "fwPacket", fwPacket, + "reason", dropReason, + ) } } } @@ -93,7 +98,7 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) { _, err := f.readers[q].Write(out) if err != nil { - f.l.WithError(err).Error("Failed to write to tun") + f.l.Error("Failed to write to tun", "error", err) } } @@ -108,11 +113,11 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo * } if len(out) > iputil.MaxRejectPacketSize { - if f.l.GetLevel() >= logrus.InfoLevel { - f.l. - WithField("packet", packet). - WithField("outPacket", out). - Info("rejectOutside: packet too big, not sending") + if f.l.Enabled(context.Background(), slog.LevelInfo) { + f.l.Info("rejectOutside: packet too big, not sending", + "packet", packet, + "outPacket", out, + ) } return } @@ -184,10 +189,11 @@ func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cac // This would also need to interact with unsafe_route updates through reloading the config or // use of the use_system_route_table option - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("destination", destinationAddr). - WithField("originalGateway", gatewayAddr). - Debugln("Calculated gateway for ECMP not available, attempting other gateways") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Calculated gateway for ECMP not available, attempting other gateways", + "destination", destinationAddr, + "originalGateway", gatewayAddr, + ) } for i := range gateways { @@ -213,17 +219,18 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp fp := &firewall.Packet{} err := newPacket(p, false, fp) if err != nil { - f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err) + f.l.Warn("error while parsing outgoing packet for firewall check", "error", err) return } // check if packet is in outbound fw rules dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil) if dropReason != nil { - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("fwPacket", fp). - WithField("reason", dropReason). - Debugln("dropping cached packet") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("dropping cached packet", + "fwPacket", fp, + "reason", dropReason, + ) } return } @@ -239,9 +246,10 @@ func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.Message }) if hostInfo == nil { - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnAddr", vpnAddr). - Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes", + "vpnAddr", vpnAddr, + ) } return } @@ -297,12 +305,12 @@ func (f *Interface) SendVia(via *HostInfo, if noiseutil.EncryptLockNeeded { via.ConnectionState.writeLock.Unlock() } - via.logger(f.l). - WithField("outCap", cap(out)). - WithField("payloadLen", len(ad)). - WithField("headerLen", len(out)). - WithField("cipherOverhead", via.ConnectionState.eKey.Overhead()). - Error("SendVia out buffer not large enough for relay") + via.logger(f.l).Error("SendVia out buffer not large enough for relay", + "outCap", cap(out), + "payloadLen", len(ad), + "headerLen", len(out), + "cipherOverhead", via.ConnectionState.eKey.Overhead(), + ) return } @@ -322,12 +330,12 @@ func (f *Interface) SendVia(via *HostInfo, via.ConnectionState.writeLock.Unlock() } if err != nil { - via.logger(f.l).WithError(err).Info("Failed to EncryptDanger in sendVia") + via.logger(f.l).Info("Failed to EncryptDanger in sendVia", "error", err) return } err = f.writers[0].WriteTo(out, via.remote) if err != nil { - via.logger(f.l).WithError(err).Info("Failed to WriteTo in sendVia") + via.logger(f.l).Info("Failed to WriteTo in sendVia", "error", err) } f.connectionManager.RelayUsed(relay.LocalIndex) } @@ -366,8 +374,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. f.lightHouse.QueryServer(hostinfo.vpnAddrs[0]) hostinfo.lastRebindCount = f.rebindCount - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Lighthouse update triggered for punch due to rebind counter", + "vpnAddrs", hostinfo.vpnAddrs, + ) } } @@ -377,24 +387,30 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType ci.writeLock.Unlock() } if err != nil { - hostinfo.logger(f.l).WithError(err). - WithField("udpAddr", remote).WithField("counter", c). - WithField("attemptedCounter", c). - Error("Failed to encrypt outgoing packet") + hostinfo.logger(f.l).Error("Failed to encrypt outgoing packet", + "error", err, + "udpAddr", remote, + "counter", c, + "attemptedCounter", c, + ) return } if remote.IsValid() { err = f.writers[q].WriteTo(out, remote) if err != nil { - hostinfo.logger(f.l).WithError(err). - WithField("udpAddr", remote).Error("Failed to write outgoing packet") + hostinfo.logger(f.l).Error("Failed to write outgoing packet", + "error", err, + "udpAddr", remote, + ) } } else if hostinfo.remote.IsValid() { err = f.writers[q].WriteTo(out, hostinfo.remote) if err != nil { - hostinfo.logger(f.l).WithError(err). - WithField("udpAddr", remote).Error("Failed to write outgoing packet") + hostinfo.logger(f.l).Error("Failed to write outgoing packet", + "error", err, + "udpAddr", remote, + ) } } else { // Try to send via a relay @@ -402,7 +418,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP) if err != nil { hostinfo.relayState.DeleteRelay(relayIP) - hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo") + hostinfo.logger(f.l).Info("sendNoMetrics failed to find HostInfo", + "relay", relayIP, + "error", err, + ) continue } f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true) diff --git a/interface.go b/interface.go index 6d040884..5fedcdd3 100644 --- a/interface.go +++ b/interface.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/netip" "sync" "sync/atomic" @@ -12,7 +13,7 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -46,7 +47,7 @@ type InterfaceConfig struct { reQueryWait time.Duration ConntrackCacheTimeout time.Duration - l *logrus.Logger + l *slog.Logger } type Interface struct { @@ -100,7 +101,7 @@ type Interface struct { messageMetrics *MessageMetrics cachedPacketMetrics *cachedPacketMetrics - l *logrus.Logger + l *slog.Logger } type EncWriter interface { @@ -223,13 +224,16 @@ func (f *Interface) activate() error { addr, err := f.outside.LocalAddr() if err != nil { - f.l.WithError(err).Error("Failed to get udp listen address") + f.l.Error("Failed to get udp listen address", "error", err) } - f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks). - WithField("build", f.version).WithField("udpAddr", addr). - WithField("boringcrypto", boringEnabled()). - Info("Nebula interface is active") + f.l.Info("Nebula interface is active", + "interface", f.inside.Name(), + "networks", f.myVpnNetworks, + "build", f.version, + "udpAddr", addr, + "boringcrypto", boringEnabled(), + ) if f.routines > 1 { if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() { @@ -305,7 +309,7 @@ func (f *Interface) listenOut(i int) { li = f.outside } - ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.conntrackCacheTimeout) + ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout) lhh := f.lightHouse.NewRequestHandler() plaintext := make([]byte, udp.MTU) h := &header.H{} @@ -313,15 +317,15 @@ func (f *Interface) listenOut(i int) { nb := make([]byte, 12, 12) err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { - f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) + f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get()) }) if err != nil && !f.closed.Load() { - f.l.WithError(err).Error("Error while reading inbound packet, closing") + f.l.Error("Error while reading inbound packet, closing", "error", err) f.onFatal(err) } - f.l.Debugf("underlay reader %v is done", i) + f.l.Debug("underlay reader is done", "reader", i) } func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { @@ -330,22 +334,22 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) - conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.conntrackCacheTimeout) + conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout) for { n, err := reader.Read(packet) if err != nil { if !f.closed.Load() { - f.l.WithError(err).WithField("reader", i).Error("Error while reading outbound packet, closing") + f.l.Error("Error while reading outbound packet, closing", "error", err, "reader", i) f.onFatal(err) } break } - f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) + f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get()) } - f.l.Debugf("overlay reader %v is done", i) + f.l.Debug("overlay reader is done", "reader", i) } func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { @@ -365,7 +369,7 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) { if initial || c.HasChanged("pki.disconnect_invalid") { f.disconnectInvalid.Store(c.GetBool("pki.disconnect_invalid", true)) if !initial { - f.l.Infof("pki.disconnect_invalid changed to %v", f.disconnectInvalid.Load()) + f.l.Info("pki.disconnect_invalid changed", "value", f.disconnectInvalid.Load()) } } } @@ -379,7 +383,7 @@ func (f *Interface) reloadFirewall(c *config.C) { fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c) if err != nil { - f.l.WithError(err).Error("Error while creating firewall during reload") + f.l.Error("Error while creating firewall during reload", "error", err) return } @@ -392,10 +396,11 @@ func (f *Interface) reloadFirewall(c *config.C) { // If rulesVersion is back to zero, we have wrapped all the way around. Be // safe and just reset conntrack in this case. if fw.rulesVersion == 0 { - f.l.WithField("firewallHashes", fw.GetRuleHashes()). - WithField("oldFirewallHashes", oldFw.GetRuleHashes()). - WithField("rulesVersion", fw.rulesVersion). - Warn("firewall rulesVersion has overflowed, resetting conntrack") + f.l.Warn("firewall rulesVersion has overflowed, resetting conntrack", + "firewallHashes", fw.GetRuleHashes(), + "oldFirewallHashes", oldFw.GetRuleHashes(), + "rulesVersion", fw.rulesVersion, + ) } else { fw.Conntrack = conntrack } @@ -403,10 +408,11 @@ func (f *Interface) reloadFirewall(c *config.C) { f.firewall = fw oldFw.Destroy() - f.l.WithField("firewallHashes", fw.GetRuleHashes()). - WithField("oldFirewallHashes", oldFw.GetRuleHashes()). - WithField("rulesVersion", fw.rulesVersion). - Info("New firewall has been installed") + f.l.Info("New firewall has been installed", + "firewallHashes", fw.GetRuleHashes(), + "oldFirewallHashes", oldFw.GetRuleHashes(), + "rulesVersion", fw.rulesVersion, + ) } func (f *Interface) reloadSendRecvError(c *config.C) { @@ -428,8 +434,7 @@ func (f *Interface) reloadSendRecvError(c *config.C) { } } - f.l.WithField("sendRecvError", f.sendRecvErrorConfig.String()). - Info("Loaded send_recv_error config") + f.l.Info("Loaded send_recv_error config", "sendRecvError", f.sendRecvErrorConfig.String()) } } @@ -452,8 +457,7 @@ func (f *Interface) reloadAcceptRecvError(c *config.C) { } } - f.l.WithField("acceptRecvError", f.acceptRecvErrorConfig.String()). - Info("Loaded accept_recv_error config") + f.l.Info("Loaded accept_recv_error config", "acceptRecvError", f.acceptRecvErrorConfig.String()) } } @@ -527,7 +531,7 @@ func (f *Interface) Close() error { for i, u := range f.writers { err := u.Close() if err != nil { - f.l.WithError(err).WithField("writer", i).Error("Error while closing udp socket") + f.l.Error("Error while closing udp socket", "error", err, "writer", i) errs = append(errs, err) } } diff --git a/lighthouse.go b/lighthouse.go index 50140e9e..6034e68c 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "fmt" + "log/slog" "net" "net/netip" "slices" @@ -15,10 +16,10 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/logging" "github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/util" ) @@ -76,12 +77,12 @@ type LightHouse struct { metrics *MessageMetrics metricHolepunchTx metrics.Counter - l *logrus.Logger + l *slog.Logger } // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object // addrMap should be nil unless this is during a config reload -func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) { +func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) { amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) nebulaPort := uint32(c.GetInt("listen.port", 0)) if amLighthouse && nebulaPort == 0 { @@ -133,7 +134,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, case *util.ContextualError: v.Log(l) case error: - l.WithError(err).Error("failed to reload lighthouse") + l.Error("failed to reload lighthouse", "error", err) } }) @@ -205,8 +206,10 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { //TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used addr := addrs[0].Unmap() if lh.myVpnNetworksTable.Contains(addr) { - lh.l.WithField("addr", rawAddr).WithField("entry", i+1). - Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range") + lh.l.Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range", + "addr", rawAddr, + "entry", i+1, + ) continue } @@ -224,7 +227,9 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { lh.interval.Store(int64(c.GetInt("lighthouse.interval", 10))) if !initial { - lh.l.Infof("lighthouse.interval changed to %v", lh.interval.Load()) + lh.l.Info("lighthouse.interval changed", + "interval", lh.interval.Load(), + ) if lh.updateCancel != nil { // May not always have a running routine @@ -336,9 +341,12 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { for _, v := range c.GetStringSlice("relay.relays", nil) { configRIP, err := netip.ParseAddr(v) if err != nil { - lh.l.WithField("relay", v).WithError(err).Warn("Parse relay from config failed") + lh.l.Warn("Parse relay from config failed", + "relay", v, + "error", err, + ) } else { - lh.l.WithField("relay", v).Info("Read relay from config") + lh.l.Info("Read relay from config", "relay", v) relaysForMe = append(relaysForMe, configRIP) } } @@ -363,8 +371,10 @@ func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) { } if !lh.myVpnNetworksTable.Contains(addr) { - lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}). - Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not") + lh.l.Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not", + "vpnAddr", addr, + "networks", lh.myVpnNetworks, + ) } out[i] = addr } @@ -435,8 +445,11 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc } if !lh.myVpnNetworksTable.Contains(vpnAddr) { - lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}). - Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work") + lh.l.Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work", + "vpnAddr", vpnAddr, + "networks", lh.myVpnNetworks, + "entry", i+1, + ) } vals, ok := v.([]any) @@ -537,12 +550,13 @@ func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) { lh.Lock() rm, ok := lh.addrMap[allVpnAddrs[0]] if ok { + debugEnabled := lh.l.Enabled(context.Background(), slog.LevelDebug) for _, addr := range allVpnAddrs { srm := lh.addrMap[addr] if srm == rm { delete(lh.addrMap, addr) - if lh.l.Level >= logrus.DebugLevel { - lh.l.Debugf("deleting %s from lighthouse.", addr) + if debugEnabled { + lh.l.Debug("deleting from lighthouse", "vpnAddr", addr) } } } @@ -659,9 +673,12 @@ func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList { func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool { allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to) - if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow). - Trace("remoteAllowList.Allow") + if lh.l.Enabled(context.Background(), logging.LevelTrace) { + lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow", + "vpnAddrs", vpnAddrs, + "udpAddr", to, + "allow", allow, + ) } if !allow { return false @@ -678,9 +695,12 @@ func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool { func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bool { udpAddr := protoV4AddrPortToNetAddrPort(to) allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr()) - if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow). - Trace("remoteAllowList.Allow") + if lh.l.Enabled(context.Background(), logging.LevelTrace) { + lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow", + "vpnAddr", vpnAddr, + "udpAddr", udpAddr, + "allow", allow, + ) } if !allow { @@ -698,9 +718,12 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bo func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bool { udpAddr := protoV6AddrPortToNetAddrPort(to) allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr()) - if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow). - Trace("remoteAllowList.Allow") + if lh.l.Enabled(context.Background(), logging.LevelTrace) { + lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow", + "vpnAddr", vpnAddr, + "udpAddr", udpAddr, + "allow", allow, + ) } if !allow { @@ -775,8 +798,10 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { if v == cert.Version1 { if !addr.Is4() { - lh.l.WithField("queryVpnAddr", addr).WithField("lighthouseAddr", lhVpnAddr). - Error("Can't query lighthouse for v6 address using a v1 protocol") + lh.l.Error("Can't query lighthouse for v6 address using a v1 protocol", + "queryVpnAddr", addr, + "lighthouseAddr", lhVpnAddr, + ) continue } @@ -787,9 +812,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { v1Query, err = msg.Marshal() if err != nil { - lh.l.WithError(err).WithField("queryVpnAddr", addr). - WithField("lighthouseAddr", lhVpnAddr). - Error("Failed to marshal lighthouse v1 query payload") + lh.l.Error("Failed to marshal lighthouse v1 query payload", + "error", err, + "queryVpnAddr", addr, + "lighthouseAddr", lhVpnAddr, + ) continue } } @@ -804,9 +831,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { v2Query, err = msg.Marshal() if err != nil { - lh.l.WithError(err).WithField("queryVpnAddr", addr). - WithField("lighthouseAddr", lhVpnAddr). - Error("Failed to marshal lighthouse v2 query payload") + lh.l.Error("Failed to marshal lighthouse v2 query payload", + "error", err, + "queryVpnAddr", addr, + "lighthouseAddr", lhVpnAddr, + ) continue } } @@ -815,7 +844,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { queried++ } else { - lh.l.Debugf("Can not query lighthouse for %v using unknown protocol version: %v", addr, v) + lh.l.Debug("unsupported protocol version", + "op", "query", + "queryVpnAddr", addr, + "version", v, + ) continue } } @@ -907,8 +940,9 @@ func (lh *LightHouse) SendUpdate() { if v == cert.Version1 { if v1Update == nil { if !lh.myVpnNetworks[0].Addr().Is4() { - lh.l.WithField("lighthouseAddr", lhVpnAddr). - Warn("cannot update lighthouse using v1 protocol without an IPv4 address") + lh.l.Warn("cannot update lighthouse using v1 protocol without an IPv4 address", + "lighthouseAddr", lhVpnAddr, + ) continue } var relays []uint32 @@ -932,8 +966,10 @@ func (lh *LightHouse) SendUpdate() { v1Update, err = msg.Marshal() if err != nil { - lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr). - Error("Error while marshaling for lighthouse v1 update") + lh.l.Error("Error while marshaling for lighthouse v1 update", + "error", err, + "lighthouseAddr", lhVpnAddr, + ) continue } } @@ -959,8 +995,10 @@ func (lh *LightHouse) SendUpdate() { v2Update, err = msg.Marshal() if err != nil { - lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr). - Error("Error while marshaling for lighthouse v2 update") + lh.l.Error("Error while marshaling for lighthouse v2 update", + "error", err, + "lighthouseAddr", lhVpnAddr, + ) continue } } @@ -969,7 +1007,10 @@ func (lh *LightHouse) SendUpdate() { updated++ } else { - lh.l.Debugf("Can not update lighthouse using unknown protocol version: %v", v) + lh.l.Debug("unsupported protocol version", + "op", "update", + "version", v, + ) continue } } @@ -983,7 +1024,7 @@ type LightHouseHandler struct { out []byte pb []byte meta *NebulaMeta - l *logrus.Logger + l *slog.Logger } func (lh *LightHouse) NewRequestHandler() *LightHouseHandler { @@ -1032,14 +1073,19 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [ n := lhh.resetMeta() err := n.Unmarshal(p) if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). - Error("Failed to unmarshal lighthouse packet") + lhh.l.Error("Failed to unmarshal lighthouse packet", + "error", err, + "vpnAddrs", fromVpnAddrs, + "udpAddr", rAddr, + ) return } if n.Details == nil { - lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). - Error("Invalid lighthouse update") + lhh.l.Error("Invalid lighthouse update", + "vpnAddrs", fromVpnAddrs, + "udpAddr", rAddr, + ) return } @@ -1067,25 +1113,29 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) { // Exit if we don't answer queries if !lhh.lh.amLighthouse { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugln("I don't answer queries, but received from: ", addr) + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("I don't answer queries, but received one", "from", addr) } return } queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion() if err != nil { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details). - Debugln("Dropping malformed HostQuery") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("Dropping malformed HostQuery", + "from", fromVpnAddrs, + "details", n.Details, + ) } return } if useVersion == cert.Version1 && queryVpnAddr.Is6() { // this case really shouldn't be possible to represent, but reject it anyway. - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr). - Debugln("invalid vpn addr for v1 handleHostQuery") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("invalid vpn addr for v1 handleHostQuery", + "vpnAddrs", fromVpnAddrs, + "queryVpnAddr", queryVpnAddr, + ) } return } @@ -1110,7 +1160,10 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti } if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host query reply") + lhh.l.Error("Failed to marshal lighthouse host query reply", + "error", err, + "vpnAddrs", fromVpnAddrs, + ) return } @@ -1138,8 +1191,10 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd if ok { whereToPunch = newDest } else { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("unable to punch to host, no addresses in common", + "to", crt.Networks(), + ) } } } @@ -1165,7 +1220,10 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd } if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host was queried for") + lhh.l.Error("Failed to marshal lighthouse host was queried for", + "error", err, + "vpnAddrs", fromVpnAddrs, + ) return } @@ -1207,8 +1265,11 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r)) } } else { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("version", v).Debug("unsupported protocol version") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("unsupported protocol version", + "op", "coalesceAnswers", + "version", v, + ) } } } @@ -1221,8 +1282,11 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [ certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion() if err != nil { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Error("dropping malformed HostQueryReply", + "error", err, + "vpnAddrs", fromVpnAddrs, + ) } return } @@ -1247,8 +1311,8 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { if !lhh.lh.amLighthouse { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs) + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("I am not a lighthouse, do not take host updates", "from", fromVpnAddrs) } return } @@ -1271,8 +1335,11 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp //Simple check that the host sent this not someone else, if detailsVpnAddr is filled if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("Host sent invalid update", + "vpnAddrs", fromVpnAddrs, + "answer", detailsVpnAddr, + ) } return } @@ -1294,7 +1361,9 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp switch useVersion { case cert.Version1: if !fromVpnAddrs[0].Is4() { - lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message") + lhh.l.Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message", + "vpnAddrs", fromVpnAddrs, + ) return } vpnAddrB := fromVpnAddrs[0].As4() @@ -1302,13 +1371,16 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp case cert.Version2: // do nothing, we want to send a blank message default: - lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version") + lhh.l.Error("invalid protocol version", "useVersion", useVersion) return } ln, err := n.MarshalTo(lhh.pb) if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host update ack") + lhh.l.Error("Failed to marshal lighthouse host update ack", + "error", err, + "vpnAddrs", fromVpnAddrs, + ) return } @@ -1325,8 +1397,11 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion() if err != nil { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification") + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("dropping invalid HostPunchNotification", + "details", n.Details, + "error", err, + ) } return } @@ -1343,8 +1418,11 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn lhh.lh.punchConn.WriteTo(empty, vpnPeer) }() - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr) + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("Punching", + "vpnPeer", vpnPeer, + "logVpnAddr", logVpnAddr, + ) } } @@ -1369,8 +1447,10 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn if lhh.lh.punchy.GetRespond() { go func() { time.Sleep(lhh.lh.punchy.GetRespondDelay()) - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr) + if lhh.l.Enabled(context.Background(), slog.LevelDebug) { + lhh.l.Debug("Sending a nebula test packet", + "vpnAddr", detailsVpnAddr, + ) } //NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine // for each punchBack packet. We should move this into a timerwheel or a single goroutine diff --git a/logger.go b/logger.go deleted file mode 100644 index aaf6f29c..00000000 --- a/logger.go +++ /dev/null @@ -1,45 +0,0 @@ -package nebula - -import ( - "fmt" - "strings" - "time" - - "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/config" -) - -func configLogger(l *logrus.Logger, c *config.C) error { - // set up our logging level - logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info"))) - if err != nil { - return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels) - } - l.SetLevel(logLevel) - - disableTimestamp := c.GetBool("logging.disable_timestamp", false) - timestampFormat := c.GetString("logging.timestamp_format", "") - fullTimestamp := (timestampFormat != "") - if timestampFormat == "" { - timestampFormat = time.RFC3339 - } - - logFormat := strings.ToLower(c.GetString("logging.format", "text")) - switch logFormat { - case "text": - l.Formatter = &logrus.TextFormatter{ - TimestampFormat: timestampFormat, - FullTimestamp: fullTimestamp, - DisableTimestamp: disableTimestamp, - } - case "json": - l.Formatter = &logrus.JSONFormatter{ - TimestampFormat: timestampFormat, - DisableTimestamp: disableTimestamp, - } - default: - return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"}) - } - - return nil -} diff --git a/logging/logger.go b/logging/logger.go new file mode 100644 index 00000000..bbc10bb3 --- /dev/null +++ b/logging/logger.go @@ -0,0 +1,233 @@ +// Package logging wires the nebula runtime-reconfigurable slog handler used +// by nebula.Main and the nebula CLI binaries. Callers build a logger with +// NewLogger, then call ApplyConfig at startup and from a config reload +// callback to push logging.level, logging.format, and +// logging.disable_timestamp changes onto the logger without rebuilding it. +package logging + +import ( + "context" + "fmt" + "io" + "log/slog" + "strings" + "sync/atomic" + "time" +) + +// Config is the subset of *config.C that ApplyConfig reads. Declaring it +// here keeps the logging package from depending on config directly, which +// would cycle through the shared test helpers (test.NewLogger imports +// logging, and config's tests import test). *config.C satisfies this +// interface structurally with no adapter. +type Config interface { + GetString(key, def string) string + GetBool(key string, def bool) bool +} + +// LevelTrace is a custom slog level below Debug, used when logging.level is +// "trace". slog has no builtin trace level; the value is one step below +// slog.LevelDebug in slog's 4-point spacing. +const LevelTrace = slog.Level(-8) + +// NewLogger returns a *slog.Logger whose level, format, and timestamp +// emission can be reconfigured at runtime via ApplyConfig and the SSH debug +// commands. The default configuration is info-level text output so log +// calls made before ApplyConfig runs still produce output. Timestamps +// follow slog's default RFC3339Nano format; set logging.disable_timestamp +// in config to suppress them. +// +// ApplyConfig and the SSH commands discover the reconfig surface via +// structural type-assertion on l.Handler(), so replacement implementations +// (tests, platform-specific sinks) need only implement the subset of +// {SetLevel(slog.Level), SetFormat(string) error, SetDisableTimestamp(bool)} +// they care about. Callers that pass a plain *slog.Logger without these +// methods get a silent no-op; reconfiguration is always opt-in. +func NewLogger(w io.Writer) *slog.Logger { + return slog.New(NewHandler(w)) +} + +// NewHandler builds the *Handler that NewLogger wraps. Exported for +// platform-specific sinks (notably cmd/nebula-service/logs_windows.go) +// that want to wrap the handler with extra behavior, such as tagging each +// record with its Event Log severity, while still benefiting from all the +// level / format / timestamp / WithAttrs machinery implemented here. +func NewHandler(w io.Writer) *Handler { + root := &handlerRoot{} + root.level.Set(slog.LevelInfo) + opts := &slog.HandlerOptions{Level: &root.level} + return &Handler{ + root: root, + text: slog.NewTextHandler(w, opts), + json: slog.NewJSONHandler(w, opts), + } +} + +// handlerRoot carries the reconfiguration state shared by every logger +// derived from a NewHandler call. All fields are consulted on the log +// path and updated lock-free. +type handlerRoot struct { + level slog.LevelVar + disableTimestamp atomic.Bool + // jsonMode picks which of the pre-derived inner handlers Handler.Handle + // dispatches to. Flipping it propagates instantly to every derived logger + // without rebuilding or chain-replaying anything. + jsonMode atomic.Bool +} + +// Handler is the slog.Handler returned by NewHandler. It holds two +// pre-derived slog handlers -- one text, one json -- both built from the +// same accumulated WithAttrs/WithGroup state. Handle picks which one to +// dispatch to based on handlerRoot.jsonMode, so a SetFormat call takes +// effect immediately across the whole process without having to rebuild +// any derived loggers. +type Handler struct { + root *handlerRoot + text slog.Handler + json slog.Handler +} + +func (h *Handler) Enabled(_ context.Context, l slog.Level) bool { + return h.root.level.Level() <= l +} + +func (h *Handler) Handle(ctx context.Context, r slog.Record) error { + if h.root.disableTimestamp.Load() { + r.Time = time.Time{} + } + if h.root.jsonMode.Load() { + return h.json.Handle(ctx, r) + } + return h.text.Handle(ctx, r) +} + +func (h *Handler) WithAttrs(attrs []slog.Attr) slog.Handler { + if len(attrs) == 0 { + return h + } + return &Handler{ + root: h.root, + text: h.text.WithAttrs(attrs), + json: h.json.WithAttrs(attrs), + } +} + +func (h *Handler) WithGroup(name string) slog.Handler { + if name == "" { + return h + } + return &Handler{ + root: h.root, + text: h.text.WithGroup(name), + json: h.json.WithGroup(name), + } +} + +// SetLevel updates the effective log level. Propagates to every derived +// logger via the shared LevelVar. +func (h *Handler) SetLevel(level slog.Level) { h.root.level.Set(level) } + +// GetLevel reports the current log level. +func (h *Handler) GetLevel() slog.Level { return h.root.level.Level() } + +// SetFormat flips the output format atomically. Valid formats are "text" +// and "json". Every derived logger sees the new format on its next Handle +// call; no rebuild or registration is required. +func (h *Handler) SetFormat(format string) error { + switch format { + case "text": + h.root.jsonMode.Store(false) + case "json": + h.root.jsonMode.Store(true) + default: + return fmt.Errorf("unknown log format `%s`. possible formats: %s", format, []string{"text", "json"}) + } + return nil +} + +// GetFormat reports the currently selected format name. +func (h *Handler) GetFormat() string { + if h.root.jsonMode.Load() { + return "json" + } + return "text" +} + +// SetDisableTimestamp toggles whether Handle zeroes r.Time before +// dispatching (slog's builtin text/json handlers skip emitting the time +// attribute on a zero time). +func (h *Handler) SetDisableTimestamp(v bool) { h.root.disableTimestamp.Store(v) } + +// ApplyConfig reads logging.level, logging.format, and (optionally) +// logging.disable_timestamp from c and applies them to l. The reconfig +// surface is discovered via structural type-assertion on l.Handler(), so +// foreign handlers silently opt out of whichever capabilities they do not +// implement. +// +// nebula.Main does NOT call this function on your behalf; callers that want +// config-driven log level / format / timestamp updates invoke it at +// startup and register it as a reload callback themselves. This keeps the +// library from mutating an embedder's logger without their say-so. +func ApplyConfig(l *slog.Logger, c Config) error { + h := l.Handler() + + lvl, err := ParseLevel(strings.ToLower(c.GetString("logging.level", "info"))) + if err != nil { + return err + } + if ls, ok := h.(interface{ SetLevel(slog.Level) }); ok { + ls.SetLevel(lvl) + } + + format := strings.ToLower(c.GetString("logging.format", "text")) + if fs, ok := h.(interface{ SetFormat(string) error }); ok { + if err := fs.SetFormat(format); err != nil { + return err + } + } + + if ts, ok := h.(interface{ SetDisableTimestamp(bool) }); ok { + ts.SetDisableTimestamp(c.GetBool("logging.disable_timestamp", false)) + } + return nil +} + +// ParseLevel converts a config-string level name ("trace", "debug", "info", +// "warn"/"warning", "error", "fatal"/"panic") to a slog.Level. "fatal" and +// "panic" are accepted for backwards compatibility with pre-slog configs +// and both map to slog.LevelError. +func ParseLevel(s string) (slog.Level, error) { + switch s { + case "trace": + return LevelTrace, nil + case "debug": + return slog.LevelDebug, nil + case "info": + return slog.LevelInfo, nil + case "warn", "warning": + return slog.LevelWarn, nil + case "error": + return slog.LevelError, nil + case "fatal", "panic": + return slog.LevelError, nil + default: + return 0, fmt.Errorf("not a valid logging level: %q", s) + } +} + +// LevelName returns a human-readable name for a slog.Level matching the +// strings accepted by ParseLevel. +func LevelName(l slog.Level) string { + switch { + case l <= LevelTrace: + return "trace" + case l <= slog.LevelDebug: + return "debug" + case l <= slog.LevelInfo: + return "info" + case l <= slog.LevelWarn: + return "warn" + default: + return "error" + } +} diff --git a/logging/logger_bench_test.go b/logging/logger_bench_test.go new file mode 100644 index 00000000..eb29c1c3 --- /dev/null +++ b/logging/logger_bench_test.go @@ -0,0 +1,90 @@ +package logging + +import ( + "context" + "io" + "log/slog" + "testing" +) + +// BenchmarkLogger_* compare the handler returned by NewLogger against a +// stock slog text handler. The key thing we care about is the per-log +// cost on a logger that has been derived via .With(), because that is the +// shape subsystems store on their structs (HostInfo.logger(), +// lh.l.With("subsystem", ...), etc.) and call from hot paths. + +func BenchmarkLogger_Stock_RootInfo(b *testing.B) { + l := slog.New(slog.DiscardHandler) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + l.Info("hello", "i", i) + } +} + +func BenchmarkLogger_Nebula_RootInfo(b *testing.B) { + l := NewLogger(io.Discard) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + l.Info("hello", "i", i) + } +} + +func BenchmarkLogger_Stock_DerivedInfo(b *testing.B) { + l := slog.New(slog.DiscardHandler).With( + "subsystem", "bench", + "localIndex", 1234, + ) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + l.Info("hello", "i", i) + } +} + +func BenchmarkLogger_Nebula_DerivedInfo(b *testing.B) { + l := NewLogger(io.Discard).With( + "subsystem", "bench", + "localIndex", 1234, + ) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + l.Info("hello", "i", i) + } +} + +// Gated-off-path benchmarks: mimic the typical hot-path shape +// `if l.Enabled(ctx, slog.LevelDebug) { ... }` where the log is gated below +// the active level. This is the dominant pattern in inside.go/outside.go and +// what we pay on every packet. +func BenchmarkLogger_Stock_DerivedEnabledGateMiss(b *testing.B) { + l := slog.New(slog.DiscardHandler).With( + "subsystem", "bench", + "localIndex", 1234, + ) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if l.Enabled(ctx, slog.LevelDebug) { + l.Debug("hello", "i", i) + } + } +} + +func BenchmarkLogger_Nebula_DerivedEnabledGateMiss(b *testing.B) { + l := NewLogger(io.Discard).With( + "subsystem", "bench", + "localIndex", 1234, + ) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if l.Enabled(ctx, slog.LevelDebug) { + l.Debug("hello", "i", i) + } + } +} diff --git a/main.go b/main.go index 0ac63dfa..f692f317 100644 --- a/main.go +++ b/main.go @@ -3,13 +3,13 @@ package nebula import ( "context" "fmt" + "log/slog" "net" "net/netip" "runtime/debug" "strings" "time" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/sshd" @@ -20,7 +20,7 @@ import ( type m = map[string]any -func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) { +func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) { ctx, cancel := context.WithCancel(context.Background()) // Automatically cancel the context if Main returns an error, to signal all created goroutines to quit. defer func() { @@ -33,11 +33,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg buildVersion = moduleVersion() } - l := logger - l.Formatter = &logrus.TextFormatter{ - FullTimestamp: true, - } - // Print the config if in test, the exit comes later if configTest { b, err := yaml.Marshal(c.Settings) @@ -46,21 +41,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } // Print the final config - l.Println(string(b)) + l.Info(string(b)) } - err := configLogger(l, c) - if err != nil { - return nil, util.ContextualizeIfNeeded("Failed to configure the logger", err) - } - - c.RegisterReloadCallback(func(c *config.C) { - err := configLogger(l, c) - if err != nil { - l.WithError(err).Error("Failed to configure the logger") - } - }) - pki, err := NewPKIFromConfig(l, c) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err) @@ -70,9 +53,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if err != nil { return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err) } - l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started") + l.Info("Firewall started", "firewallHashes", fw.GetRuleHashes()) - ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) + ssh, err := sshd.NewSSHServer(l.With("subsystem", "sshd")) if err != nil { return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err) } @@ -81,7 +64,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if c.GetBool("sshd.enabled", false) { sshStart, err = configSSH(l, ssh, c) if err != nil { - l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available") + l.Warn("Failed to configure sshd, ssh debugging will not be available", "error", err) sshStart = nil } } @@ -99,7 +82,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg routines = 1 } if routines > 1 { - l.WithField("routines", routines).Info("Using multiple routines") + l.Info("Using multiple routines", "routines", routines) } } else { // deprecated and undocumented @@ -107,7 +90,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg udpQueues := c.GetInt("listen.routines", 1) routines = max(tunQueues, udpQueues) if routines != 1 { - l.WithField("routines", routines).Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead") + l.Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead", "routines", routines) } } @@ -120,7 +103,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg conntrackCacheTimeout = 1 * time.Second } if conntrackCacheTimeout > 0 { - l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache") + l.Info("Using routine-local conntrack cache", "duration", conntrackCacheTimeout) } var tun overlay.Device @@ -166,7 +149,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } for i := 0; i < routines; i++ { - l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port))) + l.Info("listening", "addr", netip.AddrPortFrom(listenHost, uint16(port))) udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64)) if err != nil { return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) @@ -217,7 +200,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg ds, err := newDnsServerFromConfig(ctx, l, pki.getCertState(), hostMap, c) if err != nil { - l.WithError(err).Warn("Failed to start DNS responder") + l.Warn("Failed to start DNS responder", "error", err) } ifConfig := &InterfaceConfig{ diff --git a/outside.go b/outside.go index eba9d887..1e00a0a9 100644 --- a/outside.go +++ b/outside.go @@ -1,15 +1,16 @@ package nebula import ( + "context" "encoding/binary" "errors" + "log/slog" "net/netip" "time" "github.com/google/gopacket/layers" "golang.org/x/net/ipv6" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "golang.org/x/net/ipv4" @@ -24,7 +25,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, if err != nil { // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors if len(packet) > 1 { - f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err) + f.l.Info("Error while parsing inbound packet", + "from", via, + "error", err, + "packet", packet, + ) } return } @@ -32,8 +37,8 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, //l.Error("in packet ", header, packet[HeaderLen:]) if !via.IsRelayed { if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) { - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("from", via).Debug("Refusing to process double encrypted packet") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Refusing to process double encrypted packet", "from", via) } return } @@ -87,7 +92,10 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, if !ok { // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing // its internal mapping. This should never happen. - hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") + hostinfo.logger(f.l).Error("HostInfo missing remote relay index", + "vpnAddrs", hostinfo.vpnAddrs, + "remoteIndex", h.RemoteIndex, + ) return } @@ -108,7 +116,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, // Find the target HostInfo relay object targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) if err != nil { - hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip") + hostinfo.logger(f.l).Info("Failed to find target host info by ip", + "relayTo", relay.PeerAddr, + "error", err, + "hostinfo.vpnAddrs", hostinfo.vpnAddrs, + ) return } @@ -124,7 +136,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") } } else { - hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") + hostinfo.logger(f.l).Info("Unexpected target relay state", + "relayTo", relay.PeerAddr, + "relayFrom", hostinfo.vpnAddrs[0], + "targetRelayState", targetRelay.State, + ) return } } @@ -138,9 +154,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("from", via). - WithField("packet", packet). - Error("Failed to decrypt lighthouse packet") + hostinfo.logger(f.l).Error("Failed to decrypt lighthouse packet", + "error", err, + "from", via, + "packet", packet, + ) return } @@ -157,9 +175,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("from", via). - WithField("packet", packet). - Error("Failed to decrypt test packet") + hostinfo.logger(f.l).Error("Failed to decrypt test packet", + "error", err, + "from", via, + "packet", packet, + ) return } @@ -192,14 +212,15 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, } _, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("from", via). - WithField("packet", packet). - Error("Failed to decrypt CloseTunnel packet") + hostinfo.logger(f.l).Error("Failed to decrypt CloseTunnel packet", + "error", err, + "from", via, + "packet", packet, + ) return } - hostinfo.logger(f.l).WithField("from", via). - Info("Close tunnel received, tearing down.") + hostinfo.logger(f.l).Info("Close tunnel received, tearing down.", "from", via) f.closeTunnel(hostinfo) return @@ -211,9 +232,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("from", via). - WithField("packet", packet). - Error("Failed to decrypt Control packet") + hostinfo.logger(f.l).Error("Failed to decrypt Control packet", + "error", err, + "from", via, + "packet", packet, + ) return } @@ -221,7 +244,9 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, default: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Unexpected packet received", "from", via) + } return } @@ -247,20 +272,27 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) { func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) { if !via.IsRelayed && hostinfo.remote != via.UdpAddr { if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) { - hostinfo.logger(f.l).WithField("newAddr", via.UdpAddr).Debug("lighthouse.remote_allow_list denied roaming") - return - } - - if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { - if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr). - Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("lighthouse.remote_allow_list denied roaming", "newAddr", via.UdpAddr) } return } - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr). - Info("Host roamed to new udp ip/port.") + if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Suppressing roam back to previous remote", + "suppressSeconds", RoamingSuppressSeconds, + "udpAddr", hostinfo.remote, + "newAddr", via.UdpAddr, + ) + } + return + } + + hostinfo.logger(f.l).Info("Host roamed to new udp ip/port.", + "udpAddr", hostinfo.remote, + "newAddr", via.UdpAddr, + ) hostinfo.lastRoam = time.Now() hostinfo.lastRoamRemote = hostinfo.remote hostinfo.SetRemote(via.UdpAddr) @@ -491,8 +523,9 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] } if !hostinfo.ConnectionState.window.Update(f.l, mc) { - hostinfo.logger(f.l).WithField("header", h). - Debugln("dropping out of window packet") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("dropping out of window packet", "header", h) + } return nil, errors.New("out of window packet") } @@ -504,20 +537,23 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") + hostinfo.logger(f.l).Error("Failed to decrypt packet", "error", err) return false } err = newPacket(out, true, fwPacket) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("packet", out). - Warnf("Error while validating inbound packet") + hostinfo.logger(f.l).Warn("Error while validating inbound packet", + "error", err, + "packet", out, + ) return false } if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) { - hostinfo.logger(f.l).WithField("fwPacket", fwPacket). - Debugln("dropping out of window packet") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("dropping out of window packet", "fwPacket", fwPacket) + } return false } @@ -526,10 +562,11 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // This gives us a buffer to build the reject packet in f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q) - if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).WithField("fwPacket", fwPacket). - WithField("reason", dropReason). - Debugln("dropping inbound packet") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("dropping inbound packet", + "fwPacket", fwPacket, + "reason", dropReason, + ) } return false } @@ -537,7 +574,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out f.connectionManager.In(hostinfo) _, err = f.readers[q].Write(out) if err != nil { - f.l.WithError(err).Error("Failed to write to tun") + f.l.Error("Failed to write to tun", "error", err) } return true } @@ -553,35 +590,41 @@ func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) { b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0) _ = f.outside.WriteTo(b, endpoint) - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("index", index). - WithField("udpAddr", endpoint). - Debug("Recv error sent") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Recv error sent", + "index", index, + "udpAddr", endpoint, + ) } } func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) { if !f.acceptRecvErrorConfig.ShouldRecvError(addr) { - f.l.WithField("index", h.RemoteIndex). - WithField("udpAddr", addr). - Debug("Recv error received, ignoring") + f.l.Debug("Recv error received, ignoring", + "index", h.RemoteIndex, + "udpAddr", addr, + ) return } - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("index", h.RemoteIndex). - WithField("udpAddr", addr). - Debug("Recv error received") + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Recv error received", + "index", h.RemoteIndex, + "udpAddr", addr, + ) } hostinfo := f.hostMap.QueryReverseIndex(h.RemoteIndex) if hostinfo == nil { - f.l.WithField("remoteIndex", h.RemoteIndex).Debugln("Did not find remote index in main hostmap") + f.l.Debug("Did not find remote index in main hostmap", "remoteIndex", h.RemoteIndex) return } if hostinfo.remote.IsValid() && hostinfo.remote != addr { - f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote) + f.l.Info("Someone spoofing recv_errors?", + "addr", addr, + "hostinfoRemote", hostinfo.remote, + ) return } diff --git a/test/tun.go b/overlay/overlaytest/noop.go similarity index 68% rename from test/tun.go rename to overlay/overlaytest/noop.go index fb32782f..956da7dd 100644 --- a/test/tun.go +++ b/overlay/overlaytest/noop.go @@ -1,4 +1,6 @@ -package test +// Package overlaytest provides fakes of overlay.Device for tests that do +// not want to touch a real tun device or route table. +package overlaytest import ( "errors" @@ -8,6 +10,9 @@ import ( "github.com/slackhq/nebula/routing" ) +// NoopTun is an overlay.Device that silently discards every read and write. +// Useful in tests that need to construct a nebula Interface but do not +// exercise the datapath. type NoopTun struct{} func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways { diff --git a/overlay/route.go b/overlay/route.go index 61989581..c6403f91 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -2,6 +2,7 @@ package overlay import ( "fmt" + "log/slog" "math" "net" "net/netip" @@ -9,7 +10,6 @@ import ( "strconv" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" ) @@ -48,11 +48,14 @@ func (r Route) String() string { return s } -func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) { +func makeRouteTree(l *slog.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) { routeTree := new(bart.Table[routing.Gateways]) for _, r := range routes { if !allowMTU && r.MTU > 0 { - l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS) + l.Warn("route MTU is not supported on this platform", + "goos", runtime.GOOS, + "route", r, + ) } gateways := r.Via diff --git a/overlay/route_test.go b/overlay/route_test.go index 9a959a55..f9d9dcd9 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -295,7 +295,7 @@ func Test_makeRouteTree(t *testing.T) { routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) require.NoError(t, err) assert.Len(t, routes, 2) - routeTree, err := makeRouteTree(l, routes, true) + routeTree, err := makeRouteTree(test.NewLogger(), routes, true) require.NoError(t, err) ip, err := netip.ParseAddr("1.0.0.2") @@ -367,7 +367,7 @@ func Test_makeMultipathUnsafeRouteTree(t *testing.T) { routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) require.NoError(t, err) assert.Len(t, routes, 3) - routeTree, err := makeRouteTree(l, routes, true) + routeTree, err := makeRouteTree(test.NewLogger(), routes, true) require.NoError(t, err) ip, err := netip.ParseAddr("192.168.86.1") diff --git a/overlay/tun.go b/overlay/tun.go index e0bf69f6..3af1e189 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -2,10 +2,10 @@ package overlay import ( "fmt" + "log/slog" "net" "net/netip" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/util" ) @@ -22,9 +22,9 @@ func (e *NameError) Error() string { } // TODO: We may be able to remove routines -type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) +type DeviceFactory func(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) -func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { +func NewDeviceFromConfig(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { switch { case c.GetBool("tun.disabled", false): tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) @@ -36,7 +36,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Pref } func NewFdDeviceFromConfig(fd *int) DeviceFactory { - return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { + return func(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { return newTunFromFd(c, l, *fd, vpnNetworks) } } diff --git a/overlay/tun_android.go b/overlay/tun_android.go index eddef882..9cbb64be 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -6,12 +6,12 @@ package overlay import ( "fmt" "io" + "log/slog" "net/netip" "os" "sync/atomic" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -23,10 +23,10 @@ type tun struct { vpnNetworks []netip.Prefix Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + l *slog.Logger } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // Be sure not to call file.Fd() as it will set the fd to blocking mode. file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") @@ -53,7 +53,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net return t, nil } -func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in Android") } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 128c2001..524ef0cd 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/netip" "os" "sync/atomic" @@ -14,7 +15,6 @@ import ( "unsafe" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -30,7 +30,7 @@ type tun struct { Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] linkAddr *netroute.LinkAddr - l *logrus.Logger + l *slog.Logger // cache out buffer since we need to prepend 4 bytes for tun metadata out []byte @@ -79,7 +79,7 @@ type ifreqAlias6 struct { Lifetime addrLifetime } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { name := c.GetString("tun.dev", "") ifIndex := -1 if name != "" && name != "utun" { @@ -153,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) { return } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in Darwin") } @@ -389,8 +389,7 @@ func (t *tun) addRoutes(logErrors bool) error { err := addRoute(r.Cidr, t.linkAddr) if err != nil { if errors.Is(err, unix.EEXIST) { - t.l.WithField("route", r.Cidr). - Warnf("unable to add unsafe_route, identical route already exists") + t.l.Warn("unable to add unsafe_route, identical route already exists", "route", r.Cidr) } else { retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) if logErrors { @@ -400,7 +399,7 @@ func (t *tun) addRoutes(logErrors bool) error { } } } else { - t.l.WithField("route", r).Info("Added route") + t.l.Info("Added route", "route", r) } } @@ -415,9 +414,9 @@ func (t *tun) removeRoutes(routes []Route) error { err := delRoute(r.Cidr, t.linkAddr) if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + t.l.Error("Failed to remove route", "error", err, "route", r) } else { - t.l.WithField("route", r).Info("Removed route") + t.l.Info("Removed route", "route", r) } } return nil diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index aa3dddaf..f47880dd 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -1,13 +1,14 @@ package overlay import ( + "context" "fmt" "io" + "log/slog" "net/netip" "strings" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/routing" ) @@ -19,10 +20,10 @@ type disabledTun struct { // Track these metrics since we don't have the tun device to do it for us tx metrics.Counter rx metrics.Counter - l *logrus.Logger + l *slog.Logger } -func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { +func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *slog.Logger) *disabledTun { tun := &disabledTun{ vpnNetworks: vpnNetworks, read: make(chan []byte, queueLen), @@ -67,8 +68,8 @@ func (t *disabledTun) Read(b []byte) (int, error) { } t.tx.Inc(1) - if t.l.Level >= logrus.DebugLevel { - t.l.WithField("raw", prettyPacket(r)).Debugf("Write payload") + if t.l.Enabled(context.Background(), slog.LevelDebug) { + t.l.Debug("Write payload", "raw", prettyPacket(r)) } return copy(b, r), nil @@ -85,7 +86,7 @@ func (t *disabledTun) handleICMPEchoRequest(b []byte) bool { select { case t.read <- out: default: - t.l.Debugf("tun_disabled: dropped ICMP Echo Reply response") + t.l.Debug("tun_disabled: dropped ICMP Echo Reply response") } return true @@ -96,11 +97,11 @@ func (t *disabledTun) Write(b []byte) (int, error) { // Check for ICMP Echo Request before spending time doing the full parsing if t.handleICMPEchoRequest(b) { - if t.l.Level >= logrus.DebugLevel { - t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request") + if t.l.Enabled(context.Background(), slog.LevelDebug) { + t.l.Debug("Disabled tun responded to ICMP Echo Request", "raw", prettyPacket(b)) } - } else if t.l.Level >= logrus.DebugLevel { - t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload") + } else if t.l.Enabled(context.Background(), slog.LevelDebug) { + t.l.Debug("Disabled tun received unexpected payload", "raw", prettyPacket(b)) } return len(b), nil } diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 91c51159..3d995553 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "io/fs" + "log/slog" "net/netip" "os" "sync/atomic" @@ -17,8 +18,9 @@ import ( "unsafe" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" @@ -93,7 +95,7 @@ type tun struct { Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] linkAddr *netroute.LinkAddr - l *logrus.Logger + l *slog.Logger fd int shutdownR int // read end of the shutdown pipe; closing the write end wakes blocked polls @@ -243,7 +245,7 @@ func (t *tun) Close() error { if t.fd >= 0 { if err := unix.Close(t.fd); err != nil { - t.l.WithError(err).Error("Error closing device") + t.l.Error("Error closing device", "error", err) } t.fd = -1 } @@ -264,7 +266,7 @@ func (t *tun) Close() error { err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq))) } if err != nil { - t.l.WithError(err).Error("Error destroying tunnel") + t.l.Error("Error destroying tunnel", "error", err) } }() @@ -277,11 +279,11 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open existing tun device var fd int var err error @@ -584,7 +586,7 @@ func (t *tun) addRoutes(logErrors bool) error { return retErr } } else { - t.l.WithField("route", r).Info("Added route") + t.l.Info("Added route", "route", r) } } @@ -599,9 +601,9 @@ func (t *tun) removeRoutes(routes []Route) error { err := delRoute(r.Cidr, t.linkAddr) if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + t.l.Error("Failed to remove route", "error", err, "route", r) } else { - t.l.WithField("route", r).Info("Removed route") + t.l.Info("Removed route", "route", r) } } return nil diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 0ce01df8..6bfcbdfb 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/netip" "os" "sync" @@ -14,7 +15,6 @@ import ( "syscall" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -25,14 +25,14 @@ type tun struct { vpnNetworks []netip.Prefix Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + l *slog.Logger } -func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in iOS") } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/tun") t := &tun{ vpnNetworks: vpnNetworks, diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 2830ff6b..c6cfb686 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -7,6 +7,7 @@ import ( "encoding/binary" "fmt" "io" + "log/slog" "net" "net/netip" "os" @@ -17,7 +18,6 @@ import ( "unsafe" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -213,7 +213,7 @@ type tun struct { routesFromSystem map[netip.Prefix]routing.Gateways routesFromSystemLock sync.Mutex - l *logrus.Logger + l *slog.Logger } func (t *tun) Networks() []netip.Prefix { @@ -238,7 +238,7 @@ type ifreqQLEN struct { pad [8]byte } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { t, err := newTunGeneric(c, l, deviceFd, vpnNetworks) if err != nil { return nil, err @@ -249,7 +249,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net return t, nil } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { // If /dev/net/tun doesn't exist, try to create it (will happen in docker) @@ -299,7 +299,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu } // newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error. -func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunGeneric(c *config.C, l *slog.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) { tfd, err := newTunFd(fd) if err != nil { _ = unix.Close(fd) @@ -378,16 +378,16 @@ func (t *tun) reload(c *config.C, initial bool) error { if !initial { if oldMaxMTU != newMaxMTU { t.setMTU() - t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU) + t.l.Info("Set max MTU", "mtu", t.MaxMTU, "oldMTU", oldMaxMTU) } if oldDefaultMTU != newDefaultMTU { for i := range t.vpnNetworks { err := t.setDefaultRoute(t.vpnNetworks[i]) if err != nil { - t.l.Warn(err) + t.l.Warn(err.Error()) } else { - t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU) + t.l.Info("Set default MTU", "mtu", t.DefaultMTU, "oldMTU", oldDefaultMTU) } } } @@ -492,9 +492,9 @@ func (t *tun) addIPs(link netlink.Link) error { } err = netlink.AddrDel(link, &al[i]) if err != nil { - t.l.WithError(err).Error("failed to remove address from tun address list") + t.l.Error("failed to remove address from tun address list", "error", err) } else { - t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)") + t.l.Info("removed address not listed in cert(s)", "removed", al[i].String()) } } @@ -538,12 +538,12 @@ func (t *tun) Activate() error { ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)} if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil { // If we can't set the queue length nebula will still work but it may lead to packet loss - t.l.WithError(err).Error("Failed to set tun tx queue length") + t.l.Error("Failed to set tun tx queue length", "error", err) } const modeNone = 1 if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil { - t.l.WithError(err).Warn("Failed to disable link local address generation") + t.l.Warn("Failed to disable link local address generation", "error", err) } if err = t.addIPs(link); err != nil { @@ -582,7 +582,7 @@ func (t *tun) setMTU() { ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)} if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil { // This is currently a non fatal condition because the route table must have the MTU set appropriately as well - t.l.WithError(err).Error("Failed to set tun mtu") + t.l.Error("Failed to set tun mtu", "error", err) } } @@ -605,7 +605,7 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error { } err := netlink.RouteReplace(&nr) if err != nil { - t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying") + t.l.Warn("Failed to set default route MTU, retrying", "error", err, "cidr", cidr) //retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument` for i := 0; i < 2; i++ { time.Sleep(100 * time.Millisecond) @@ -613,7 +613,11 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error { if err == nil { break } else { - t.l.WithError(err).WithField("cidr", cidr).WithField("mtu", t.DefaultMTU).Warn("Failed to set default route MTU, retrying") + t.l.Warn("Failed to set default route MTU, retrying", + "error", err, + "cidr", cidr, + "mtu", t.DefaultMTU, + ) } } if err != nil { @@ -658,7 +662,7 @@ func (t *tun) addRoutes(logErrors bool) error { return retErr } } else { - t.l.WithField("route", r).Info("Added route") + t.l.Info("Added route", "route", r) } } @@ -690,9 +694,9 @@ func (t *tun) removeRoutes(routes []Route) { err := netlink.RouteDel(&nr) if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + t.l.Error("Failed to remove route", "error", err, "route", r) } else { - t.l.WithField("route", r).Info("Removed route") + t.l.Info("Removed route", "route", r) } } } @@ -721,11 +725,11 @@ func (t *tun) watchRoutes() { netlinkOptions := netlink.RouteSubscribeOptions{ ReceiveBufferSize: t.useSystemRoutesBufferSize, ReceiveBufferForceSize: t.useSystemRoutesBufferSize != 0, - ErrorCallback: func(e error) { t.l.WithError(e).Errorf("netlink error") }, + ErrorCallback: func(e error) { t.l.Error("netlink error", "error", e) }, } if err := netlink.RouteSubscribeWithOptions(rch, doneChan, netlinkOptions); err != nil { - t.l.WithError(err).Errorf("failed to subscribe to system route changes") + t.l.Error("failed to subscribe to system route changes", "error", err) return } @@ -767,7 +771,7 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways { link, err := netlink.LinkByName(t.Device) if err != nil { - t.l.WithField("deviceName", t.Device).Error("Ignoring route update: failed to get link by name") + t.l.Error("Ignoring route update: failed to get link by name", "deviceName", t.Device) return gateways } @@ -779,10 +783,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways { gateways = append(gateways, routing.NewGateway(gwAddr, 1)) } else { // Gateway isn't in our overlay network, ignore - t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network") + t.l.Debug("Ignoring route update, gateway is not in our network", "route", r) } } else { - t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address") + t.l.Debug("Ignoring route update, invalid gateway or via address", "route", r) } } @@ -795,10 +799,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways { gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1)) } else { // Gateway isn't in our overlay network, ignore - t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network") + t.l.Debug("Ignoring route update, gateway is not in our network", "route", r) } } else { - t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address") + t.l.Debug("Ignoring route update, invalid gateway or via address", "route", r) } } } @@ -830,18 +834,18 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { gateways := t.getGatewaysFromRoute(&r.Route) if len(gateways) == 0 { // No gateways relevant to our network, no routing changes required. - t.l.WithField("route", r).Debug("Ignoring route update, no gateways") + t.l.Debug("Ignoring route update, no gateways", "route", r) return } if r.Dst == nil { - t.l.WithField("route", r).Debug("Ignoring route update, no destination address") + t.l.Debug("Ignoring route update, no destination address", "route", r) return } dstAddr, ok := netip.AddrFromSlice(r.Dst.IP) if !ok { - t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address") + t.l.Debug("Ignoring route update, invalid destination address", "route", r) return } @@ -852,12 +856,12 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { t.routesFromSystemLock.Lock() if r.Type == unix.RTM_NEWROUTE { - t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route") + t.l.Info("Adding route", "destination", dst, "via", gateways) t.routesFromSystem[dst] = gateways newTree.Insert(dst, gateways) } else { - t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route") + t.l.Info("Removing route", "destination", dst, "via", gateways) delete(t.routesFromSystem, dst) newTree.Delete(dst) } @@ -888,18 +892,18 @@ func (t *tun) Close() error { } err := t.readers[i].Close() if err != nil { - t.l.WithField("reader", i).WithError(err).Error("error closing tun reader") + t.l.Error("error closing tun reader", "reader", i, "error", err) } else { - t.l.WithField("reader", i).Info("closed tun reader") + t.l.Info("closed tun reader", "reader", i) } } //this is t.readers[0] too err := t.tunFile.Close() if err != nil { - t.l.WithField("reader", 0).WithError(err).Error("error closing tun reader") + t.l.Error("error closing tun reader", "reader", 0, "error", err) } else { - t.l.WithField("reader", 0).Info("closed tun reader") + t.l.Info("closed tun reader", "reader", 0) } return err } diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 2986c895..c971bb6e 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/netip" "os" "regexp" @@ -15,7 +16,6 @@ import ( "unsafe" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -63,18 +63,18 @@ type tun struct { MTU int Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + l *slog.Logger f *os.File fd int } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var err error deviceName := c.GetString("tun.dev", "") @@ -92,7 +92,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( err = unix.SetNonblock(fd, true) if err != nil { - l.WithError(err).Warn("Failed to set the tun device as nonblocking") + l.Warn("Failed to set the tun device as nonblocking", "error", err) } t := &tun{ @@ -416,7 +416,7 @@ func (t *tun) addRoutes(logErrors bool) error { return retErr } } else { - t.l.WithField("route", r).Info("Added route") + t.l.Info("Added route", "route", r) } } @@ -431,9 +431,9 @@ func (t *tun) removeRoutes(routes []Route) error { err := delRoute(r.Cidr, t.vpnNetworks) if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + t.l.Error("Failed to remove route", "error", err, "route", r) } else { - t.l.WithField("route", r).Info("Removed route") + t.l.Info("Removed route", "route", r) } } return nil diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 9209b795..81362184 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/netip" "os" "regexp" @@ -15,7 +16,6 @@ import ( "unsafe" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -54,7 +54,7 @@ type tun struct { MTU int Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + l *slog.Logger f *os.File fd int // cache out buffer since we need to prepend 4 bytes for tun metadata @@ -63,11 +63,11 @@ type tun struct { var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in openbsd") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var err error deviceName := c.GetString("tun.dev", "") @@ -85,7 +85,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( err = unix.SetNonblock(fd, true) if err != nil { - l.WithError(err).Warn("Failed to set the tun device as nonblocking") + l.Warn("Failed to set the tun device as nonblocking", "error", err) } t := &tun{ @@ -336,7 +336,7 @@ func (t *tun) addRoutes(logErrors bool) error { return retErr } } else { - t.l.WithField("route", r).Info("Added route") + t.l.Info("Added route", "route", r) } } @@ -351,9 +351,9 @@ func (t *tun) removeRoutes(routes []Route) error { err := delRoute(r.Cidr, t.vpnNetworks) if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + t.l.Error("Failed to remove route", "error", err, "route", r) } else { - t.l.WithField("route", r).Info("Removed route") + t.l.Info("Removed route", "route", r) } } return nil diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 3477de3d..b2c2a0ea 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -4,14 +4,15 @@ package overlay import ( + "context" "fmt" "io" + "log/slog" "net/netip" "os" "sync/atomic" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" ) @@ -21,14 +22,14 @@ type TestTun struct { vpnNetworks []netip.Prefix Routes []Route routeTree *bart.Table[routing.Gateways] - l *logrus.Logger + l *slog.Logger closed atomic.Bool rxPackets chan []byte // Packets to receive into nebula TxPackets chan []byte // Packets transmitted outside by nebula } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { _, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true) if err != nil { return nil, err @@ -49,7 +50,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( }, nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) { +func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*TestTun, error) { return nil, fmt.Errorf("newTunFromFd not supported") } @@ -61,8 +62,8 @@ func (t *TestTun) Send(packet []byte) { return } - if t.l.Level >= logrus.DebugLevel { - t.l.WithField("dataLen", len(packet)).Debug("Tun receiving injected packet") + if t.l.Enabled(context.Background(), slog.LevelDebug) { + t.l.Debug("Tun receiving injected packet", "dataLen", len(packet)) } t.rxPackets <- packet } diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 223eabee..680dddb3 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -7,6 +7,7 @@ import ( "crypto" "fmt" "io" + "log/slog" "net/netip" "os" "path/filepath" @@ -16,7 +17,6 @@ import ( "unsafe" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" @@ -33,16 +33,16 @@ type winTun struct { MTU int Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + l *slog.Logger tun *wintun.NativeTun } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) { +func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (Device, error) { return nil, fmt.Errorf("newTunFromFd not supported in Windows") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) { +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) { err := checkWinTunExists() if err != nil { return nil, fmt.Errorf("can not load the wintun driver: %w", err) @@ -71,7 +71,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( if err != nil { // Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device. // Trying a second time resolves the issue. - l.WithError(err).Debug("Failed to create wintun device, retrying") + l.Debug("Failed to create wintun device, retrying", "error", err) tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) if err != nil { return nil, &NameError{ @@ -170,7 +170,7 @@ func (t *winTun) addRoutes(logErrors bool) error { return retErr } } else { - t.l.WithField("route", r).Info("Added route") + t.l.Info("Added route", "route", r) } if !foundDefault4 { @@ -208,9 +208,9 @@ func (t *winTun) removeRoutes(routes []Route) error { // See comment on luid.AddRoute err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr()) if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + t.l.Error("Failed to remove route", "error", err, "route", r) } else { - t.l.WithField("route", r).Info("Removed route") + t.l.Info("Removed route", "route", r) } } return nil diff --git a/overlay/user.go b/overlay/user.go index 1f92d4e9..e5f27f37 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -2,14 +2,14 @@ package overlay import ( "io" + "log/slog" "net/netip" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" ) -func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { +func NewUserDeviceFromConfig(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { return NewUserDevice(vpnNetworks) } diff --git a/pki.go b/pki.go index 0639fd3d..fb8cc5c6 100644 --- a/pki.go +++ b/pki.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net" "net/netip" "os" @@ -15,7 +16,6 @@ import ( "time" "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/util" @@ -24,7 +24,7 @@ import ( type PKI struct { cs atomic.Pointer[CertState] caPool atomic.Pointer[cert.CAPool] - l *logrus.Logger + l *slog.Logger } type CertState struct { @@ -46,7 +46,7 @@ type CertState struct { myVpnBroadcastAddrsTable *bart.Lite } -func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { +func NewPKIFromConfig(l *slog.Logger, c *config.C) (*PKI, error) { pki := &PKI{l: l} err := pki.reload(c, true) if err != nil { @@ -182,9 +182,9 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { p.cs.Store(newState) if initial { - p.l.WithField("cert", newState).Debug("Client nebula certificate(s)") + p.l.Debug("Client nebula certificate(s)", "cert", newState) } else { - p.l.WithField("cert", newState).Info("Client certificate(s) refreshed from disk") + p.l.Info("Client certificate(s) refreshed from disk", "cert", newState) } return nil } @@ -196,7 +196,7 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError { } p.caPool.Store(caPool) - p.l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints") + p.l.Debug("Trusted CA fingerprints", "fingerprints", caPool.GetFingerprints()) return nil } @@ -487,7 +487,7 @@ func loadCertificate(b []byte) (cert.Certificate, []byte, error) { return c, b, nil } -func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { +func loadCAPoolFromConfig(l *slog.Logger, c *config.C) (*cert.CAPool, error) { caPathOrPEM := c.GetString("pki.ca", "") if caPathOrPEM == "" { return nil, errors.New("no pki.ca path or PEM data provided") @@ -512,7 +512,7 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { for _, crt := range caPool.CAs { if crt.Certificate.Expired(time.Now()) { expired++ - l.WithField("cert", crt).Warn("expired certificate present in CA pool") + l.Warn("expired certificate present in CA pool", "cert", crt) } } @@ -530,7 +530,7 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { caPool.BlocklistFingerprint(fp) } - l.WithField("fingerprintCount", len(bl)).Info("Blocklisted certificates") + l.Info("Blocklisted certificates", "fingerprintCount", len(bl)) } return caPool, nil diff --git a/pki_hup_benchmark_test.go b/pki_hup_benchmark_test.go index 39f648ff..bca23d78 100644 --- a/pki_hup_benchmark_test.go +++ b/pki_hup_benchmark_test.go @@ -41,7 +41,7 @@ func BenchmarkReloadConfigWithCAs(b *testing.B) { c := config.NewC(l) require.NoError(b, c.Load(dir)) - _, err := NewPKIFromConfig(l, c) + _, err := NewPKIFromConfig(test.NewLogger(), c) require.NoError(b, err) b.ReportAllocs() diff --git a/punchy.go b/punchy.go index 2034405a..6ecf4f85 100644 --- a/punchy.go +++ b/punchy.go @@ -1,10 +1,10 @@ package nebula import ( + "log/slog" "sync/atomic" "time" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" ) @@ -14,10 +14,10 @@ type Punchy struct { delay atomic.Int64 respondDelay atomic.Int64 punchEverything atomic.Bool - l *logrus.Logger + l *slog.Logger } -func NewPunchyFromConfig(l *logrus.Logger, c *config.C) *Punchy { +func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy { p := &Punchy{l: l} p.reload(c, true) @@ -62,7 +62,7 @@ func (p *Punchy) reload(c *config.C, initial bool) { p.respond.Store(yes) if !initial { - p.l.Infof("punchy.respond changed to %v", p.GetRespond()) + p.l.Info("punchy.respond changed", "respond", p.GetRespond()) } } @@ -70,21 +70,21 @@ func (p *Punchy) reload(c *config.C, initial bool) { if initial || c.HasChanged("punchy.delay") { p.delay.Store((int64)(c.GetDuration("punchy.delay", time.Second))) if !initial { - p.l.Infof("punchy.delay changed to %s", p.GetDelay()) + p.l.Info("punchy.delay changed", "delay", p.GetDelay()) } } if initial || c.HasChanged("punchy.target_all_remotes") { p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false)) if !initial { - p.l.WithField("target_all_remotes", p.GetTargetEverything()).Info("punchy.target_all_remotes changed") + p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", p.GetTargetEverything()) } } if initial || c.HasChanged("punchy.respond_delay") { p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second))) if !initial { - p.l.Infof("punchy.respond_delay changed to %s", p.GetRespondDelay()) + p.l.Info("punchy.respond_delay changed", "respond_delay", p.GetRespondDelay()) } } } diff --git a/punchy_test.go b/punchy_test.go index 56dd1c25..cbf9b17b 100644 --- a/punchy_test.go +++ b/punchy_test.go @@ -1,6 +1,8 @@ package nebula import ( + "context" + "log/slog" "testing" "time" @@ -15,7 +17,7 @@ func TestNewPunchyFromConfig(t *testing.T) { c := config.NewC(l) // Test defaults - p := NewPunchyFromConfig(l, c) + p := NewPunchyFromConfig(test.NewLogger(), c) assert.False(t, p.GetPunch()) assert.False(t, p.GetRespond()) assert.Equal(t, time.Second, p.GetDelay()) @@ -23,33 +25,33 @@ func TestNewPunchyFromConfig(t *testing.T) { // punchy deprecation c.Settings["punchy"] = true - p = NewPunchyFromConfig(l, c) + p = NewPunchyFromConfig(test.NewLogger(), c) assert.True(t, p.GetPunch()) // punchy.punch c.Settings["punchy"] = map[string]any{"punch": true} - p = NewPunchyFromConfig(l, c) + p = NewPunchyFromConfig(test.NewLogger(), c) assert.True(t, p.GetPunch()) // punch_back deprecation c.Settings["punch_back"] = true - p = NewPunchyFromConfig(l, c) + p = NewPunchyFromConfig(test.NewLogger(), c) assert.True(t, p.GetRespond()) // punchy.respond c.Settings["punchy"] = map[string]any{"respond": true} c.Settings["punch_back"] = false - p = NewPunchyFromConfig(l, c) + p = NewPunchyFromConfig(test.NewLogger(), c) assert.True(t, p.GetRespond()) // punchy.delay c.Settings["punchy"] = map[string]any{"delay": "1m"} - p = NewPunchyFromConfig(l, c) + p = NewPunchyFromConfig(test.NewLogger(), c) assert.Equal(t, time.Minute, p.GetDelay()) // punchy.respond_delay c.Settings["punchy"] = map[string]any{"respond_delay": "1m"} - p = NewPunchyFromConfig(l, c) + p = NewPunchyFromConfig(test.NewLogger(), c) assert.Equal(t, time.Minute, p.GetRespondDelay()) } @@ -62,7 +64,7 @@ punchy: delay: 1m respond: false `)) - p := NewPunchyFromConfig(l, c) + p := NewPunchyFromConfig(test.NewLogger(), c) assert.Equal(t, delay, p.GetDelay()) assert.False(t, p.GetRespond()) @@ -76,3 +78,158 @@ punchy: assert.Equal(t, newDelay, p.GetDelay()) assert.True(t, p.GetRespond()) } + +// The tests below pin the shape of each log line Punchy produces so changes +// cannot silently break whatever operators are grepping for. The assertions +// are on the structured message + attrs (e.g. "punchy.respond changed" with +// a respond=true field) rather than a formatted string. +// +// Punchy.reload also emits a spurious "Changing punchy.punch with reload is +// not supported" warning whenever any key under punchy changes, because of +// the c.HasChanged("punchy") fallback kept for the deprecated top-level +// punchy form. The tests filter by message rather than asserting total +// entry counts so that warning is tolerated without being locked into +// the format. + +type capturedEntry struct { + Level slog.Level + Msg string + Attrs map[string]any +} + +// capturingHandler is a slog.Handler that records each Record it receives so +// tests can assert on the level, message, and attribute map of individual log +// lines without coupling to any specific text format. +type capturingHandler struct { + entries []capturedEntry +} + +func (h *capturingHandler) Enabled(_ context.Context, _ slog.Level) bool { return true } + +func (h *capturingHandler) Handle(_ context.Context, r slog.Record) error { + e := capturedEntry{ + Level: r.Level, + Msg: r.Message, + Attrs: make(map[string]any), + } + r.Attrs(func(a slog.Attr) bool { + e.Attrs[a.Key] = a.Value.Resolve().Any() + return true + }) + h.entries = append(h.entries, e) + return nil +} + +func (h *capturingHandler) WithAttrs(_ []slog.Attr) slog.Handler { return h } +func (h *capturingHandler) WithGroup(_ string) slog.Handler { return h } + +func newCapturingPunchyLogger(t *testing.T) (*slog.Logger, *capturingHandler) { + t.Helper() + hook := &capturingHandler{} + return slog.New(hook), hook +} + +func findEntry(t *testing.T, entries []capturedEntry, msg string) capturedEntry { + t.Helper() + for _, e := range entries { + if e.Msg == msg { + return e + } + } + t.Fatalf("no entry with message %q among %d entries", msg, len(entries)) + return capturedEntry{} +} + +func TestPunchy_LogFormat_InitialEnabled(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {punch: true}`)) + + NewPunchyFromConfig(l, c) + + entry := findEntry(t, hook.entries, "punchy enabled") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Empty(t, entry.Attrs) +} + +func TestPunchy_LogFormat_InitialDisabled(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {punch: false}`)) + + NewPunchyFromConfig(l, c) + + entry := findEntry(t, hook.entries, "punchy disabled") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Empty(t, entry.Attrs) +} + +func TestPunchy_LogFormat_ReloadPunchUnsupported(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {punch: false}`)) + NewPunchyFromConfig(l, c) + hook.entries = nil + + require.NoError(t, c.ReloadConfigString(`punchy: {punch: true}`)) + + entry := findEntry(t, hook.entries, "Changing punchy.punch with reload is not supported, ignoring.") + assert.Equal(t, slog.LevelWarn, entry.Level) + assert.Empty(t, entry.Attrs) +} + +func TestPunchy_LogFormat_ReloadRespond(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {respond: false}`)) + NewPunchyFromConfig(l, c) + hook.entries = nil + + require.NoError(t, c.ReloadConfigString(`punchy: {respond: true}`)) + + entry := findEntry(t, hook.entries, "punchy.respond changed") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Equal(t, map[string]any{"respond": true}, entry.Attrs) +} + +func TestPunchy_LogFormat_ReloadDelay(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {delay: 1s}`)) + NewPunchyFromConfig(l, c) + hook.entries = nil + + require.NoError(t, c.ReloadConfigString(`punchy: {delay: 10s}`)) + + entry := findEntry(t, hook.entries, "punchy.delay changed") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Equal(t, map[string]any{"delay": 10 * time.Second}, entry.Attrs) +} + +func TestPunchy_LogFormat_ReloadTargetAllRemotes(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {target_all_remotes: false}`)) + NewPunchyFromConfig(l, c) + hook.entries = nil + + require.NoError(t, c.ReloadConfigString(`punchy: {target_all_remotes: true}`)) + + entry := findEntry(t, hook.entries, "punchy.target_all_remotes changed") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Equal(t, map[string]any{"target_all_remotes": true}, entry.Attrs) +} + +func TestPunchy_LogFormat_ReloadRespondDelay(t *testing.T) { + l, hook := newCapturingPunchyLogger(t) + c := config.NewC(test.NewLogger()) + require.NoError(t, c.LoadString(`punchy: {respond_delay: 5s}`)) + NewPunchyFromConfig(l, c) + hook.entries = nil + + require.NoError(t, c.ReloadConfigString(`punchy: {respond_delay: 15s}`)) + + entry := findEntry(t, hook.entries, "punchy.respond_delay changed") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Equal(t, map[string]any{"respond_delay": 15 * time.Second}, entry.Attrs) +} diff --git a/relay_manager.go b/relay_manager.go index 91640f24..919bb2b6 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -5,22 +5,22 @@ import ( "encoding/binary" "errors" "fmt" + "log/slog" "net/netip" "sync/atomic" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" ) type relayManager struct { - l *logrus.Logger + l *slog.Logger hostmap *HostMap amRelay atomic.Bool } -func NewRelayManager(ctx context.Context, l *logrus.Logger, hostmap *HostMap, c *config.C) *relayManager { +func NewRelayManager(ctx context.Context, l *slog.Logger, hostmap *HostMap, c *config.C) *relayManager { rm := &relayManager{ l: l, hostmap: hostmap, @@ -29,7 +29,7 @@ func NewRelayManager(ctx context.Context, l *logrus.Logger, hostmap *HostMap, c c.RegisterReloadCallback(func(c *config.C) { err := rm.reload(c, false) if err != nil { - l.WithError(err).Error("Failed to reload relay_manager") + rm.l.Error("Failed to reload relay_manager", "error", err) } }) return rm @@ -52,7 +52,7 @@ func (rm *relayManager) setAmRelay(v bool) { // AddRelay finds an available relay index on the hostmap, and associates the relay info with it. // relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp. -func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) { +func AddRelay(l *slog.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) { hm.Lock() defer hm.Unlock() for range 32 { @@ -92,24 +92,24 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) { relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex) if !ok { - fields := logrus.Fields{ - "relay": relayHostInfo.vpnAddrs[0], - "initiatorRelayIndex": m.InitiatorRelayIndex, - } - + var relayFrom, relayTo any if m.RelayFromAddr == nil { - fields["relayFrom"] = m.OldRelayFromAddr + relayFrom = m.OldRelayFromAddr } else { - fields["relayFrom"] = m.RelayFromAddr + relayFrom = m.RelayFromAddr } - if m.RelayToAddr == nil { - fields["relayTo"] = m.OldRelayToAddr + relayTo = m.OldRelayToAddr } else { - fields["relayTo"] = m.RelayToAddr + relayTo = m.RelayToAddr } - rm.l.WithFields(fields).Info("relayManager failed to update relay") + rm.l.Info("relayManager failed to update relay", + "relay", relayHostInfo.vpnAddrs[0], + "initiatorRelayIndex", m.InitiatorRelayIndex, + "relayFrom", relayFrom, + "relayTo", relayTo, + ) return nil, fmt.Errorf("unknown relay") } @@ -120,7 +120,7 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) { msg := &NebulaControl{} err := msg.Unmarshal(d) if err != nil { - h.logger(f.l).WithError(err).Error("Failed to unmarshal control message") + h.logger(f.l).Error("Failed to unmarshal control message", "error", err) return } @@ -147,20 +147,20 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) { } func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) { - rm.l.WithFields(logrus.Fields{ - "relayFrom": protoAddrToNetAddr(m.RelayFromAddr), - "relayTo": protoAddrToNetAddr(m.RelayToAddr), - "initiatorRelayIndex": m.InitiatorRelayIndex, - "responderRelayIndex": m.ResponderRelayIndex, - "vpnAddrs": h.vpnAddrs}). - Info("handleCreateRelayResponse") + rm.l.Info("handleCreateRelayResponse", + "relayFrom", protoAddrToNetAddr(m.RelayFromAddr), + "relayTo", protoAddrToNetAddr(m.RelayToAddr), + "initiatorRelayIndex", m.InitiatorRelayIndex, + "responderRelayIndex", m.ResponderRelayIndex, + "vpnAddrs", h.vpnAddrs, + ) target := m.RelayToAddr targetAddr := protoAddrToNetAddr(target) relay, err := rm.EstablishRelay(h, m) if err != nil { - rm.l.WithError(err).Error("Failed to update relay for relayTo") + rm.l.Error("Failed to update relay for relayTo", "error", err) return } // Do I need to complete the relays now? @@ -170,12 +170,12 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f // I'm the middle man. Let the initiator know that the I've established the relay they requested. peerHostInfo := rm.hostmap.QueryVpnAddr(relay.PeerAddr) if peerHostInfo == nil { - rm.l.WithField("relayTo", relay.PeerAddr).Error("Can't find a HostInfo for peer") + rm.l.Error("Can't find a HostInfo for peer", "relayTo", relay.PeerAddr) return } peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr) if !ok { - rm.l.WithField("relayTo", peerHostInfo.vpnAddrs[0]).Error("peerRelay does not have Relay state for relayTo") + rm.l.Error("peerRelay does not have Relay state for relayTo", "relayTo", peerHostInfo.vpnAddrs[0]) return } switch peerRelay.State { @@ -193,12 +193,13 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f if v == cert.Version1 { peer := peerHostInfo.vpnAddrs[0] if !peer.Is4() { - rm.l.WithField("relayFrom", peer). - WithField("relayTo", target). - WithField("initiatorRelayIndex", resp.InitiatorRelayIndex). - WithField("responderRelayIndex", resp.ResponderRelayIndex). - WithField("vpnAddrs", peerHostInfo.vpnAddrs). - Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address") + rm.l.Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address", + "relayFrom", peer, + "relayTo", target, + "initiatorRelayIndex", resp.InitiatorRelayIndex, + "responderRelayIndex", resp.ResponderRelayIndex, + "vpnAddrs", peerHostInfo.vpnAddrs, + ) return } @@ -213,17 +214,16 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f msg, err := resp.Marshal() if err != nil { - rm.l.WithError(err). - Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") + rm.l.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err) } else { f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - rm.l.WithFields(logrus.Fields{ - "relayFrom": resp.RelayFromAddr, - "relayTo": resp.RelayToAddr, - "initiatorRelayIndex": resp.InitiatorRelayIndex, - "responderRelayIndex": resp.ResponderRelayIndex, - "vpnAddrs": peerHostInfo.vpnAddrs}). - Info("send CreateRelayResponse") + rm.l.Info("send CreateRelayResponse", + "relayFrom", resp.RelayFromAddr, + "relayTo", resp.RelayToAddr, + "initiatorRelayIndex", resp.InitiatorRelayIndex, + "responderRelayIndex", resp.ResponderRelayIndex, + "vpnAddrs", peerHostInfo.vpnAddrs, + ) } } } @@ -232,17 +232,18 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f from := protoAddrToNetAddr(m.RelayFromAddr) target := protoAddrToNetAddr(m.RelayToAddr) - logMsg := rm.l.WithFields(logrus.Fields{ - "relayFrom": from, - "relayTo": target, - "initiatorRelayIndex": m.InitiatorRelayIndex, - "vpnAddrs": h.vpnAddrs}) + logMsg := rm.l.With( + "relayFrom", from, + "relayTo", target, + "initiatorRelayIndex", m.InitiatorRelayIndex, + "vpnAddrs", h.vpnAddrs, + ) logMsg.Info("handleCreateRelayRequest") // Is the source of the relay me? This should never happen, but did happen due to // an issue migrating relays over to newly re-handshaked host info objects. if f.myVpnAddrsTable.Contains(from) { - logMsg.WithField("myIP", from).Error("Discarding relay request from myself") + logMsg.Error("Discarding relay request from myself", "myIP", from) return } @@ -261,37 +262,37 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f if existingRelay.RemoteIndex != m.InitiatorRelayIndex { // We got a brand new Relay request, because its index is different than what we saw before. // This should never happen. The peer should never change an index, once created. - logMsg.WithFields(logrus.Fields{ - "existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") + logMsg.Error("Existing relay mismatch with CreateRelayRequest", + "existingRemoteIndex", existingRelay.RemoteIndex) return } case Disestablished: if existingRelay.RemoteIndex != m.InitiatorRelayIndex { // We got a brand new Relay request, because its index is different than what we saw before. // This should never happen. The peer should never change an index, once created. - logMsg.WithFields(logrus.Fields{ - "existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") + logMsg.Error("Existing relay mismatch with CreateRelayRequest", + "existingRemoteIndex", existingRelay.RemoteIndex) return } // Mark the relay as 'Established' because it's safe to use again h.relayState.UpdateRelayForByIpState(from, Established) case PeerRequested: // I should never be in this state, because I am terminal, not forwarding. - logMsg.WithFields(logrus.Fields{ - "existingRemoteIndex": existingRelay.RemoteIndex, - "state": existingRelay.State}).Error("Unexpected Relay State found") + logMsg.Error("Unexpected Relay State found", + "existingRemoteIndex", existingRelay.RemoteIndex, + "state", existingRelay.State) } } else { _, err := AddRelay(rm.l, h, f.hostMap, from, &m.InitiatorRelayIndex, TerminalType, Established) if err != nil { - logMsg.WithError(err).Error("Failed to add relay") + logMsg.Error("Failed to add relay", "error", err) return } } relay, ok := h.relayState.QueryRelayForByIp(from) if !ok { - logMsg.WithField("from", from).Error("Relay State not found") + logMsg.Error("Relay State not found", "from", from) return } @@ -313,17 +314,16 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f msg, err := resp.Marshal() if err != nil { - logMsg. - WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") + logMsg.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err) } else { f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) - rm.l.WithFields(logrus.Fields{ - "relayFrom": from, - "relayTo": target, - "initiatorRelayIndex": resp.InitiatorRelayIndex, - "responderRelayIndex": resp.ResponderRelayIndex, - "vpnAddrs": h.vpnAddrs}). - Info("send CreateRelayResponse") + rm.l.Info("send CreateRelayResponse", + "relayFrom", from, + "relayTo", target, + "initiatorRelayIndex", resp.InitiatorRelayIndex, + "responderRelayIndex", resp.ResponderRelayIndex, + "vpnAddrs", h.vpnAddrs, + ) } return } else { @@ -363,12 +363,13 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f if v == cert.Version1 { if !h.vpnAddrs[0].Is4() { - rm.l.WithField("relayFrom", h.vpnAddrs[0]). - WithField("relayTo", target). - WithField("initiatorRelayIndex", req.InitiatorRelayIndex). - WithField("responderRelayIndex", req.ResponderRelayIndex). - WithField("vpnAddr", target). - Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address") + rm.l.Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address", + "relayFrom", h.vpnAddrs[0], + "relayTo", target, + "initiatorRelayIndex", req.InitiatorRelayIndex, + "responderRelayIndex", req.ResponderRelayIndex, + "vpnAddr", target, + ) return } @@ -383,17 +384,16 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f msg, err := req.Marshal() if err != nil { - logMsg. - WithError(err).Error("relayManager Failed to marshal Control message to create relay") + logMsg.Error("relayManager Failed to marshal Control message to create relay", "error", err) } else { f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) - rm.l.WithFields(logrus.Fields{ - "relayFrom": h.vpnAddrs[0], - "relayTo": target, - "initiatorRelayIndex": req.InitiatorRelayIndex, - "responderRelayIndex": req.ResponderRelayIndex, - "vpnAddr": target}). - Info("send CreateRelayRequest") + rm.l.Info("send CreateRelayRequest", + "relayFrom", h.vpnAddrs[0], + "relayTo", target, + "initiatorRelayIndex", req.InitiatorRelayIndex, + "responderRelayIndex", req.ResponderRelayIndex, + "vpnAddr", target, + ) } // Also track the half-created Relay state just received @@ -401,8 +401,7 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f if !ok { _, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, PeerRequested) if err != nil { - logMsg. - WithError(err).Error("relayManager Failed to allocate a local index for relay") + logMsg.Error("relayManager Failed to allocate a local index for relay", "error", err) return } } diff --git a/remote_list.go b/remote_list.go index 8338d517..7b95de87 100644 --- a/remote_list.go +++ b/remote_list.go @@ -2,6 +2,7 @@ package nebula import ( "context" + "log/slog" "net" "net/netip" "slices" @@ -10,8 +11,6 @@ import ( "sync" "sync/atomic" "time" - - "github.com/sirupsen/logrus" ) // forEachFunc is used to benefit folks that want to do work inside the lock @@ -66,11 +65,11 @@ type hostnamesResults struct { network string lookupTimeout time.Duration cancelFn func() - l *logrus.Logger + l *slog.Logger ips atomic.Pointer[map[netip.AddrPort]struct{}] } -func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) { +func NewHostnameResults(ctx context.Context, l *slog.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) { r := &hostnamesResults{ hostnames: make([]hostnamePort, len(hostPorts)), network: network, @@ -121,7 +120,11 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, r.network, hostPort.name) timeoutCancel() if err != nil { - l.WithFields(logrus.Fields{"hostname": hostPort.name, "network": r.network}).WithError(err).Error("DNS resolution failed for static_map host") + l.Error("DNS resolution failed for static_map host", + "hostname", hostPort.name, + "network", r.network, + "error", err, + ) continue } for _, a := range addrs { @@ -145,7 +148,10 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, } } if different { - l.WithFields(logrus.Fields{"origSet": origSet, "newSet": netipAddrs}).Info("DNS results changed for host list") + l.Info("DNS results changed for host list", + "origSet", origSet, + "newSet", netipAddrs, + ) r.ips.Store(&netipAddrs) onUpdate() } diff --git a/service/service_test.go b/service/service_test.go index c6b87423..4bcc8437 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -10,11 +10,11 @@ import ( "time" "dario.cat/mergo" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/logging" "github.com/slackhq/nebula/overlay" "go.yaml.in/yaml/v3" "golang.org/x/sync/errgroup" @@ -75,8 +75,7 @@ func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp n panic(err) } - logger := logrus.New() - logger.Out = os.Stdout + logger := logging.NewLogger(os.Stdout) control, err := nebula.Main(&c, false, "custom-app", logger, overlay.NewUserDeviceFromConfig) if err != nil { diff --git a/ssh.go b/ssh.go index b2912d55..3863b5ec 100644 --- a/ssh.go +++ b/ssh.go @@ -6,21 +6,21 @@ import ( "errors" "flag" "fmt" + "log/slog" "maps" "net" "net/netip" "os" "path/filepath" - "reflect" "runtime" "runtime/pprof" "sort" "strconv" "strings" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/logging" "github.com/slackhq/nebula/sshd" ) @@ -57,12 +57,12 @@ type sshDeviceInfoFlags struct { Pretty bool } -func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) { +func wireSSHReload(l *slog.Logger, ssh *sshd.SSHServer, c *config.C) { c.RegisterReloadCallback(func(c *config.C) { if c.GetBool("sshd.enabled", false) { sshRun, err := configSSH(l, ssh, c) if err != nil { - l.WithError(err).Error("Failed to reconfigure the sshd") + l.Error("Failed to reconfigure the sshd", "error", err) ssh.Stop() } if sshRun != nil { @@ -78,7 +78,7 @@ func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) { // updates the passed-in SSHServer. On success, it returns a function // that callers may invoke to run the configured ssh server. On // failure, it returns nil, error. -func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) { +func configSSH(l *slog.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) { listen := c.GetString("sshd.listen", "") if listen == "" { return nil, fmt.Errorf("sshd.listen must be provided") @@ -120,7 +120,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro for _, caAuthorizedKey := range rawCAs { err := ssh.AddTrustedCA(caAuthorizedKey) if err != nil { - l.WithError(err).WithField("sshCA", caAuthorizedKey).Warn("SSH CA had an error, ignoring") + l.Warn("SSH CA had an error, ignoring", "error", err, "sshCA", caAuthorizedKey) continue } } @@ -131,13 +131,13 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro for _, rk := range keys { kDef, ok := rk.(map[string]any) if !ok { - l.WithField("sshKeyConfig", rk).Warn("Authorized user had an error, ignoring") + l.Warn("Authorized user had an error, ignoring", "sshKeyConfig", rk) continue } user, ok := kDef["user"].(string) if !ok { - l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the user field") + l.Warn("Authorized user is missing the user field", "sshKeyConfig", rk) continue } @@ -146,7 +146,11 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro case string: err := ssh.AddAuthorizedKey(user, v) if err != nil { - l.WithError(err).WithField("sshKeyConfig", rk).WithField("sshKey", v).Warn("Failed to authorize key") + l.Warn("Failed to authorize key", + "error", err, + "sshKeyConfig", rk, + "sshKey", v, + ) continue } @@ -154,19 +158,25 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro for _, subK := range v { sk, ok := subK.(string) if !ok { - l.WithField("sshKeyConfig", rk).WithField("sshKey", subK).Warn("Did not understand ssh key") + l.Warn("Did not understand ssh key", + "sshKeyConfig", rk, + "sshKey", subK, + ) continue } err := ssh.AddAuthorizedKey(user, sk) if err != nil { - l.WithError(err).WithField("sshKeyConfig", sk).Warn("Failed to authorize key") + l.Warn("Failed to authorize key", + "error", err, + "sshKeyConfig", sk, + ) continue } } default: - l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the keys field or was not understood") + l.Warn("Authorized user is missing the keys field or was not understood", "sshKeyConfig", rk) } } } else { @@ -178,7 +188,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro ssh.Stop() runner = func() { if err := ssh.Run(listen); err != nil { - l.WithField("err", err).Warn("Failed to run the SSH server") + l.Warn("Failed to run the SSH server", "error", err) } } } else { @@ -188,7 +198,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro return runner, nil } -func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) { +func attachCommands(l *slog.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) { // sandboxDir defaults to a dir in temp. The intention is that end user will // create this dir as needed. Overriding this config value to "" allows // writing to anywhere in the system. @@ -789,36 +799,45 @@ func sshGetMutexProfile(sandboxDir string, fs any, a []string, w sshd.StringWrit return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a)) } -func sshLogLevel(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error { +func sshLogLevel(l *slog.Logger, fs any, a []string, w sshd.StringWriter) error { + ctrl, ok := l.Handler().(interface { + GetLevel() slog.Level + SetLevel(slog.Level) + }) + if !ok { + return w.WriteLine("Log level is not reconfigurable on this logger") + } + if len(a) == 0 { - return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) + return w.WriteLine(fmt.Sprintf("Log level is: %s", logging.LevelName(ctrl.GetLevel()))) } - level, err := logrus.ParseLevel(a[0]) + level, err := logging.ParseLevel(strings.ToLower(a[0])) if err != nil { - return w.WriteLine(fmt.Sprintf("Unknown log level %s. Possible log levels: %s", a, logrus.AllLevels)) + return w.WriteLine(fmt.Sprintf("Unknown log level %s. Possible log levels: trace, debug, info, warn, error", a)) } - l.SetLevel(level) - return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) + ctrl.SetLevel(level) + return w.WriteLine(fmt.Sprintf("Log level is: %s", logging.LevelName(ctrl.GetLevel()))) } -func sshLogFormat(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error { +func sshLogFormat(l *slog.Logger, fs any, a []string, w sshd.StringWriter) error { + ctrl, ok := l.Handler().(interface { + GetFormat() string + SetFormat(string) error + }) + if !ok { + return w.WriteLine("Log format is not reconfigurable on this logger") + } + if len(a) == 0 { - return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter))) + return w.WriteLine(fmt.Sprintf("Log format is: %s", ctrl.GetFormat())) } - logFormat := strings.ToLower(a[0]) - switch logFormat { - case "text": - l.Formatter = &logrus.TextFormatter{} - case "json": - l.Formatter = &logrus.JSONFormatter{} - default: - return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"}) + if err := ctrl.SetFormat(strings.ToLower(a[0])); err != nil { + return err } - - return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter))) + return w.WriteLine(fmt.Sprintf("Log format is: %s", ctrl.GetFormat())) } func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { diff --git a/sshd/server.go b/sshd/server.go index 4b5cc3e0..38886e53 100644 --- a/sshd/server.go +++ b/sshd/server.go @@ -5,16 +5,16 @@ import ( "context" "errors" "fmt" + "log/slog" "net" "github.com/armon/go-radix" - "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" ) type SSHServer struct { config *ssh.ServerConfig - l *logrus.Entry + l *slog.Logger certChecker *ssh.CertChecker @@ -33,7 +33,7 @@ type SSHServer struct { } // NewSSHServer creates a new ssh server rigged with default commands and prepares to listen -func NewSSHServer(l *logrus.Entry) (*SSHServer, error) { +func NewSSHServer(l *slog.Logger) (*SSHServer, error) { ctx, cancel := context.WithCancel(context.Background()) s := &SSHServer{ @@ -121,7 +121,7 @@ func (s *SSHServer) AddTrustedCA(pubKey string) error { } s.trustedCAs = append(s.trustedCAs, pk) - s.l.WithField("sshKey", pubKey).Info("Trusted CA key") + s.l.Info("Trusted CA key", "sshKey", pubKey) return nil } @@ -139,7 +139,10 @@ func (s *SSHServer) AddAuthorizedKey(user, pubKey string) error { } tk[string(pk.Marshal())] = true - s.l.WithField("sshKey", pubKey).WithField("sshUser", user).Info("Authorized ssh key") + s.l.Info("Authorized ssh key", + "sshKey", pubKey, + "sshUser", user, + ) return nil } @@ -156,7 +159,7 @@ func (s *SSHServer) Run(addr string) error { return err } - s.l.WithField("sshListener", addr).Info("SSH server is listening") + s.l.Info("SSH server is listening", "sshListener", addr) // Run loops until there is an error s.run() @@ -172,7 +175,7 @@ func (s *SSHServer) run() { c, err := s.listener.Accept() if err != nil { if !errors.Is(err, net.ErrClosed) { - s.l.WithError(err).Warn("Error in listener, shutting down") + s.l.Warn("Error in listener, shutting down", "error", err) } return } @@ -193,23 +196,29 @@ func (s *SSHServer) run() { } if err != nil { - l := s.l.WithError(err).WithField("remoteAddress", c.RemoteAddr()) + l := s.l.With( + "error", err, + "remoteAddress", c.RemoteAddr(), + ) if conn != nil { - l = l.WithField("sshUser", conn.User()) + l = l.With("sshUser", conn.User()) conn.Close() } if fp != "" { - l = l.WithField("sshFingerprint", fp) + l = l.With("sshFingerprint", fp) } l.Warn("failed to handshake") sessionCancel() return } - l := s.l.WithField("sshUser", conn.User()) - l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in") + l := s.l.With("sshUser", conn.User()) + l.Info("ssh user logged in", + "remoteAddress", c.RemoteAddr(), + "sshFingerprint", fp, + ) - NewSession(s.commands, conn, chans, sessionCancel, l.WithField("subsystem", "sshd.session")) + NewSession(s.commands, conn, chans, sessionCancel, l.With("subsystem", "sshd.session")) go ssh.DiscardRequests(reqs) @@ -221,7 +230,7 @@ func (s *SSHServer) Stop() { // Close the listener, this will cause all session to terminate as well, see SSHServer.Run if s.listener != nil { if err := s.listener.Close(); err != nil { - s.l.WithError(err).Warn("Failed to close the sshd listener") + s.l.Warn("Failed to close the sshd listener", "error", err) } } } diff --git a/sshd/session.go b/sshd/session.go index 39c81bd0..1c8e1a9b 100644 --- a/sshd/session.go +++ b/sshd/session.go @@ -2,25 +2,25 @@ package sshd import ( "fmt" + "log/slog" "sort" "strings" "github.com/anmitsu/go-shlex" "github.com/armon/go-radix" - "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" "golang.org/x/term" ) type session struct { - l *logrus.Entry + l *slog.Logger c *ssh.ServerConn term *term.Terminal commands *radix.Tree cancel func() } -func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, cancel func(), l *logrus.Entry) *session { +func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, cancel func(), l *slog.Logger) *session { s := &session{ commands: radix.NewFromMap(commands.ToMap()), l: l, @@ -45,14 +45,14 @@ func (s *session) handleChannels(chans <-chan ssh.NewChannel) { defer s.Close() for newChannel := range chans { if newChannel.ChannelType() != "session" { - s.l.WithField("sshChannelType", newChannel.ChannelType()).Error("unknown channel type") + s.l.Error("unknown channel type", "sshChannelType", newChannel.ChannelType()) newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") continue } channel, requests, err := newChannel.Accept() if err != nil { - s.l.WithError(err).Warn("could not accept channel") + s.l.Warn("could not accept channel", "error", err) continue } @@ -95,12 +95,12 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) { return default: - s.l.WithField("sshRequest", req.Type).Debug("Rejected unknown request") + s.l.Debug("Rejected unknown request", "sshRequest", req.Type) err = req.Reply(false, nil) } if err != nil { - s.l.WithError(err).Info("Error handling ssh session requests") + s.l.Info("Error handling ssh session requests", "error", err) return } } diff --git a/stats.go b/stats.go index c88c45cc..c7bf3a06 100644 --- a/stats.go +++ b/stats.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "log" + "log/slog" "net" "net/http" "runtime" @@ -15,14 +16,13 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" ) // startStats initializes stats from config. On success, if any further work // is needed to serve stats, it returns a func to handle that work. If no // work is needed, it'll return nil. On failure, it returns nil, error. -func startStats(l *logrus.Logger, c *config.C, buildVersion string, configTest bool) (func(), error) { +func startStats(l *slog.Logger, c *config.C, buildVersion string, configTest bool) (func(), error) { mType := c.GetString("stats.type", "") if mType == "" || mType == "none" { return nil, nil @@ -59,7 +59,7 @@ func startStats(l *logrus.Logger, c *config.C, buildVersion string, configTest b return startFn, nil } -func startGraphiteStats(l *logrus.Logger, i time.Duration, c *config.C, configTest bool) error { +func startGraphiteStats(l *slog.Logger, i time.Duration, c *config.C, configTest bool) error { proto := c.GetString("stats.protocol", "tcp") host := c.GetString("stats.host", "") if host == "" { @@ -73,13 +73,17 @@ func startGraphiteStats(l *logrus.Logger, i time.Duration, c *config.C, configTe } if !configTest { - l.Infof("Starting graphite. Interval: %s, prefix: %s, addr: %s", i, prefix, addr) + l.Info("Starting graphite", + "interval", i, + "prefix", prefix, + "addr", addr.String(), + ) go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr) } return nil } -func startPrometheusStats(l *logrus.Logger, i time.Duration, c *config.C, buildVersion string, configTest bool) (func(), error) { +func startPrometheusStats(l *slog.Logger, i time.Duration, c *config.C, buildVersion string, configTest bool) (func(), error) { namespace := c.GetString("stats.namespace", "") subsystem := c.GetString("stats.subsystem", "") @@ -116,9 +120,15 @@ func startPrometheusStats(l *logrus.Logger, i time.Duration, c *config.C, buildV var startFn func() if !configTest { + // promhttp.HandlerOpts.ErrorLog needs a stdlib-shaped Println logger, + // so bridge our slog.Logger back to a *log.Logger that emits at Error. + errLog := slog.NewLogLogger(l.Handler(), slog.LevelError) startFn = func() { - l.Infof("Prometheus stats listening on %s at %s", listen, path) - http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: l})) + l.Info("Prometheus stats listening", + "listen", listen, + "path", path, + ) + http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: errLog})) log.Fatal(http.ListenAndServe(listen, nil)) } } diff --git a/test/logger.go b/test/logger.go index b5a717d8..faab0b69 100644 --- a/test/logger.go +++ b/test/logger.go @@ -1,29 +1,73 @@ package test import ( + "context" "io" + "log/slog" "os" + "time" - "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/logging" ) -func NewLogger() *logrus.Logger { - l := logrus.New() - +// NewLogger returns a *slog.Logger suitable for use in tests. Output goes to +// io.Discard by default; set TEST_LOGS=1 (info), 2 (debug), or 3 (trace) to +// stream output to stderr for local debugging. +func NewLogger() *slog.Logger { v := os.Getenv("TEST_LOGS") if v == "" { - l.SetOutput(io.Discard) - return l + return slog.New(slog.DiscardHandler) } + level := slog.LevelInfo switch v { case "2": - l.SetLevel(logrus.DebugLevel) + level = slog.LevelDebug case "3": - l.SetLevel(logrus.TraceLevel) - default: - l.SetLevel(logrus.InfoLevel) + level = logging.LevelTrace } - - return l + return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level})) +} + +// NewLoggerWithOutput returns a *slog.Logger whose text output is captured by +// w. Timestamps are suppressed so tests can assert on exact output without +// baking the current time into expected strings. +func NewLoggerWithOutput(w io.Writer) *slog.Logger { + return slog.New(&stripTimeHandler{inner: slog.NewTextHandler(w, nil)}) +} + +// NewLoggerWithOutputAndLevel is NewLoggerWithOutput with an explicit level +// so tests can exercise Enabled-gated paths. +func NewLoggerWithOutputAndLevel(w io.Writer, level slog.Level) *slog.Logger { + return slog.New(&stripTimeHandler{inner: slog.NewTextHandler(w, &slog.HandlerOptions{Level: level})}) +} + +// NewJSONLoggerWithOutput returns a *slog.Logger emitting JSON to w with +// timestamps suppressed, for tests that pin the JSON shape. +func NewJSONLoggerWithOutput(w io.Writer, level slog.Level) *slog.Logger { + return slog.New(&stripTimeHandler{inner: slog.NewJSONHandler(w, &slog.HandlerOptions{Level: level})}) +} + +// stripTimeHandler zeros each record's time before delegating so slog's +// built-in handlers skip emitting the time attribute. Used to avoid +// timestamp-dependent assertions in tests without resorting to ReplaceAttr. +type stripTimeHandler struct { + inner slog.Handler +} + +func (h *stripTimeHandler) Enabled(ctx context.Context, l slog.Level) bool { + return h.inner.Enabled(ctx, l) +} + +func (h *stripTimeHandler) Handle(ctx context.Context, r slog.Record) error { + r.Time = time.Time{} + return h.inner.Handle(ctx, r) +} + +func (h *stripTimeHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return &stripTimeHandler{inner: h.inner.WithAttrs(attrs)} +} + +func (h *stripTimeHandler) WithGroup(name string) slog.Handler { + return &stripTimeHandler{inner: h.inner.WithGroup(name)} } diff --git a/udp/udp_android.go b/udp/udp_android.go index bb191954..3fc68003 100644 --- a/udp/udp_android.go +++ b/udp/udp_android.go @@ -9,11 +9,12 @@ import ( "net/netip" "syscall" - "github.com/sirupsen/logrus" + "log/slog" + "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_bsd.go b/udp/udp_bsd.go index 65ef31a5..c42a3c18 100644 --- a/udp/udp_bsd.go +++ b/udp/udp_bsd.go @@ -12,11 +12,12 @@ import ( "net/netip" "syscall" - "github.com/sirupsen/logrus" + "log/slog" + "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 863c98f3..8a4f5b18 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -8,12 +8,12 @@ import ( "encoding/binary" "errors" "fmt" + "log/slog" "net" "net/netip" "syscall" "unsafe" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "golang.org/x/sys/unix" ) @@ -22,12 +22,12 @@ type StdConn struct { *net.UDPConn isV4 bool sysFd uintptr - l *logrus.Logger + l *slog.Logger } var _ Conn = &StdConn{} -func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { lc := NewListenConfig(multi) pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))) if err != nil { @@ -176,7 +176,7 @@ func (u *StdConn) ListenOut(r EncReader) error { return err } - u.l.WithError(err).Error("unexpected udp socket receive error") + u.l.Error("unexpected udp socket receive error", "error", err) } r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) @@ -196,7 +196,7 @@ func (u *StdConn) Rebind() error { } if err != nil { - u.l.WithError(err).Error("Failed to rebind udp socket") + u.l.Error("Failed to rebind udp socket", "error", err) } return nil diff --git a/udp/udp_generic.go b/udp/udp_generic.go index ad26f794..131eb73b 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -12,22 +12,22 @@ import ( "context" "errors" "fmt" + "log/slog" "net" "net/netip" "time" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" ) type GenericConn struct { *net.UDPConn - l *logrus.Logger + l *slog.Logger } var _ Conn = &GenericConn{} -func NewGenericListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewGenericListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { lc := NewListenConfig(multi) pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))) if err != nil { @@ -88,7 +88,7 @@ func (u *GenericConn) ListenOut(r EncReader) error { // Dampen unexpected message warns to once per minute if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute { lastRecvErr = time.Now() - u.l.WithError(err).Warn("unexpected udp socket receive error") + u.l.Warn("unexpected udp socket receive error", "error", err) } continue } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 21a34147..3e2d726a 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -7,13 +7,13 @@ import ( "context" "encoding/binary" "fmt" + "log/slog" "net" "net/netip" "syscall" "unsafe" "github.com/rcrowley/go-metrics" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "golang.org/x/sys/unix" ) @@ -22,7 +22,7 @@ type StdConn struct { udpConn *net.UDPConn rawConn syscall.RawConn isV4 bool - l *logrus.Logger + l *slog.Logger batch int } @@ -38,7 +38,7 @@ func setReusePort(network, address string, c syscall.RawConn) error { return opErr } -func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { listen := netip.AddrPortFrom(ip, uint16(port)) lc := net.ListenConfig{} if multi { @@ -242,12 +242,12 @@ func (u *StdConn) ReloadConfig(c *config.C) { if err == nil { s, err := u.GetRecvBuffer() if err == nil { - u.l.WithField("size", s).Info("listen.read_buffer was set") + u.l.Info("listen.read_buffer was set", "size", s) } else { - u.l.WithError(err).Warn("Failed to get listen.read_buffer") + u.l.Warn("Failed to get listen.read_buffer", "error", err) } } else { - u.l.WithError(err).Error("Failed to set listen.read_buffer") + u.l.Error("Failed to set listen.read_buffer", "error", err) } } @@ -257,12 +257,12 @@ func (u *StdConn) ReloadConfig(c *config.C) { if err == nil { s, err := u.GetSendBuffer() if err == nil { - u.l.WithField("size", s).Info("listen.write_buffer was set") + u.l.Info("listen.write_buffer was set", "size", s) } else { - u.l.WithError(err).Warn("Failed to get listen.write_buffer") + u.l.Warn("Failed to get listen.write_buffer", "error", err) } } else { - u.l.WithError(err).Error("Failed to set listen.write_buffer") + u.l.Error("Failed to set listen.write_buffer", "error", err) } } @@ -273,12 +273,12 @@ func (u *StdConn) ReloadConfig(c *config.C) { if err == nil { s, err := u.GetSoMark() if err == nil { - u.l.WithField("mark", s).Info("listen.so_mark was set") + u.l.Info("listen.so_mark was set", "mark", s) } else { - u.l.WithError(err).Warn("Failed to get listen.so_mark") + u.l.Warn("Failed to get listen.so_mark", "error", err) } } else { - u.l.WithError(err).Error("Failed to set listen.so_mark") + u.l.Error("Failed to set listen.so_mark", "error", err) } } } diff --git a/udp/udp_netbsd.go b/udp/udp_netbsd.go index 3b69159a..4b2de75a 100644 --- a/udp/udp_netbsd.go +++ b/udp/udp_netbsd.go @@ -11,11 +11,12 @@ import ( "net/netip" "syscall" - "github.com/sirupsen/logrus" + "log/slog" + "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index 607b978e..d110af19 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net" "net/netip" "sync" @@ -17,7 +18,6 @@ import ( "time" "unsafe" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/conn/winrio" @@ -53,14 +53,14 @@ type ringBuffer struct { type RIOConn struct { isOpen atomic.Bool - l *logrus.Logger + l *slog.Logger sock windows.Handle rx, tx ringBuffer rq winrio.Rq results [packetsPerRing]winrio.Result } -func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, error) { +func NewRIOListener(l *slog.Logger, addr netip.Addr, port int) (*RIOConn, error) { if !winrio.Initialize() { return nil, errors.New("could not initialize winrio") } @@ -83,7 +83,7 @@ func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, erro return u, nil } -func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error { +func (u *RIOConn) bind(l *slog.Logger, sa windows.Sockaddr) error { var err error u.sock, err = winrio.Socket(windows.AF_INET6, windows.SOCK_DGRAM, windows.IPPROTO_UDP) if err != nil { @@ -103,7 +103,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error { if err != nil { // This is a best-effort to prevent errors from being returned by the udp recv operation. // Quietly log a failure and continue. - l.WithError(err).Debug("failed to set UDP_CONNRESET ioctl") + l.Debug("failed to set UDP_CONNRESET ioctl", "error", err) } ret = 0 @@ -114,7 +114,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error { if err != nil { // This is a best-effort to prevent errors from being returned by the udp recv operation. // Quietly log a failure and continue. - l.WithError(err).Debug("failed to set UDP_NETRESET ioctl") + l.Debug("failed to set UDP_NETRESET ioctl", "error", err) } err = u.rx.Open() @@ -156,7 +156,7 @@ func (u *RIOConn) ListenOut(r EncReader) error { // Dampen unexpected message warns to once per minute if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute { lastRecvErr = time.Now() - u.l.WithError(err).Warn("unexpected udp socket receive error") + u.l.Warn("unexpected udp socket receive error", "error", err) } continue } diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 388b17d0..fcd0967c 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -4,12 +4,13 @@ package udp import ( + "context" "io" + "log/slog" "net/netip" "os" "sync" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" ) @@ -46,10 +47,10 @@ type TesterConn struct { done chan struct{} closeOnce sync.Once - l *logrus.Logger + l *slog.Logger } -func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) { return &TesterConn{ Addr: netip.AddrPortFrom(ip, uint16(port)), RxPackets: make(chan *Packet, 10), @@ -67,11 +68,12 @@ func (u *TesterConn) Send(packet *Packet) { if err := h.Parse(packet.Data); err != nil { panic(err) } - if u.l.Level >= logrus.DebugLevel { - u.l.WithField("header", h). - WithField("udpAddr", packet.From). - WithField("dataLen", len(packet.Data)). - Debug("UDP receiving injected packet") + if u.l.Enabled(context.Background(), slog.LevelDebug) { + u.l.Debug("UDP receiving injected packet", + "header", h, + "udpAddr", packet.From, + "dataLen", len(packet.Data), + ) } select { case <-u.done: diff --git a/udp/udp_windows.go b/udp/udp_windows.go index 1b777c37..7969f7e8 100644 --- a/udp/udp_windows.go +++ b/udp/udp_windows.go @@ -5,14 +5,13 @@ package udp import ( "fmt" + "log/slog" "net" "net/netip" "syscall" - - "github.com/sirupsen/logrus" ) -func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { if multi { //NOTE: Technically we can support it with RIO but it wouldn't be at the socket level // The udp stack would need to be reworked to hide away the implementation differences between @@ -25,7 +24,7 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in return rc, nil } - l.WithError(err).Error("Falling back to standard udp sockets") + l.Error("Falling back to standard udp sockets", "error", err) return NewGenericListener(l, ip, port, multi, batch) } diff --git a/util/error.go b/util/error.go index 814c77a1..14371d3f 100644 --- a/util/error.go +++ b/util/error.go @@ -1,10 +1,10 @@ package util import ( + "context" "errors" "fmt" - - "github.com/sirupsen/logrus" + "log/slog" ) type ContextualError struct { @@ -28,12 +28,12 @@ func ContextualizeIfNeeded(msg string, err error) error { } // LogWithContextIfNeeded is a helper function to log an error line for an error or ContextualError -func LogWithContextIfNeeded(msg string, err error, l *logrus.Logger) { +func LogWithContextIfNeeded(msg string, err error, l *slog.Logger) { switch v := err.(type) { case *ContextualError: v.Log(l) default: - l.WithError(err).Error(msg) + l.Error(msg, "error", err) } } @@ -51,10 +51,19 @@ func (ce *ContextualError) Unwrap() error { return ce.RealError } -func (ce *ContextualError) Log(lr *logrus.Logger) { - if ce.RealError != nil { - lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context) - } else { - lr.WithFields(ce.Fields).Error(ce.Context) +// Log emits ce as a single error-level log line with Fields and RealError +// promoted to top-level attributes, producing a flat shape callers can grep +// or parse without walking into a nested object. +func (ce *ContextualError) Log(l *slog.Logger) { + attrs := make([]slog.Attr, 0, len(ce.Fields)+1) + for k, v := range ce.Fields { + attrs = append(attrs, slog.Any(k, v)) } + if ce.RealError != nil { + attrs = append(attrs, slog.Any("error", ce.RealError)) + } + // LogAttrs is intentional: attrs is built from a map[string]any so it has + // no pair-form equivalent. + //nolint:sloglint + l.LogAttrs(context.Background(), slog.LevelError, ce.Context, attrs...) } diff --git a/util/error_test.go b/util/error_test.go index 692c1840..30e39e33 100644 --- a/util/error_test.go +++ b/util/error_test.go @@ -1,95 +1,67 @@ package util import ( + "bytes" "errors" "fmt" "testing" - "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" ) type m = map[string]any -type TestLogWriter struct { - Logs []string -} - -func NewTestLogWriter() *TestLogWriter { - return &TestLogWriter{Logs: make([]string, 0)} -} - -func (tl *TestLogWriter) Write(p []byte) (n int, err error) { - tl.Logs = append(tl.Logs, string(p)) - return len(p), nil -} - -func (tl *TestLogWriter) Reset() { - tl.Logs = tl.Logs[:0] -} - func TestContextualError_Log(t *testing.T) { - l := logrus.New() - l.Formatter = &logrus.TextFormatter{ - DisableTimestamp: true, - DisableColors: true, - } - - tl := NewTestLogWriter() - l.Out = tl + buf := &bytes.Buffer{} + l := test.NewLoggerWithOutput(buf) // Test a full context line - tl.Reset() + buf.Reset() e := NewContextualError("test message", m{"field": "1"}, errors.New("error")) e.Log(l) - assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"test message\" field=1 error=error\n", buf.String()) // Test a line with an error and msg but no fields - tl.Reset() + buf.Reset() e = NewContextualError("test message", nil, errors.New("error")) e.Log(l) - assert.Equal(t, []string{"level=error msg=\"test message\" error=error\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"test message\" error=error\n", buf.String()) // Test just a context and fields - tl.Reset() + buf.Reset() e = NewContextualError("test message", m{"field": "1"}, nil) e.Log(l) - assert.Equal(t, []string{"level=error msg=\"test message\" field=1\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"test message\" field=1\n", buf.String()) // Test just a context - tl.Reset() + buf.Reset() e = NewContextualError("test message", nil, nil) e.Log(l) - assert.Equal(t, []string{"level=error msg=\"test message\"\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"test message\"\n", buf.String()) // Test just an error - tl.Reset() + buf.Reset() e = NewContextualError("", nil, errors.New("error")) e.Log(l) - assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"\" error=error\n", buf.String()) } func TestLogWithContextIfNeeded(t *testing.T) { - l := logrus.New() - l.Formatter = &logrus.TextFormatter{ - DisableTimestamp: true, - DisableColors: true, - } - - tl := NewTestLogWriter() - l.Out = tl + buf := &bytes.Buffer{} + l := test.NewLoggerWithOutput(buf) // Test ignoring fallback context - tl.Reset() + buf.Reset() e := NewContextualError("test message", m{"field": "1"}, errors.New("error")) LogWithContextIfNeeded("This should get thrown away", e, l) - assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"test message\" field=1 error=error\n", buf.String()) // Test using fallback context - tl.Reset() + buf.Reset() err := fmt.Errorf("this is a normal error") LogWithContextIfNeeded("Fallback context woo", err, l) - assert.Equal(t, []string{"level=error msg=\"Fallback context woo\" error=\"this is a normal error\"\n"}, tl.Logs) + assert.Equal(t, "level=ERROR msg=\"Fallback context woo\" error=\"this is a normal error\"\n", buf.String()) } func TestContextualizeIfNeeded(t *testing.T) {