Switch to slog, remove logrus (#1672)

This commit is contained in:
Nate Brown
2026-04-27 09:41:47 -05:00
committed by GitHub
parent 5f890dbc34
commit d0f02ba873
77 changed files with 2299 additions and 1338 deletions

View File

@@ -2,7 +2,21 @@ version: "2"
linters: linters:
default: none default: none
enable: enable:
- sloglint
- testifylint - testifylint
settings:
sloglint:
# Enforce key-value pair form for Info/Debug/Warn/Error/Log/With and
# the package-level slog equivalents. Use l.Log(ctx, level, ...) for
# custom levels instead of LogAttrs when you can.
#
# LogAttrs is also flagged by this rule because it takes ...slog.Attr;
# the few legitimate sites (where attrs is built up as a []slog.Attr)
# carry a //nolint:sloglint with rationale.
kv-only: true
# no-mixed-args is on by default: forbids mixing kv and attrs in one call.
# discard-handler is on by default (since Go 1.24): suggests
# slog.DiscardHandler over slog.NewTextHandler(io.Discard, nil).
exclusions: exclusions:
generated: lax generated: lax
presets: presets:

38
bits.go
View File

@@ -1,8 +1,10 @@
package nebula package nebula
import ( import (
"context"
"log/slog"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
) )
type Bits struct { type Bits struct {
@@ -30,7 +32,7 @@ func NewBits(bits uint64) *Bits {
return b return b
} }
func (b *Bits) Check(l *logrus.Logger, i uint64) bool { func (b *Bits) Check(l *slog.Logger, i uint64) bool {
// If i is the next number, return true. // If i is the next number, return true.
if i > b.current { if i > b.current {
return true return true
@@ -42,13 +44,16 @@ func (b *Bits) Check(l *logrus.Logger, i uint64) bool {
} }
// Not within the window // Not within the window
if l.Level >= logrus.DebugLevel { if l.Enabled(context.Background(), slog.LevelDebug) {
l.Debugf("rejected a packet (top) %d %d\n", b.current, i) l.Debug("rejected a packet (top)",
"current", b.current,
"incoming", i,
)
} }
return false return false
} }
func (b *Bits) Update(l *logrus.Logger, i uint64) bool { func (b *Bits) Update(l *slog.Logger, i uint64) bool {
// If i is the next number, return true and update current. // If i is the next number, return true and update current.
if i == b.current+1 { if i == b.current+1 {
// Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter // Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter
@@ -87,9 +92,13 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
// Check to see if it's a duplicate // Check to see if it's a duplicate
if i > b.current-b.length || i < b.length && b.current < b.length { if i > b.current-b.length || i < b.length && b.current < b.length {
if b.current == i || b.bits[i%b.length] == true { if b.current == i || b.bits[i%b.length] == true {
if l.Level >= logrus.DebugLevel { if l.Enabled(context.Background(), slog.LevelDebug) {
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}). l.Debug("Receive window",
Debug("Receive window") "accepted", false,
"currentCounter", b.current,
"incomingCounter", i,
"reason", "duplicate",
)
} }
b.dupeCounter.Inc(1) b.dupeCounter.Inc(1)
return false return false
@@ -101,12 +110,13 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
// In all other cases, fail and don't change current. // In all other cases, fail and don't change current.
b.outOfWindowCounter.Inc(1) b.outOfWindowCounter.Inc(1)
if l.Level >= logrus.DebugLevel { if l.Enabled(context.Background(), slog.LevelDebug) {
l.WithField("accepted", false). l.Debug("Receive window",
WithField("currentCounter", b.current). "accepted", false,
WithField("incomingCounter", i). "currentCounter", b.current,
WithField("reason", "nonsense"). "incomingCounter", i,
Debug("Receive window") "reason", "nonsense",
)
} }
return false return false
} }

View File

@@ -3,8 +3,15 @@
package main package main
import "github.com/sirupsen/logrus" import (
"log/slog"
"os"
func HookLogger(l *logrus.Logger) { "github.com/slackhq/nebula/logging"
// Do nothing, let the logs flow to stdout/stderr )
// newPlatformLogger returns a *slog.Logger that writes to stdout. Non-Windows
// platforms have no special sink to integrate with.
func newPlatformLogger() *slog.Logger {
return logging.NewLogger(os.Stdout)
} }

View File

@@ -1,54 +1,86 @@
package main package main
import ( import (
"fmt" "context"
"io/ioutil" "log/slog"
"os" "strings"
"sync"
"github.com/kardianos/service" "github.com/slackhq/nebula/logging"
"github.com/sirupsen/logrus"
) )
// HookLogger routes the logrus logs through the service logger so that they end up in the Windows Event Viewer // newPlatformLogger returns a *slog.Logger that routes every log record
// logrus output will be discarded // through the Windows service logger so records end up in the Windows
func HookLogger(l *logrus.Logger) { // Event Log. All the heavy lifting (level management, format swap,
l.AddHook(newLogHook(logger)) // timestamp toggle, WithAttrs/WithGroup) comes from logging.NewHandler;
l.SetOutput(ioutil.Discard) // this file only contributes:
//
// - an io.Writer that forwards each formatted line to the service
// logger at the current record's Event Log severity, and
// - a thin severityTag that embeds *logging.Handler and overrides
// only Handle / WithAttrs / WithGroup, so Event Viewer's severity
// column and severity-based filters keep working the way they did
// before the slog migration.
//
// Format (text vs json) is carried by the embedded *logging.Handler, so
// logging.format: json in config still produces JSON lines in Event
// Viewer, same as the pre-slog logrus setup.
func newPlatformLogger() *slog.Logger {
w := &eventLogWriter{}
return slog.New(&severityTag{Handler: logging.NewHandler(w), w: w})
} }
type logHook struct { // eventLogWriter forwards slog-formatted lines to the Windows service
sl service.Logger // logger at the severity most recently stashed by severityTag.Handle.
// The mutex serializes the stash + inner.Handle + Write cycle per record
// across all concurrent goroutines; slog's builtin text/json handlers
// each hold their own mutex around Write, but that only protects the
// Write call itself, not our stash-then-handle sequence.
type eventLogWriter struct {
mu sync.Mutex
level slog.Level
} }
func newLogHook(sl service.Logger) *logHook { func (w *eventLogWriter) Write(p []byte) (int, error) {
return &logHook{sl: sl} line := strings.TrimRight(string(p), "\n")
} switch {
case w.level >= slog.LevelError:
func (h *logHook) Fire(entry *logrus.Entry) error { return len(p), logger.Error(line)
line, err := entry.String() case w.level >= slog.LevelWarn:
if err != nil { return len(p), logger.Warning(line)
fmt.Fprintf(os.Stderr, "Unable to read entry, %v", err)
return err
}
switch entry.Level {
case logrus.PanicLevel:
return h.sl.Error(line)
case logrus.FatalLevel:
return h.sl.Error(line)
case logrus.ErrorLevel:
return h.sl.Error(line)
case logrus.WarnLevel:
return h.sl.Warning(line)
case logrus.InfoLevel:
return h.sl.Info(line)
case logrus.DebugLevel:
return h.sl.Info(line)
default: default:
return nil return len(p), logger.Info(line)
} }
} }
func (h *logHook) Levels() []logrus.Level { // severityTag embeds *logging.Handler to pick up everything it does for
return logrus.AllLevels // free (Enabled, SetLevel, GetLevel, SetFormat, GetFormat,
// SetDisableTimestamp) and overrides only Handle / WithAttrs / WithGroup
// so each record's slog.Level is stashed on the writer before formatting
// and so derived handlers stay wrapped as severityTag rather than
// downgrading to bare *logging.Handler.
type severityTag struct {
*logging.Handler
w *eventLogWriter
}
func (s *severityTag) Handle(ctx context.Context, r slog.Record) error {
s.w.mu.Lock()
defer s.w.mu.Unlock()
s.w.level = r.Level
return s.Handler.Handle(ctx, r)
}
func (s *severityTag) WithAttrs(attrs []slog.Attr) slog.Handler {
if len(attrs) == 0 {
return s
}
return &severityTag{Handler: s.Handler.WithAttrs(attrs).(*logging.Handler), w: s.w}
}
func (s *severityTag) WithGroup(name string) slog.Handler {
if name == "" {
return s
}
return &severityTag{Handler: s.Handler.WithGroup(name).(*logging.Handler), w: s.w}
} }

View File

@@ -7,9 +7,9 @@ import (
"runtime/debug" "runtime/debug"
"strings" "strings"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/logging"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
) )
@@ -50,12 +50,11 @@ func main() {
os.Exit(0) os.Exit(0)
} }
l := logrus.New() l := logging.NewLogger(os.Stdout)
l.Out = os.Stdout
if *serviceFlag != "" { if *serviceFlag != "" {
if err := doService(configPath, configTest, Build, serviceFlag); err != nil { if err := doService(configPath, configTest, Build, serviceFlag); err != nil {
l.WithError(err).Error("Service command failed") l.Error("Service command failed", "error", err)
os.Exit(1) os.Exit(1)
} }
return return
@@ -74,6 +73,16 @@ func main() {
os.Exit(1) os.Exit(1)
} }
if err := logging.ApplyConfig(l, c); err != nil {
fmt.Printf("failed to apply logging config: %s", err)
os.Exit(1)
}
c.RegisterReloadCallback(func(c *config.C) {
if err := logging.ApplyConfig(l, c); err != nil {
l.Error("Failed to reconfigure logger on reload", "error", err)
}
})
ctrl, err := nebula.Main(c, *configTest, Build, l, nil) ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
if err != nil { if err != nil {
util.LogWithContextIfNeeded("Failed to start", err, l) util.LogWithContextIfNeeded("Failed to start", err, l)
@@ -90,7 +99,7 @@ func main() {
go ctrl.ShutdownBlock() go ctrl.ShutdownBlock()
if err := wait(); err != nil { if err := wait(); err != nil {
l.WithError(err).Error("Nebula stopped due to fatal error") l.Error("Nebula stopped due to fatal error", "error", err)
os.Exit(2) os.Exit(2)
} }

View File

@@ -7,9 +7,9 @@ import (
"path/filepath" "path/filepath"
"github.com/kardianos/service" "github.com/kardianos/service"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/logging"
) )
var logger service.Logger var logger service.Logger
@@ -25,8 +25,7 @@ func (p *program) Start(s service.Service) error {
// Start should not block. // Start should not block.
logger.Info("Nebula service starting.") logger.Info("Nebula service starting.")
l := logrus.New() l := newPlatformLogger()
HookLogger(l)
c := config.NewC(l) c := config.NewC(l)
err := c.Load(*p.configPath) err := c.Load(*p.configPath)
@@ -34,6 +33,15 @@ func (p *program) Start(s service.Service) error {
return fmt.Errorf("failed to load config: %s", err) return fmt.Errorf("failed to load config: %s", err)
} }
if err := logging.ApplyConfig(l, c); err != nil {
return fmt.Errorf("failed to apply logging config: %s", err)
}
c.RegisterReloadCallback(func(c *config.C) {
if err := logging.ApplyConfig(l, c); err != nil {
l.Error("Failed to reconfigure logger on reload", "error", err)
}
})
p.control, err = nebula.Main(c, *p.configTest, Build, l, nil) p.control, err = nebula.Main(c, *p.configTest, Build, l, nil)
if err != nil { if err != nil {
return err return err
@@ -85,7 +93,7 @@ func doService(configPath *string, configTest *bool, build string, serviceFlag *
// Here are what the different loggers are doing: // Here are what the different loggers are doing:
// - `log` is the standard go log utility, meant to be used while the process is still attached to stdout/stderr // - `log` is the standard go log utility, meant to be used while the process is still attached to stdout/stderr
// - `logger` is the service log utility that may be attached to a special place depending on OS (Windows will have it attached to the event log) // - `logger` is the service log utility that may be attached to a special place depending on OS (Windows will have it attached to the event log)
// - above, in `Run` we create a `logrus.Logger` which is what nebula expects to use // - in program.Start we build a *slog.Logger via newPlatformLogger; on non-Windows that is a stdout-backed slog logger, on Windows it routes records through the service logger
s, err := service.New(prg, svcConfig) s, err := service.New(prg, svcConfig)
if err != nil { if err != nil {
return err return err

View File

@@ -7,9 +7,9 @@ import (
"runtime/debug" "runtime/debug"
"strings" "strings"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/logging"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
) )
@@ -55,8 +55,7 @@ func main() {
os.Exit(1) os.Exit(1)
} }
l := logrus.New() l := logging.NewLogger(os.Stdout)
l.Out = os.Stdout
c := config.NewC(l) c := config.NewC(l)
err := c.Load(*configPath) err := c.Load(*configPath)
@@ -65,6 +64,16 @@ func main() {
os.Exit(1) os.Exit(1)
} }
if err := logging.ApplyConfig(l, c); err != nil {
fmt.Printf("failed to apply logging config: %s", err)
os.Exit(1)
}
c.RegisterReloadCallback(func(c *config.C) {
if err := logging.ApplyConfig(l, c); err != nil {
l.Error("Failed to reconfigure logger on reload", "error", err)
}
})
ctrl, err := nebula.Main(c, *configTest, Build, l, nil) ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
if err != nil { if err != nil {
util.LogWithContextIfNeeded("Failed to start", err, l) util.LogWithContextIfNeeded("Failed to start", err, l)
@@ -82,7 +91,7 @@ func main() {
notifyReady(l) notifyReady(l)
if err := wait(); err != nil { if err := wait(); err != nil {
l.WithError(err).Error("Nebula stopped due to fatal error") l.Error("Nebula stopped due to fatal error", "error", err)
os.Exit(2) os.Exit(2)
} }

View File

@@ -1,11 +1,10 @@
package main package main
import ( import (
"log/slog"
"net" "net"
"os" "os"
"time" "time"
"github.com/sirupsen/logrus"
) )
// SdNotifyReady tells systemd the service is ready and dependent services can now be started // SdNotifyReady tells systemd the service is ready and dependent services can now be started
@@ -13,30 +12,30 @@ import (
// https://www.freedesktop.org/software/systemd/man/systemd.service.html // https://www.freedesktop.org/software/systemd/man/systemd.service.html
const SdNotifyReady = "READY=1" const SdNotifyReady = "READY=1"
func notifyReady(l *logrus.Logger) { func notifyReady(l *slog.Logger) {
sockName := os.Getenv("NOTIFY_SOCKET") sockName := os.Getenv("NOTIFY_SOCKET")
if sockName == "" { if sockName == "" {
l.Debugln("NOTIFY_SOCKET systemd env var not set, not sending ready signal") l.Debug("NOTIFY_SOCKET systemd env var not set, not sending ready signal")
return return
} }
conn, err := net.DialTimeout("unixgram", sockName, time.Second) conn, err := net.DialTimeout("unixgram", sockName, time.Second)
if err != nil { if err != nil {
l.WithError(err).Error("failed to connect to systemd notification socket") l.Error("failed to connect to systemd notification socket", "error", err)
return return
} }
defer conn.Close() defer conn.Close()
err = conn.SetWriteDeadline(time.Now().Add(time.Second)) err = conn.SetWriteDeadline(time.Now().Add(time.Second))
if err != nil { if err != nil {
l.WithError(err).Error("failed to set the write deadline for the systemd notification socket") l.Error("failed to set the write deadline for the systemd notification socket", "error", err)
return return
} }
if _, err = conn.Write([]byte(SdNotifyReady)); err != nil { if _, err = conn.Write([]byte(SdNotifyReady)); err != nil {
l.WithError(err).Error("failed to signal the systemd notification socket") l.Error("failed to signal the systemd notification socket", "error", err)
return return
} }
l.Debugln("notified systemd the service is ready") l.Debug("notified systemd the service is ready")
} }

View File

@@ -3,8 +3,8 @@
package main package main
import "github.com/sirupsen/logrus" import "log/slog"
func notifyReady(_ *logrus.Logger) { func notifyReady(_ *slog.Logger) {
// No init service to notify // No init service to notify
} }

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"math" "math"
"os" "os"
"os/signal" "os/signal"
@@ -16,7 +17,6 @@ import (
"time" "time"
"dario.cat/mergo" "dario.cat/mergo"
"github.com/sirupsen/logrus"
"go.yaml.in/yaml/v3" "go.yaml.in/yaml/v3"
) )
@@ -26,11 +26,11 @@ type C struct {
Settings map[string]any Settings map[string]any
oldSettings map[string]any oldSettings map[string]any
callbacks []func(*C) callbacks []func(*C)
l *logrus.Logger l *slog.Logger
reloadLock sync.Mutex reloadLock sync.Mutex
} }
func NewC(l *logrus.Logger) *C { func NewC(l *slog.Logger) *C {
return &C{ return &C{
Settings: make(map[string]any), Settings: make(map[string]any),
l: l, l: l,
@@ -107,12 +107,18 @@ func (c *C) HasChanged(k string) bool {
newVals, err := yaml.Marshal(nv) newVals, err := yaml.Marshal(nv)
if err != nil { if err != nil {
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config") c.l.Error("Error while marshaling new config",
"config_path", k,
"error", err,
)
} }
oldVals, err := yaml.Marshal(ov) oldVals, err := yaml.Marshal(ov)
if err != nil { if err != nil {
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config") c.l.Error("Error while marshaling old config",
"config_path", k,
"error", err,
)
} }
return string(newVals) != string(oldVals) return string(newVals) != string(oldVals)
@@ -154,7 +160,10 @@ func (c *C) ReloadConfig() {
err := c.Load(c.path) err := c.Load(c.path)
if err != nil { if err != nil {
c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config") c.l.Error("Error occurred while reloading config",
"config_path", c.path,
"error", err,
)
return return
} }

View File

@@ -5,13 +5,13 @@ import (
"context" "context"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"log/slog"
"net/netip" "net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
@@ -47,10 +47,10 @@ type connectionManager struct {
metricsTxPunchy metrics.Counter metricsTxPunchy metrics.Counter
l *logrus.Logger l *slog.Logger
} }
func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager { func newConnectionManagerFromConfig(l *slog.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
cm := &connectionManager{ cm := &connectionManager{
hostMap: hm, hostMap: hm,
l: l, l: l,
@@ -85,9 +85,10 @@ func (cm *connectionManager) reload(c *config.C, initial bool) {
old := cm.getInactivityTimeout() old := cm.getInactivityTimeout()
cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute))) cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute)))
if !initial { if !initial {
cm.l.WithField("oldDuration", old). cm.l.Info("Inactivity timeout has changed",
WithField("newDuration", cm.getInactivityTimeout()). "oldDuration", old,
Info("Inactivity timeout has changed") "newDuration", cm.getInactivityTimeout(),
)
} }
} }
@@ -95,9 +96,10 @@ func (cm *connectionManager) reload(c *config.C, initial bool) {
old := cm.dropInactive.Load() old := cm.dropInactive.Load()
cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false)) cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false))
if !initial { if !initial {
cm.l.WithField("oldBool", old). cm.l.Info("Drop inactive setting has changed",
WithField("newBool", cm.dropInactive.Load()). "oldBool", old,
Info("Drop inactive setting has changed") "newBool", cm.dropInactive.Load(),
)
} }
} }
} }
@@ -256,7 +258,7 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
var err error var err error
index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested) index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested)
if err != nil { if err != nil {
cm.l.WithError(err).Error("failed to migrate relay to new hostinfo") cm.l.Error("failed to migrate relay to new hostinfo", "error", err)
continue continue
} }
switch r.Type { switch r.Type {
@@ -304,16 +306,16 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
msg, err := req.Marshal() msg, err := req.Marshal()
if err != nil { if err != nil {
cm.l.WithError(err).Error("failed to marshal Control message to migrate relay") cm.l.Error("failed to marshal Control message to migrate relay", "error", err)
} else { } else {
cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
cm.l.WithFields(logrus.Fields{ cm.l.Info("send CreateRelayRequest",
"relayFrom": req.RelayFromAddr, "relayFrom", req.RelayFromAddr,
"relayTo": req.RelayToAddr, "relayTo", req.RelayToAddr,
"initiatorRelayIndex": req.InitiatorRelayIndex, "initiatorRelayIndex", req.InitiatorRelayIndex,
"responderRelayIndex": req.ResponderRelayIndex, "responderRelayIndex", req.ResponderRelayIndex,
"vpnAddrs": newhostinfo.vpnAddrs}). "vpnAddrs", newhostinfo.vpnAddrs,
Info("send CreateRelayRequest") )
} }
} }
} }
@@ -325,7 +327,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
hostinfo := cm.hostMap.Indexes[localIndex] hostinfo := cm.hostMap.Indexes[localIndex]
if hostinfo == nil { if hostinfo == nil {
cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap") cm.l.Debug("Not found in hostmap", "localIndex", localIndex)
return doNothing, nil, nil return doNothing, nil, nil
} }
@@ -345,10 +347,10 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
// A hostinfo is determined alive if there is incoming traffic // A hostinfo is determined alive if there is incoming traffic
if inTraffic { if inTraffic {
decision := doNothing decision := doNothing
if cm.l.Level >= logrus.DebugLevel { if cm.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(cm.l). hostinfo.logger(cm.l).Debug("Tunnel status",
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}). "tunnelCheck", m{"state": "alive", "method": "passive"},
Debug("Tunnel status") )
} }
hostinfo.pendingDeletion.Store(false) hostinfo.pendingDeletion.Store(false)
@@ -375,9 +377,9 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
if hostinfo.pendingDeletion.Load() { if hostinfo.pendingDeletion.Load() {
// We have already sent a test packet and nothing was returned, this hostinfo is dead // We have already sent a test packet and nothing was returned, this hostinfo is dead
hostinfo.logger(cm.l). hostinfo.logger(cm.l).Info("Tunnel status",
WithField("tunnelCheck", m{"state": "dead", "method": "active"}). "tunnelCheck", m{"state": "dead", "method": "active"},
Info("Tunnel status") )
return deleteTunnel, hostinfo, nil return deleteTunnel, hostinfo, nil
} }
@@ -388,10 +390,10 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
inactiveFor, isInactive := cm.isInactive(hostinfo, now) inactiveFor, isInactive := cm.isInactive(hostinfo, now)
if isInactive { if isInactive {
// Tunnel is inactive, tear it down // Tunnel is inactive, tear it down
hostinfo.logger(cm.l). hostinfo.logger(cm.l).Info("Dropping tunnel due to inactivity",
WithField("inactiveDuration", inactiveFor). "inactiveDuration", inactiveFor,
WithField("primary", mainHostInfo). "primary", mainHostInfo,
Info("Dropping tunnel due to inactivity") )
return closeTunnel, hostinfo, primary return closeTunnel, hostinfo, primary
} }
@@ -410,18 +412,18 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
cm.sendPunch(hostinfo) cm.sendPunch(hostinfo)
} }
if cm.l.Level >= logrus.DebugLevel { if cm.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(cm.l). hostinfo.logger(cm.l).Debug("Tunnel status",
WithField("tunnelCheck", m{"state": "testing", "method": "active"}). "tunnelCheck", m{"state": "testing", "method": "active"},
Debug("Tunnel status") )
} }
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
decision = sendTestPacket decision = sendTestPacket
} else { } else {
if cm.l.Level >= logrus.DebugLevel { if cm.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(cm.l).Debugf("Hostinfo sadness") hostinfo.logger(cm.l).Debug("Hostinfo sadness")
} }
} }
@@ -493,14 +495,16 @@ func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostI
return false //cert is still valid! yay! return false //cert is still valid! yay!
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed } else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
// Block listed certificates should always be disconnected // Block listed certificates should always be disconnected
hostinfo.logger(cm.l).WithError(err). hostinfo.logger(cm.l).Info("Remote certificate is blocked, tearing down the tunnel",
WithField("fingerprint", remoteCert.Fingerprint). "error", err,
Info("Remote certificate is blocked, tearing down the tunnel") "fingerprint", remoteCert.Fingerprint,
)
return true return true
} else if cm.intf.disconnectInvalid.Load() { } else if cm.intf.disconnectInvalid.Load() {
hostinfo.logger(cm.l).WithError(err). hostinfo.logger(cm.l).Info("Remote certificate is no longer valid, tearing down the tunnel",
WithField("fingerprint", remoteCert.Fingerprint). "error", err,
Info("Remote certificate is no longer valid, tearing down the tunnel") "fingerprint", remoteCert.Fingerprint,
)
return true return true
} else { } else {
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open //if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
@@ -539,10 +543,11 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
curCrtVersion := curCrt.Version() curCrtVersion := curCrt.Version()
myCrt := cs.getCertificate(curCrtVersion) myCrt := cs.getCertificate(curCrtVersion)
if myCrt == nil { if myCrt == nil {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). cm.l.Info("Re-handshaking with remote",
WithField("version", curCrtVersion). "vpnAddrs", hostinfo.vpnAddrs,
WithField("reason", "local certificate removed"). "version", curCrtVersion,
Info("Re-handshaking with remote") "reason", "local certificate removed",
)
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return return
} }
@@ -550,11 +555,12 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() { if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
// if our certificate version is less than theirs, and we have a matching version available, rehandshake? // if our certificate version is less than theirs, and we have a matching version available, rehandshake?
if cs.getCertificate(peerCrt.Certificate.Version()) != nil { if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). cm.l.Info("Re-handshaking with remote",
WithField("version", curCrtVersion). "vpnAddrs", hostinfo.vpnAddrs,
WithField("peerVersion", peerCrt.Certificate.Version()). "version", curCrtVersion,
WithField("reason", "local certificate version lower than peer, attempting to correct"). "peerVersion", peerCrt.Certificate.Version(),
Info("Re-handshaking with remote") "reason", "local certificate version lower than peer, attempting to correct",
)
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) { cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
hh.initiatingVersionOverride = peerCrt.Certificate.Version() hh.initiatingVersionOverride = peerCrt.Certificate.Version()
}) })
@@ -562,17 +568,19 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
} }
} }
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) { if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). cm.l.Info("Re-handshaking with remote",
WithField("reason", "local certificate is not current"). "vpnAddrs", hostinfo.vpnAddrs,
Info("Re-handshaking with remote") "reason", "local certificate is not current",
)
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return return
} }
if curCrtVersion < cs.initiatingVersion { if curCrtVersion < cs.initiatingVersion {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). cm.l.Info("Re-handshaking with remote",
WithField("reason", "current cert version < pki.initiatingVersion"). "vpnAddrs", hostinfo.vpnAddrs,
Info("Re-handshaking with remote") "reason", "current cert version < pki.initiatingVersion",
)
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return return

View File

@@ -10,6 +10,7 @@ import (
"github.com/flynn/noise" "github.com/flynn/noise"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/overlaytest"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -52,7 +53,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
lh := newTestLighthouse() lh := newTestLighthouse()
ifce := &Interface{ ifce := &Interface{
hostMap: hostMap, hostMap: hostMap,
inside: &test.NoopTun{}, inside: &overlaytest.NoopTun{},
outside: &udp.NoopConn{}, outside: &udp.NoopConn{},
firewall: &Firewall{}, firewall: &Firewall{},
lightHouse: lh, lightHouse: lh,
@@ -63,9 +64,9 @@ func Test_NewConnectionManagerTest(t *testing.T) {
ifce.pki.cs.Store(cs) ifce.pki.cs.Store(cs)
// Create manager // Create manager
conf := config.NewC(l) conf := config.NewC(test.NewLogger())
punchy := NewPunchyFromConfig(l, conf) punchy := NewPunchyFromConfig(test.NewLogger(), conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
nc.intf = ifce nc.intf = ifce
p := []byte("") p := []byte("")
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
@@ -135,7 +136,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
lh := newTestLighthouse() lh := newTestLighthouse()
ifce := &Interface{ ifce := &Interface{
hostMap: hostMap, hostMap: hostMap,
inside: &test.NoopTun{}, inside: &overlaytest.NoopTun{},
outside: &udp.NoopConn{}, outside: &udp.NoopConn{},
firewall: &Firewall{}, firewall: &Firewall{},
lightHouse: lh, lightHouse: lh,
@@ -146,9 +147,9 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
ifce.pki.cs.Store(cs) ifce.pki.cs.Store(cs)
// Create manager // Create manager
conf := config.NewC(l) conf := config.NewC(test.NewLogger())
punchy := NewPunchyFromConfig(l, conf) punchy := NewPunchyFromConfig(test.NewLogger(), conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
nc.intf = ifce nc.intf = ifce
p := []byte("") p := []byte("")
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
@@ -220,7 +221,7 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
lh := newTestLighthouse() lh := newTestLighthouse()
ifce := &Interface{ ifce := &Interface{
hostMap: hostMap, hostMap: hostMap,
inside: &test.NoopTun{}, inside: &overlaytest.NoopTun{},
outside: &udp.NoopConn{}, outside: &udp.NoopConn{},
firewall: &Firewall{}, firewall: &Firewall{},
lightHouse: lh, lightHouse: lh,
@@ -231,12 +232,12 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
ifce.pki.cs.Store(cs) ifce.pki.cs.Store(cs)
// Create manager // Create manager
conf := config.NewC(l) conf := config.NewC(test.NewLogger())
conf.Settings["tunnels"] = map[string]any{ conf.Settings["tunnels"] = map[string]any{
"drop_inactive": true, "drop_inactive": true,
} }
punchy := NewPunchyFromConfig(l, conf) punchy := NewPunchyFromConfig(test.NewLogger(), conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
assert.True(t, nc.dropInactive.Load()) assert.True(t, nc.dropInactive.Load())
nc.intf = ifce nc.intf = ifce
@@ -347,7 +348,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
lh := newTestLighthouse() lh := newTestLighthouse()
ifce := &Interface{ ifce := &Interface{
hostMap: hostMap, hostMap: hostMap,
inside: &test.NoopTun{}, inside: &overlaytest.NoopTun{},
outside: &udp.NoopConn{}, outside: &udp.NoopConn{},
firewall: &Firewall{}, firewall: &Firewall{},
lightHouse: lh, lightHouse: lh,
@@ -360,9 +361,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
ifce.disconnectInvalid.Store(true) ifce.disconnectInvalid.Store(true)
// Create manager // Create manager
conf := config.NewC(l) conf := config.NewC(test.NewLogger())
punchy := NewPunchyFromConfig(l, conf) punchy := NewPunchyFromConfig(test.NewLogger(), conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
nc.intf = ifce nc.intf = ifce
ifce.connectionManager = nc ifce.connectionManager = nc

View File

@@ -8,7 +8,6 @@ import (
"sync/atomic" "sync/atomic"
"github.com/flynn/noise" "github.com/flynn/noise"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/noiseutil"
) )
@@ -27,7 +26,7 @@ type ConnectionState struct {
writeLock sync.Mutex writeLock sync.Mutex
} }
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) { func NewConnectionState(cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
var dhFunc noise.DHFunc var dhFunc noise.DHFunc
switch crt.Curve() { switch crt.Curve() {
case cert.Curve_CURVE25519: case cert.Curve_CURVE25519:

View File

@@ -3,13 +3,13 @@ package nebula
import ( import (
"context" "context"
"errors" "errors"
"log/slog"
"net/netip" "net/netip"
"os" "os"
"os/signal" "os/signal"
"sync" "sync"
"syscall" "syscall"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
@@ -46,7 +46,7 @@ type Control struct {
state RunState state RunState
f *Interface f *Interface
l *logrus.Logger l *slog.Logger
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
sshStart func() sshStart func()
@@ -151,7 +151,7 @@ func (c *Control) Stop() {
c.CloseAllTunnels(false) c.CloseAllTunnels(false)
if err := c.f.Close(); err != nil { if err := c.f.Close(); err != nil {
c.l.WithError(err).Error("Close interface failed") c.l.Error("Close interface failed", "error", err)
} }
c.stateLock.Lock() c.stateLock.Lock()
c.state = StateStopped c.state = StateStopped
@@ -166,7 +166,7 @@ func (c *Control) ShutdownBlock() {
rawSig := <-sigChan rawSig := <-sigChan
sig := rawSig.String() sig := rawSig.String()
c.l.WithField("signal", sig).Info("Caught signal, shutting down") c.l.Info("Caught signal, shutting down", "signal", sig)
c.Stop() c.Stop()
} }
@@ -303,8 +303,10 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
c.f.closeTunnel(h) c.f.closeTunnel(h)
c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote). c.l.Debug("Sending close tunnel message",
Debug("Sending close tunnel message") "vpnAddrs", h.vpnAddrs,
"udpAddr", h.remote,
)
closed++ closed++
} }

View File

@@ -6,7 +6,6 @@ import (
"reflect" "reflect"
"testing" "testing"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -83,7 +82,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
f: &Interface{ f: &Interface{
hostMap: hm, hostMap: hm,
}, },
l: logrus.New(), l: test.NewLogger(),
} }
thi := c.GetHostInfoByVpnAddr(vpnIp, false) thi := c.GetHostInfoByVpnAddr(vpnIp, false)

View File

@@ -3,6 +3,7 @@ package nebula
import ( import (
"context" "context"
"fmt" "fmt"
"log/slog"
"net" "net"
"net/netip" "net/netip"
"strconv" "strconv"
@@ -12,13 +13,12 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
) )
type dnsServer struct { type dnsServer struct {
sync.RWMutex sync.RWMutex
l *logrus.Logger l *slog.Logger
ctx context.Context ctx context.Context
dnsMap4 map[string]netip.Addr dnsMap4 map[string]netip.Addr
dnsMap6 map[string]netip.Addr dnsMap6 map[string]netip.Addr
@@ -55,7 +55,7 @@ type dnsServer struct {
// they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel // they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel
// watcher that tears the listener down on nebula shutdown. The returned // watcher that tears the listener down on nebula shutdown. The returned
// pointer is always non-nil, even on error. // pointer is always non-nil, even on error.
func newDnsServerFromConfig(ctx context.Context, l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) { func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) {
ds := &dnsServer{ ds := &dnsServer{
l: l, l: l,
ctx: ctx, ctx: ctx,
@@ -69,7 +69,7 @@ func newDnsServerFromConfig(ctx context.Context, l *logrus.Logger, cs *CertState
c.RegisterReloadCallback(func(c *config.C) { c.RegisterReloadCallback(func(c *config.C) {
if err := ds.reload(c, false); err != nil { if err := ds.reload(c, false); err != nil {
l.WithError(err).Error("Failed to reload DNS responder from config") ds.l.Error("Failed to reload DNS responder from config", "error", err)
} }
}) })
@@ -145,7 +145,7 @@ func (d *dnsServer) shutdownServer(srv *dns.Server, started chan struct{}, reaso
<-started <-started
} }
if err := srv.Shutdown(); err != nil { if err := srv.Shutdown(); err != nil {
d.l.WithError(err).WithField("reason", reason).Warn("Failed to shut down the DNS responder") d.l.Warn("Failed to shut down the DNS responder", "reason", reason, "error", err)
} }
} }
@@ -188,7 +188,7 @@ func (d *dnsServer) Start() {
} }
}() }()
d.l.WithField("dnsListener", addr).Info("Starting DNS responder") d.l.Info("Starting DNS responder", "dnsListener", addr)
err := server.ListenAndServe() err := server.ListenAndServe()
close(done) close(done)
@@ -201,7 +201,7 @@ func (d *dnsServer) Start() {
} }
if err != nil { if err != nil {
d.l.WithError(err).Warn("Failed to run the DNS responder") d.l.Warn("Failed to run the DNS responder", "error", err)
} }
} }
@@ -314,6 +314,7 @@ func (d *dnsServer) isSelfNebulaOrLocalhost(addr string) bool {
} }
func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) { func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
debugEnabled := d.l.Enabled(context.Background(), slog.LevelDebug)
// Per RFC 2308 §2.2, a name that exists but has no record of the requested // Per RFC 2308 §2.2, a name that exists but has no record of the requested
// type must be answered with NOERROR and an empty answer section (NODATA), // type must be answered with NOERROR and an empty answer section (NODATA),
// not NXDOMAIN (RFC 2308 §2.1), which is reserved for names that do not // not NXDOMAIN (RFC 2308 §2.1), which is reserved for names that do not
@@ -323,7 +324,9 @@ func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
switch q.Qtype { switch q.Qtype {
case dns.TypeA, dns.TypeAAAA: case dns.TypeA, dns.TypeAAAA:
qType := dns.TypeToString[q.Qtype] qType := dns.TypeToString[q.Qtype]
d.l.Debugf("Query for %s %s", qType, q.Name) if debugEnabled {
d.l.Debug("DNS query", "type", qType, "name", q.Name)
}
ip, nameExists := d.Query(q.Qtype, q.Name) ip, nameExists := d.Query(q.Qtype, q.Name)
if nameExists { if nameExists {
anyNameExists = true anyNameExists = true
@@ -339,7 +342,9 @@ func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) { if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
return return
} }
d.l.Debugf("Query for TXT %s", q.Name) if debugEnabled {
d.l.Debug("DNS query", "type", "TXT", "name", q.Name)
}
ip := d.QueryCert(q.Name) ip := d.QueryCert(q.Name)
if ip != "" { if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip)) rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))

View File

@@ -2,7 +2,7 @@ package nebula
import ( import (
"context" "context"
"io" "log/slog"
"net" "net"
"net/netip" "net/netip"
"strconv" "strconv"
@@ -10,7 +10,6 @@ import (
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -30,7 +29,7 @@ func (stubDNSWriter) TsigTimersOnly(bool) {}
func (stubDNSWriter) Hijack() {} func (stubDNSWriter) Hijack() {}
func TestParsequery(t *testing.T) { func TestParsequery(t *testing.T) {
l := logrus.New() l := slog.New(slog.DiscardHandler)
hostMap := &HostMap{} hostMap := &HostMap{}
ds := &dnsServer{ ds := &dnsServer{
l: l, l: l,
@@ -137,10 +136,9 @@ func Test_getDnsServerAddr(t *testing.T) {
func newTestDnsServer(t *testing.T) (*dnsServer, *config.C) { func newTestDnsServer(t *testing.T) (*dnsServer, *config.C) {
t.Helper() t.Helper()
l := logrus.New() sl := slog.New(slog.DiscardHandler)
l.Out = io.Discard
ds := &dnsServer{ ds := &dnsServer{
l: l, l: sl,
ctx: context.Background(), ctx: context.Background(),
dnsMap4: make(map[string]netip.Addr), dnsMap4: make(map[string]netip.Addr),
dnsMap6: make(map[string]netip.Addr), dnsMap6: make(map[string]netip.Addr),
@@ -148,7 +146,7 @@ func newTestDnsServer(t *testing.T) (*dnsServer, *config.C) {
} }
ds.mux = dns.NewServeMux() ds.mux = dns.NewServeMux()
ds.mux.HandleFunc(".", ds.handleDnsRequest) ds.mux.HandleFunc(".", ds.handleDnsRequest)
return ds, config.NewC(l) return ds, config.NewC(nil)
} }
func setDnsConfig(c *config.C, host string, port string, amLighthouse, serveDns bool) { func setDnsConfig(c *config.C, host string, port string, amLighthouse, serveDns bool) {

View File

@@ -11,7 +11,6 @@ import (
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/cert_test"
@@ -749,7 +748,6 @@ func TestStage1RaceRelays2(t *testing.T) {
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
l := NewTestLogger()
// Teach my how to get to the relay and that their can be reached via the relay // Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
@@ -771,49 +769,41 @@ func TestStage1RaceRelays2(t *testing.T) {
theirControl.Start() theirControl.Start()
r.Log("Get a tunnel between me and relay") r.Log("Get a tunnel between me and relay")
l.Info("Get a tunnel between me and relay")
assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
r.Log("Get a tunnel between them and relay") r.Log("Get a tunnel between them and relay")
l.Info("Get a tunnel between them and relay")
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
r.Log("Trigger a handshake from both them and me via relay to them and me") r.Log("Trigger a handshake from both them and me via relay to them and me")
l.Info("Trigger a handshake from both them and me via relay to them and me")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone) //r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone) //r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
r.Log("Wait for a packet from them to me") r.Log("Wait for a packet from them to me; myControl")
l.Info("Wait for a packet from them to me; myControl")
r.RouteForAllUntilTxTun(myControl) r.RouteForAllUntilTxTun(myControl)
l.Info("Wait for a packet from them to me; theirControl") r.Log("Wait for a packet from them to me; theirControl")
r.RouteForAllUntilTxTun(theirControl) r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works") r.Log("Assert the tunnel works")
l.Info("Assert the tunnel works")
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
t.Log("Wait until we remove extra tunnels") t.Log("Wait until we remove extra tunnels")
l.Info("Wait until we remove extra tunnels") t.Logf("Waiting for hostinfos to be removed... myControl=%d theirControl=%d relayControl=%d",
l.WithFields( len(myControl.GetHostmap().Indexes),
logrus.Fields{ len(theirControl.GetHostmap().Indexes),
"myControl": len(myControl.GetHostmap().Indexes), len(relayControl.GetHostmap().Indexes),
"theirControl": len(theirControl.GetHostmap().Indexes), )
"relayControl": len(relayControl.GetHostmap().Indexes),
}).Info("Waiting for hostinfos to be removed...")
hostInfos := len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes) hostInfos := len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes)
retries := 60 retries := 60
for hostInfos > 6 && retries > 0 { for hostInfos > 6 && retries > 0 {
hostInfos = len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes) hostInfos = len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes)
l.WithFields( t.Logf("Waiting for hostinfos to be removed... myControl=%d theirControl=%d relayControl=%d",
logrus.Fields{ len(myControl.GetHostmap().Indexes),
"myControl": len(myControl.GetHostmap().Indexes), len(theirControl.GetHostmap().Indexes),
"theirControl": len(theirControl.GetHostmap().Indexes), len(relayControl.GetHostmap().Indexes),
"relayControl": len(relayControl.GetHostmap().Indexes), )
}).Info("Waiting for hostinfos to be removed...")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet") t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -821,7 +811,6 @@ func TestStage1RaceRelays2(t *testing.T) {
} }
r.Log("Assert the tunnel works") r.Log("Assert the tunnel works")
l.Info("Assert the tunnel works")
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
myControl.Stop() myControl.Stop()

View File

@@ -4,7 +4,6 @@
package e2e package e2e
import ( import (
"fmt"
"io" "io"
"net/netip" "net/netip"
"os" "os"
@@ -12,15 +11,18 @@ import (
"testing" "testing"
"time" "time"
"log/slog"
"dario.cat/mergo" "dario.cat/mergo"
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/e2e/router"
"github.com/slackhq/nebula/logging"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.yaml.in/yaml/v3" "go.yaml.in/yaml/v3"
@@ -132,8 +134,7 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
"port": udpAddr.Port(), "port": udpAddr.Port(),
}, },
"logging": m{ "logging": m{
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name), "level": testLogLevelName(),
"level": l.Level.String(),
}, },
"timers": m{ "timers": m{
"pending_deletion_interval": 2, "pending_deletion_interval": 2,
@@ -234,8 +235,7 @@ func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, o
"port": udpAddr.Port(), "port": udpAddr.Port(),
}, },
"logging": m{ "logging": m{
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()), "level": testLogLevelName(),
"level": l.Level.String(),
}, },
"timers": m{ "timers": m{
"pending_deletion_interval": 2, "pending_deletion_interval": 2,
@@ -379,24 +379,32 @@ func getAddrs(ns []netip.Prefix) []netip.Addr {
return a return a
} }
func NewTestLogger() *logrus.Logger { func NewTestLogger() *slog.Logger {
l := logrus.New()
v := os.Getenv("TEST_LOGS") v := os.Getenv("TEST_LOGS")
if v == "" { if v == "" {
l.SetOutput(io.Discard) return slog.New(slog.NewTextHandler(io.Discard, nil))
l.SetLevel(logrus.PanicLevel)
return l
} }
level := slog.LevelInfo
switch v { switch v {
case "2": case "2":
l.SetLevel(logrus.DebugLevel) level = slog.LevelDebug
case "3": case "3":
l.SetLevel(logrus.TraceLevel) level = logging.LevelTrace
default: }
l.SetLevel(logrus.InfoLevel) return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level}))
} }
return l // testLogLevelName returns the level name string accepted by logging.ApplyConfig
// for the current TEST_LOGS setting. Kept in sync with NewTestLogger.
func testLogLevelName() string {
switch os.Getenv("TEST_LOGS") {
case "2":
return "debug"
case "3":
return "trace"
case "":
return "info"
}
return "info"
} }

View File

@@ -292,23 +292,17 @@ tun:
# Configure logging level # Configure logging level
logging: logging:
# panic, fatal, error, warning, info, or debug. Default is info and is reloadable. # trace, debug, info, warn, or error. Default is info and is reloadable.
#NOTE: Debug mode can log remotely controlled/untrusted data which can quickly fill a disk in some # fatal and panic are accepted for backwards compatibility and map to error.
# scenarios. Debug logging is also CPU intensive and will decrease performance overall. #NOTE: Debug and trace modes can log remotely controlled/untrusted data which can quickly fill a disk in some
# Only enable debug logging while actively investigating an issue. # scenarios. Debug and trace logging are also CPU intensive and will decrease performance overall.
# Only enable debug or trace logging while actively investigating an issue.
level: info level: info
# json or text formats currently available. Default is text # json or text formats currently available. Default is text.
format: text format: text
# Disable timestamp logging. useful when output is redirected to logging system that already adds timestamps. Default is false # Disable timestamp logging. Useful when output is redirected to a logging system that already adds timestamps. Default is false.
#disable_timestamp: true #disable_timestamp: true
# timestamp format is specified in Go time format, see: # Timestamps use RFC3339Nano ("2006-01-02T15:04:05.999999999Z07:00") and are not configurable.
# https://golang.org/pkg/time/#pkg-constants
# default when `format: json`: "2006-01-02T15:04:05Z07:00" (RFC3339)
# default when `format: text`:
# when TTY attached: seconds since beginning of execution
# otherwise: "2006-01-02T15:04:05Z07:00" (RFC3339)
# As an example, to log as RFC3339 with millisecond precision, set to:
#timestamp_format: "2006-01-02T15:04:05.000Z07:00"
#stats: #stats:
#type: graphite #type: graphite

View File

@@ -7,9 +7,9 @@ import (
"net" "net"
"os" "os"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/logging"
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/service" "github.com/slackhq/nebula/service"
) )
@@ -64,8 +64,7 @@ pki:
return err return err
} }
logger := logrus.New() logger := logging.NewLogger(os.Stdout)
logger.Out = os.Stdout
ctrl, err := nebula.Main(&cfg, false, "custom-app", logger, overlay.NewUserDeviceFromConfig) ctrl, err := nebula.Main(&cfg, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
if err != nil { if err != nil {

View File

@@ -1,11 +1,13 @@
package nebula package nebula
import ( import (
"context"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"hash/fnv" "hash/fnv"
"log/slog"
"net/netip" "net/netip"
"reflect" "reflect"
"slices" "slices"
@@ -16,7 +18,6 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
@@ -67,7 +68,7 @@ type Firewall struct {
incomingMetrics firewallMetrics incomingMetrics firewallMetrics
outgoingMetrics firewallMetrics outgoingMetrics firewallMetrics
l *logrus.Logger l *slog.Logger
} }
type firewallMetrics struct { type firewallMetrics struct {
@@ -131,7 +132,7 @@ type firewallLocalCIDR struct {
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
// The certificate provided should be the highest version loaded in memory. // The certificate provided should be the highest version loaded in memory.
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall { func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
//TODO: error on 0 duration //TODO: error on 0 duration
var tmin, tmax time.Duration var tmin, tmax time.Duration
@@ -191,7 +192,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
} }
} }
func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) { func NewFirewallFromConfig(l *slog.Logger, cs *CertState, c *config.C) (*Firewall, error) {
certificate := cs.getCertificate(cert.Version2) certificate := cs.getCertificate(cert.Version2)
if certificate == nil { if certificate == nil {
certificate = cs.getCertificate(cert.Version1) certificate = cs.getCertificate(cert.Version1)
@@ -219,7 +220,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
case "drop": case "drop":
fw.InSendReject = false fw.InSendReject = false
default: default:
l.WithField("action", inboundAction).Warn("invalid firewall.inbound_action, defaulting to `drop`") l.Warn("invalid firewall.inbound_action, defaulting to `drop`", "action", inboundAction)
fw.InSendReject = false fw.InSendReject = false
} }
@@ -230,7 +231,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
case "drop": case "drop":
fw.OutSendReject = false fw.OutSendReject = false
default: default:
l.WithField("action", outboundAction).Warn("invalid firewall.outbound_action, defaulting to `drop`") l.Warn("invalid firewall.outbound_action, defaulting to `drop`", "action", outboundAction)
fw.OutSendReject = false fw.OutSendReject = false
} }
@@ -268,7 +269,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
case firewall.ProtoICMP, firewall.ProtoICMPv6: case firewall.ProtoICMP, firewall.ProtoICMPv6:
//ICMP traffic doesn't have ports, so we always coerce to "any", even if a value is provided //ICMP traffic doesn't have ports, so we always coerce to "any", even if a value is provided
if startPort != firewall.PortAny { if startPort != firewall.PortAny {
f.l.WithField("startPort", startPort).Warn("ignoring port specification for ICMP firewall rule") f.l.Warn("ignoring port specification for ICMP firewall rule", "startPort", startPort)
} }
startPort = firewall.PortAny startPort = firewall.PortAny
endPort = firewall.PortAny endPort = firewall.PortAny
@@ -290,8 +291,9 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
if !incoming { if !incoming {
direction = "outgoing" direction = "outgoing"
} }
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}). f.l.Info("Firewall rule added",
Info("Firewall rule added") "firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha},
)
return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha) return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha)
} }
@@ -314,7 +316,7 @@ func (f *Firewall) GetRuleHashes() string {
return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10) return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10)
} }
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error { func AddFirewallRulesFromConfig(l *slog.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
var table string var table string
if inbound { if inbound {
table = "firewall.inbound" table = "firewall.inbound"
@@ -372,7 +374,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
startPort = firewall.PortAny startPort = firewall.PortAny
endPort = firewall.PortAny endPort = firewall.PortAny
if sPort != "" { if sPort != "" {
l.WithField("port", sPort).Warn("ignoring port specification for ICMP firewall rule") l.Warn("ignoring port specification for ICMP firewall rule", "port", sPort)
} }
default: default:
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto) return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
@@ -396,7 +398,11 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
} }
if warning := r.sanity(); warning != nil { if warning := r.sanity(); warning != nil {
l.Warnf("%s rule #%v; %s", table, i, warning) l.Warn("firewall rule sanity check",
"table", table,
"rule", i,
"warning", warning,
)
} }
err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha) err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
@@ -528,26 +534,26 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
// We now know which firewall table to check against // We now know which firewall table to check against
if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) { if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
if f.l.Level >= logrus.DebugLevel { if f.l.Enabled(context.Background(), slog.LevelDebug) {
h.logger(f.l). h.logger(f.l).Debug("dropping old conntrack entry, does not match new ruleset",
WithField("fwPacket", fp). "fwPacket", fp,
WithField("incoming", c.incoming). "incoming", c.incoming,
WithField("rulesVersion", f.rulesVersion). "rulesVersion", f.rulesVersion,
WithField("oldRulesVersion", c.rulesVersion). "oldRulesVersion", c.rulesVersion,
Debugln("dropping old conntrack entry, does not match new ruleset") )
} }
delete(conntrack.Conns, fp) delete(conntrack.Conns, fp)
conntrack.Unlock() conntrack.Unlock()
return false return false
} }
if f.l.Level >= logrus.DebugLevel { if f.l.Enabled(context.Background(), slog.LevelDebug) {
h.logger(f.l). h.logger(f.l).Debug("keeping old conntrack entry, does match new ruleset",
WithField("fwPacket", fp). "fwPacket", fp,
WithField("incoming", c.incoming). "incoming", c.incoming,
WithField("rulesVersion", f.rulesVersion). "rulesVersion", f.rulesVersion,
WithField("oldRulesVersion", c.rulesVersion). "oldRulesVersion", c.rulesVersion,
Debugln("keeping old conntrack entry, does match new ruleset") )
} }
c.rulesVersion = f.rulesVersion c.rulesVersion = f.rulesVersion
@@ -935,7 +941,7 @@ type rule struct {
CASha string CASha string
} }
func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) { func convertRule(l *slog.Logger, p any, table string, i int) (rule, error) {
r := rule{} r := rule{}
m, ok := p.(map[string]any) m, ok := p.(map[string]any)
@@ -966,7 +972,10 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
return r, errors.New("group should contain a single value, an array with more than one entry was provided") return r, errors.New("group should contain a single value, an array with more than one entry was provided")
} }
l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i) l.Warn("group was an array with a single value, converting to simple value",
"table", table,
"rule", i,
)
m["group"] = v[0] m["group"] = v[0]
} }

View File

@@ -2,10 +2,9 @@ package firewall
import ( import (
"context" "context"
"log/slog"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/sirupsen/logrus"
) )
// ConntrackCache is used as a local routine cache to know if a given flow // ConntrackCache is used as a local routine cache to know if a given flow
@@ -16,15 +15,17 @@ type ConntrackCacheTicker struct {
cacheV uint64 cacheV uint64
cacheTick atomic.Uint64 cacheTick atomic.Uint64
l *slog.Logger
cache ConntrackCache cache ConntrackCache
} }
func NewConntrackCacheTicker(ctx context.Context, d time.Duration) *ConntrackCacheTicker { func NewConntrackCacheTicker(ctx context.Context, l *slog.Logger, d time.Duration) *ConntrackCacheTicker {
if d == 0 { if d == 0 {
return nil return nil
} }
c := &ConntrackCacheTicker{ c := &ConntrackCacheTicker{
l: l,
cache: ConntrackCache{}, cache: ConntrackCache{},
} }
@@ -48,15 +49,15 @@ func (c *ConntrackCacheTicker) tick(ctx context.Context, d time.Duration) {
// Get checks if the cache ticker has moved to the next version before returning // Get checks if the cache ticker has moved to the next version before returning
// the map. If it has moved, we reset the map. // the map. If it has moved, we reset the map.
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache { func (c *ConntrackCacheTicker) Get() ConntrackCache {
if c == nil { if c == nil {
return nil return nil
} }
if tick := c.cacheTick.Load(); tick != c.cacheV { if tick := c.cacheTick.Load(); tick != c.cacheV {
c.cacheV = tick c.cacheV = tick
if ll := len(c.cache); ll > 0 { if ll := len(c.cache); ll > 0 {
if l.Level == logrus.DebugLevel { if c.l.Enabled(context.Background(), slog.LevelDebug) {
l.WithField("len", ll).Debug("resetting conntrack cache") c.l.Debug("resetting conntrack cache", "len", ll)
} }
c.cache = make(ConntrackCache, ll) c.cache = make(ConntrackCache, ll)
} }

69
firewall/cache_test.go Normal file
View File

@@ -0,0 +1,69 @@
package firewall
import (
"bytes"
"log/slog"
"strings"
"testing"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
)
// The tests below pin the log format produced by ConntrackCacheTicker.Get
// so changes cannot silently break what operators are grepping for. The
// ticker's internal state (cache + cacheTick) is poked directly to avoid
// racing a goroutine-driven tick in tests.
func newFixedTicker(t *testing.T, l *slog.Logger, cacheLen int) *ConntrackCacheTicker {
t.Helper()
c := &ConntrackCacheTicker{
l: l,
cache: make(ConntrackCache, cacheLen),
}
for i := 0; i < cacheLen; i++ {
c.cache[Packet{LocalPort: uint16(i) + 1}] = struct{}{}
}
c.cacheTick.Store(1) // cacheV starts at 0, so Get() takes the reset path
return c
}
func TestConntrackCacheTicker_Get_TextFormat(t *testing.T) {
buf := &bytes.Buffer{}
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug)
c := newFixedTicker(t, l, 3)
c.Get()
assert.Equal(t, "level=DEBUG msg=\"resetting conntrack cache\" len=3\n", buf.String())
}
func TestConntrackCacheTicker_Get_JSONFormat(t *testing.T) {
buf := &bytes.Buffer{}
l := test.NewJSONLoggerWithOutput(buf, slog.LevelDebug)
c := newFixedTicker(t, l, 2)
c.Get()
assert.JSONEq(t, `{"level":"DEBUG","msg":"resetting conntrack cache","len":2}`, strings.TrimSpace(buf.String()))
}
func TestConntrackCacheTicker_Get_QuietBelowDebug(t *testing.T) {
buf := &bytes.Buffer{}
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelInfo)
c := newFixedTicker(t, l, 5)
c.Get()
assert.Empty(t, buf.String())
}
func TestConntrackCacheTicker_Get_QuietWhenCacheEmpty(t *testing.T) {
buf := &bytes.Buffer{}
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug)
c := newFixedTicker(t, l, 0)
c.Get()
assert.Empty(t, buf.String())
}

View File

@@ -3,13 +3,13 @@ package nebula
import ( import (
"bytes" "bytes"
"errors" "errors"
"log/slog"
"math" "math"
"net/netip" "net/netip"
"testing" "testing"
"time" "time"
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
@@ -58,9 +58,8 @@ func TestNewFirewall(t *testing.T) {
} }
func TestFirewall_AddRule(t *testing.T) { func TestFirewall_AddRule(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l := test.NewLoggerWithOutput(ob)
c := &dummyCert{} c := &dummyCert{}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
@@ -177,9 +176,8 @@ func TestFirewall_AddRule(t *testing.T) {
} }
func TestFirewall_Drop(t *testing.T) { func TestFirewall_Drop(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{ p := firewall.Packet{
@@ -254,9 +252,8 @@ func TestFirewall_Drop(t *testing.T) {
} }
func TestFirewall_DropV6(t *testing.T) { func TestFirewall_DropV6(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7")) myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
@@ -485,9 +482,8 @@ func BenchmarkFirewallTable_match(b *testing.B) {
} }
func TestFirewall_Drop2(t *testing.T) { func TestFirewall_Drop2(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
@@ -544,9 +540,8 @@ func TestFirewall_Drop2(t *testing.T) {
} }
func TestFirewall_Drop3(t *testing.T) { func TestFirewall_Drop3(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
@@ -633,9 +628,8 @@ func TestFirewall_Drop3(t *testing.T) {
} }
func TestFirewall_Drop3V6(t *testing.T) { func TestFirewall_Drop3V6(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7")) myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
@@ -671,9 +665,8 @@ func TestFirewall_Drop3V6(t *testing.T) {
} }
func TestFirewall_DropConntrackReload(t *testing.T) { func TestFirewall_DropConntrackReload(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
@@ -736,9 +729,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
} }
func TestFirewall_ICMPPortBehavior(t *testing.T) { func TestFirewall_ICMPPortBehavior(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
@@ -880,9 +872,8 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
} }
func TestFirewall_DropIPSpoofing(t *testing.T) { func TestFirewall_DropIPSpoofing(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l := test.NewLoggerWithOutput(ob)
myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24")) myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
@@ -1045,25 +1036,25 @@ func TestNewFirewallFromConfig(t *testing.T) {
cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil) cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
require.NoError(t, err) require.NoError(t, err)
conf := config.NewC(l) conf := config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": "asdf"} conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
// Test both port and code // Test both port and code
conf = config.NewC(l) conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
// Test missing host, group, cidr, ca_name and ca_sha // Test missing host, group, cidr, ca_name and ca_sha
conf = config.NewC(l) conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
// Test code/port error // Test code/port error
conf = config.NewC(l) conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh", "proto": "any"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
@@ -1073,25 +1064,25 @@ func TestNewFirewallFromConfig(t *testing.T) {
require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
// Test proto error // Test proto error
conf = config.NewC(l) conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
// Test cidr parse error // Test cidr parse error
conf = config.NewC(l) conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test local_cidr parse error // Test local_cidr parse error
conf = config.NewC(l) conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test both group and groups // Test both group and groups
conf = config.NewC(l) conf = config.NewC(test.NewLogger())
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
@@ -1100,35 +1091,35 @@ func TestNewFirewallFromConfig(t *testing.T) {
func TestAddFirewallRulesFromConfig(t *testing.T) { func TestAddFirewallRulesFromConfig(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
// Test adding tcp rule // Test adding tcp rule
conf := config.NewC(l) conf := config.NewC(test.NewLogger())
mf := &mockFirewall{} mf := &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding udp rule // Test adding udp rule
conf = config.NewC(l) conf = config.NewC(test.NewLogger())
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding icmp rule // Test adding icmp rule
conf = config.NewC(l) conf = config.NewC(test.NewLogger())
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding icmp rule no port // Test adding icmp rule no port
conf = config.NewC(l) conf = config.NewC(test.NewLogger())
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"proto": "icmp", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"proto": "icmp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding any rule // Test adding any rule
conf = config.NewC(l) conf = config.NewC(test.NewLogger())
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
@@ -1136,14 +1127,14 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
// Test adding rule with cidr // Test adding rule with cidr
cidr := netip.MustParsePrefix("10.0.0.0/8") cidr := netip.MustParsePrefix("10.0.0.0/8")
conf = config.NewC(l) conf = config.NewC(test.NewLogger())
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall)
// Test adding rule with local_cidr // Test adding rule with local_cidr
conf = config.NewC(l) conf = config.NewC(test.NewLogger())
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
@@ -1151,82 +1142,82 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
// Test adding rule with cidr ipv6 // Test adding rule with cidr ipv6
cidr6 := netip.MustParsePrefix("fd00::/8") cidr6 := netip.MustParsePrefix("fd00::/8")
conf = config.NewC(l) conf = config.NewC(test.NewLogger())
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall)
// Test adding rule with any cidr // Test adding rule with any cidr
conf = config.NewC(l) conf = config.NewC(test.NewLogger())
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall)
// Test adding rule with junk cidr // Test adding rule with junk cidr
conf = config.NewC(l) conf = config.NewC(test.NewLogger())
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP") require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
// Test adding rule with local_cidr ipv6 // Test adding rule with local_cidr ipv6
conf = config.NewC(l) conf = config.NewC(test.NewLogger())
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall)
// Test adding rule with any local_cidr // Test adding rule with any local_cidr
conf = config.NewC(l) conf = config.NewC(test.NewLogger())
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall)
// Test adding rule with junk local_cidr // Test adding rule with junk local_cidr
conf = config.NewC(l) conf = config.NewC(test.NewLogger())
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP") require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
// Test adding rule with ca_sha // Test adding rule with ca_sha
conf = config.NewC(l) conf = config.NewC(test.NewLogger())
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name // Test adding rule with ca_name
conf = config.NewC(l) conf = config.NewC(test.NewLogger())
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall)
// Test single group // Test single group
conf = config.NewC(l) conf = config.NewC(test.NewLogger())
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
// Test single groups // Test single groups
conf = config.NewC(l) conf = config.NewC(test.NewLogger())
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
// Test multiple AND groups // Test multiple AND groups
conf = config.NewC(l) conf = config.NewC(test.NewLogger())
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall)
// Test Add error // Test Add error
conf = config.NewC(l) conf = config.NewC(test.NewLogger())
mf = &mockFirewall{} mf = &mockFirewall{}
mf.nextCallReturn = errors.New("test error") mf.nextCallReturn = errors.New("test error")
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
@@ -1234,9 +1225,8 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
} }
func TestFirewall_convertRule(t *testing.T) { func TestFirewall_convertRule(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l := test.NewLoggerWithOutput(ob)
// Ensure group array of 1 is converted and a warning is printed // Ensure group array of 1 is converted and a warning is printed
c := map[string]any{ c := map[string]any{
@@ -1244,7 +1234,9 @@ func TestFirewall_convertRule(t *testing.T) {
} }
r, err := convertRule(l, c, "test", 1) r, err := convertRule(l, c, "test", 1)
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value") assert.Contains(t, ob.String(), "group was an array with a single value, converting to simple value")
assert.Contains(t, ob.String(), "table=test")
assert.Contains(t, ob.String(), "rule=1")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{"group1"}, r.Groups) assert.Equal(t, []string{"group1"}, r.Groups)
@@ -1270,9 +1262,8 @@ func TestFirewall_convertRule(t *testing.T) {
} }
func TestFirewall_convertRuleSanity(t *testing.T) { func TestFirewall_convertRuleSanity(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l := test.NewLoggerWithOutput(ob)
noWarningPlease := []map[string]any{ noWarningPlease := []map[string]any{
{"group": "group1"}, {"group": "group1"},
@@ -1386,7 +1377,7 @@ type testsetup struct {
fw *Firewall fw *Firewall
} }
func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup { func newSetup(t *testing.T, l *slog.Logger, myPrefixes ...netip.Prefix) testsetup {
c := dummyCert{ c := dummyCert{
name: "me", name: "me",
networks: myPrefixes, networks: myPrefixes,
@@ -1397,7 +1388,7 @@ func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testse
return newSetupFromCert(t, l, c) return newSetupFromCert(t, l, c)
} }
func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup { func newSetupFromCert(t *testing.T, l *slog.Logger, c dummyCert) testsetup {
myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable := new(bart.Lite)
for _, prefix := range c.Networks() { for _, prefix := range c.Networks() {
myVpnNetworksTable.Insert(prefix) myVpnNetworksTable.Insert(prefix)
@@ -1414,9 +1405,8 @@ func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
func TestFirewall_Drop_EnforceIPMatch(t *testing.T) { func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
t.Parallel() t.Parallel()
l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l := test.NewLoggerWithOutput(ob)
myPrefix := netip.MustParsePrefix("1.1.1.1/8") myPrefix := netip.MustParsePrefix("1.1.1.1/8")
// for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out // for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out

1
go.mod
View File

@@ -18,7 +18,6 @@ require (
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
github.com/prometheus/client_golang v1.23.2 github.com/prometheus/client_golang v1.23.2
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
github.com/sirupsen/logrus v1.9.4
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1

2
go.sum
View File

@@ -133,8 +133,6 @@ github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncj
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w=
github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw= github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw=

View File

@@ -2,11 +2,12 @@ package nebula
import ( import (
"bytes" "bytes"
"context"
"log/slog"
"net/netip" "net/netip"
"time" "time"
"github.com/flynn/noise" "github.com/flynn/noise"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
) )
@@ -18,8 +19,11 @@ import (
func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
err := f.handshakeManager.allocateIndex(hh) err := f.handshakeManager.allocateIndex(hh)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). f.l.Error("Failed to generate index",
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") "error", err,
"vpnAddrs", hh.hostinfo.vpnAddrs,
"handshake", m{"stage": 0, "style": "ix_psk0"},
)
return false return false
} }
@@ -39,28 +43,32 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
crt := cs.getCertificate(v) crt := cs.getCertificate(v)
if crt == nil { if crt == nil {
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs). f.l.Error("Unable to handshake with host because no certificate is available",
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). "vpnAddrs", hh.hostinfo.vpnAddrs,
WithField("certVersion", v). "handshake", m{"stage": 0, "style": "ix_psk0"},
Error("Unable to handshake with host because no certificate is available") "certVersion", v,
)
return false return false
} }
crtHs := cs.getHandshakeBytes(v) crtHs := cs.getHandshakeBytes(v)
if crtHs == nil { if crtHs == nil {
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs). f.l.Error("Unable to handshake with host because no certificate handshake bytes is available",
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). "vpnAddrs", hh.hostinfo.vpnAddrs,
WithField("certVersion", v). "handshake", m{"stage": 0, "style": "ix_psk0"},
Error("Unable to handshake with host because no certificate handshake bytes is available") "certVersion", v,
)
return false return false
} }
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX) ci, err := NewConnectionState(cs, crt, true, noise.HandshakeIX)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). f.l.Error("Failed to create connection state",
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). "error", err,
WithField("certVersion", v). "vpnAddrs", hh.hostinfo.vpnAddrs,
Error("Failed to create connection state") "handshake", m{"stage": 0, "style": "ix_psk0"},
"certVersion", v,
)
return false return false
} }
hh.hostinfo.ConnectionState = ci hh.hostinfo.ConnectionState = ci
@@ -76,9 +84,12 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
hsBytes, err := hs.Marshal() hsBytes, err := hs.Marshal()
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). f.l.Error("Failed to marshal handshake message",
WithField("certVersion", v). "error", err,
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") "vpnAddrs", hh.hostinfo.vpnAddrs,
"certVersion", v,
"handshake", m{"stage": 0, "style": "ix_psk0"},
)
return false return false
} }
@@ -86,8 +97,11 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
msg, _, _, err := ci.H.WriteMessage(h, hsBytes) msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). f.l.Error("Failed to call noise.WriteMessage",
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") "error", err,
"vpnAddrs", hh.hostinfo.vpnAddrs,
"handshake", m{"stage": 0, "style": "ix_psk0"},
)
return false return false
} }
@@ -104,18 +118,21 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
cs := f.pki.getCertState() cs := f.pki.getCertState()
crt := cs.GetDefaultCertificate() crt := cs.GetDefaultCertificate()
if crt == nil { if crt == nil {
f.l.WithField("from", via). f.l.Error("Unable to handshake with host because no certificate is available",
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). "from", via,
WithField("certVersion", cs.initiatingVersion). "handshake", m{"stage": 0, "style": "ix_psk0"},
Error("Unable to handshake with host because no certificate is available") "certVersion", cs.initiatingVersion,
)
return return
} }
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX) ci, err := NewConnectionState(cs, crt, false, noise.HandshakeIX)
if err != nil { if err != nil {
f.l.WithError(err).WithField("from", via). f.l.Error("Failed to create connection state",
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). "error", err,
Error("Failed to create connection state") "from", via,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return return
} }
@@ -124,26 +141,32 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:]) msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil { if err != nil {
f.l.WithError(err).WithField("from", via). f.l.Error("Failed to call noise.ReadMessage",
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). "error", err,
Error("Failed to call noise.ReadMessage") "from", via,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return return
} }
hs := &NebulaHandshake{} hs := &NebulaHandshake{}
err = hs.Unmarshal(msg) err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil { if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("from", via). f.l.Error("Failed unmarshal handshake message",
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). "error", err,
Error("Failed unmarshal handshake message") "from", via,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return return
} }
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
if err != nil { if err != nil {
f.l.WithError(err).WithField("from", via). f.l.Info("Handshake did not contain a certificate",
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). "error", err,
Info("Handshake did not contain a certificate") "from", via,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return return
} }
@@ -154,23 +177,30 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
fp = "<error generating certificate fingerprint>" fp = "<error generating certificate fingerprint>"
} }
e := f.l.WithError(err).WithField("from", via). attrs := []slog.Attr{
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). slog.Any("error", err),
WithField("certVpnNetworks", rc.Networks()). slog.Any("from", via),
WithField("certFingerprint", fp) slog.Any("handshake", m{"stage": 1, "style": "ix_psk0"}),
slog.Any("certVpnNetworks", rc.Networks()),
if f.l.Level >= logrus.DebugLevel { slog.String("certFingerprint", fp),
e = e.WithField("cert", rc) }
if f.l.Enabled(context.Background(), slog.LevelDebug) {
attrs = append(attrs, slog.Any("cert", rc))
} }
e.Info("Invalid certificate from host") // LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that
// callers grow conditionally, which has no pair-form equivalent.
//nolint:sloglint
f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...)
return return
} }
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) { if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
f.l.WithField("from", via). f.l.Info("public key mismatch between certificate and handshake",
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). "from", via,
WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake") "handshake", m{"stage": 1, "style": "ix_psk0"},
"cert", remoteCert,
)
return return
} }
@@ -178,12 +208,13 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us // We started off using the wrong certificate version, lets see if we can match the version that was sent to us
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version()) myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
if myCertOtherVersion == nil { if myCertOtherVersion == nil {
if f.l.Level >= logrus.DebugLevel { if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.WithError(err).WithFields(m{ f.l.Debug("Might be unable to handshake with host due to missing certificate version",
"from": via, "error", err,
"handshake": m{"stage": 1, "style": "ix_psk0"}, "from", via,
"cert": remoteCert, "handshake", m{"stage": 1, "style": "ix_psk0"},
}).Debug("Might be unable to handshake with host due to missing certificate version") "cert", remoteCert,
)
} }
} else { } else {
// Record the certificate we are actually using // Record the certificate we are actually using
@@ -192,10 +223,12 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
} }
if len(remoteCert.Certificate.Networks()) == 0 { if len(remoteCert.Certificate.Networks()) == 0 {
f.l.WithError(err).WithField("from", via). f.l.Info("No networks in certificate",
WithField("cert", remoteCert). "error", err,
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). "from", via,
Info("No networks in certificate") "cert", remoteCert,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return return
} }
@@ -209,12 +242,15 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
vpnAddrs := make([]netip.Addr, len(vpnNetworks)) vpnAddrs := make([]netip.Addr, len(vpnNetworks))
for i, network := range vpnNetworks { for i, network := range vpnNetworks {
if f.myVpnAddrsTable.Contains(network.Addr()) { if f.myVpnAddrsTable.Contains(network.Addr()) {
f.l.WithField("vpnNetworks", vpnNetworks).WithField("from", via). f.l.Error("Refusing to handshake with myself",
WithField("certName", certName). "vpnNetworks", vpnNetworks,
WithField("certVersion", certVersion). "from", via,
WithField("fingerprint", fingerprint). "certName", certName,
WithField("issuer", issuer). "certVersion", certVersion,
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") "fingerprint", fingerprint,
"issuer", issuer,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return return
} }
vpnAddrs[i] = network.Addr() vpnAddrs[i] = network.Addr()
@@ -226,20 +262,28 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
if !via.IsRelayed { if !via.IsRelayed {
// We only want to apply the remote allow list for direct tunnels here // We only want to apply the remote allow list for direct tunnels here
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) { if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). if f.l.Enabled(context.Background(), slog.LevelDebug) {
Debug("lighthouse.remote_allow_list denied incoming handshake") f.l.Debug("lighthouse.remote_allow_list denied incoming handshake",
"vpnAddrs", vpnAddrs,
"from", via,
)
}
return return
} }
} }
myIndex, err := generateIndex(f.l) myIndex, err := generateIndex(f.l)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via). f.l.Error("Failed to generate index",
WithField("certName", certName). "error", err,
WithField("certVersion", certVersion). "vpnAddrs", vpnAddrs,
WithField("fingerprint", fingerprint). "from", via,
WithField("issuer", issuer). "certName", certName,
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index") "certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return return
} }
@@ -257,18 +301,18 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
}, },
} }
msgRxL := f.l.WithFields(m{ msgRxL := f.l.With(
"vpnAddrs": vpnAddrs, "vpnAddrs", vpnAddrs,
"from": via, "from", via,
"certName": certName, "certName", certName,
"certVersion": certVersion, "certVersion", certVersion,
"fingerprint": fingerprint, "fingerprint", fingerprint,
"issuer": issuer, "issuer", issuer,
"initiatorIndex": hs.Details.InitiatorIndex, "initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex": hs.Details.ResponderIndex, "responderIndex", hs.Details.ResponderIndex,
"remoteIndex": h.RemoteIndex, "remoteIndex", h.RemoteIndex,
"handshake": m{"stage": 1, "style": "ix_psk0"}, "handshake", m{"stage": 1, "style": "ix_psk0"},
}) )
if anyVpnAddrsInCommon { if anyVpnAddrsInCommon {
msgRxL.Info("Handshake message received") msgRxL.Info("Handshake message received")
@@ -280,8 +324,9 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
hs.Details.ResponderIndex = myIndex hs.Details.ResponderIndex = myIndex
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version()) hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
if hs.Details.Cert == nil { if hs.Details.Cert == nil {
msgRxL.WithField("myCertVersion", ci.myCert.Version()). msgRxL.Error("Unable to handshake with host because no certificate handshake bytes is available",
Error("Unable to handshake with host because no certificate handshake bytes is available") "myCertVersion", ci.myCert.Version(),
)
return return
} }
@@ -291,32 +336,43 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
hsBytes, err := hs.Marshal() hsBytes, err := hs.Marshal()
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). f.l.Error("Failed to marshal handshake message",
WithField("certName", certName). "error", err,
WithField("certVersion", certVersion). "vpnAddrs", hostinfo.vpnAddrs,
WithField("fingerprint", fingerprint). "from", via,
WithField("issuer", issuer). "certName", certName,
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message") "certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return return
} }
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2) nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes) msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). f.l.Error("Failed to call noise.WriteMessage",
WithField("certName", certName). "error", err,
WithField("certVersion", certVersion). "vpnAddrs", hostinfo.vpnAddrs,
WithField("fingerprint", fingerprint). "from", via,
WithField("issuer", issuer). "certName", certName,
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") "certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return return
} else if dKey == nil || eKey == nil { } else if dKey == nil || eKey == nil {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). f.l.Error("Noise did not arrive at a key",
WithField("certName", certName). "vpnAddrs", hostinfo.vpnAddrs,
WithField("certVersion", certVersion). "from", via,
WithField("fingerprint", fingerprint). "certName", certName,
WithField("issuer", issuer). "certVersion", certVersion,
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key") "fingerprint", fingerprint,
"issuer", issuer,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return return
} }
@@ -358,13 +414,20 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
if !via.IsRelayed { if !via.IsRelayed {
err := f.outside.WriteTo(msg, via.UdpAddr) err := f.outside.WriteTo(msg, via.UdpAddr)
if err != nil { if err != nil {
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via). f.l.Error("Failed to send handshake message",
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). "vpnAddrs", existing.vpnAddrs,
WithError(err).Error("Failed to send handshake message") "from", via,
"handshake", m{"stage": 2, "style": "ix_psk0"},
"cached", true,
"error", err,
)
} else { } else {
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via). f.l.Info("Handshake message sent",
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). "vpnAddrs", existing.vpnAddrs,
Info("Handshake message sent") "from", via,
"handshake", m{"stage": 2, "style": "ix_psk0"},
"cached", true,
)
} }
return return
} else { } else {
@@ -374,50 +437,67 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
} }
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]). f.l.Info("Handshake message sent",
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). "vpnAddrs", existing.vpnAddrs,
Info("Handshake message sent") "relay", via.relayHI.vpnAddrs[0],
"handshake", m{"stage": 2, "style": "ix_psk0"},
"cached", true,
)
return return
} }
case ErrExistingHostInfo: case ErrExistingHostInfo:
// This means there was an existing tunnel and this handshake was older than the one we are currently based on // This means there was an existing tunnel and this handshake was older than the one we are currently based on
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). f.l.Info("Handshake too old",
WithField("certName", certName). "vpnAddrs", vpnAddrs,
WithField("certVersion", certVersion). "from", via,
WithField("oldHandshakeTime", existing.lastHandshakeTime). "certName", certName,
WithField("newHandshakeTime", hostinfo.lastHandshakeTime). "certVersion", certVersion,
WithField("fingerprint", fingerprint). "oldHandshakeTime", existing.lastHandshakeTime,
WithField("issuer", issuer). "newHandshakeTime", hostinfo.lastHandshakeTime,
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). "fingerprint", fingerprint,
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). "issuer", issuer,
Info("Handshake too old") "initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
return return
case ErrLocalIndexCollision: case ErrLocalIndexCollision:
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). f.l.Error("Failed to add HostInfo due to localIndex collision",
WithField("certName", certName). "vpnAddrs", vpnAddrs,
WithField("certVersion", certVersion). "from", via,
WithField("fingerprint", fingerprint). "certName", certName,
WithField("issuer", issuer). "certVersion", certVersion,
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). "fingerprint", fingerprint,
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). "issuer", issuer,
WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs). "initiatorIndex", hs.Details.InitiatorIndex,
Error("Failed to add HostInfo due to localIndex collision") "responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 1, "style": "ix_psk0"},
"localIndex", hostinfo.localIndexId,
"collision", existing.vpnAddrs,
)
return return
default: default:
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
// And we forget to update it here // And we forget to update it here
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via). f.l.Error("Failed to add HostInfo to HostMap",
WithField("certName", certName). "error", err,
WithField("certVersion", certVersion). "vpnAddrs", vpnAddrs,
WithField("fingerprint", fingerprint). "from", via,
WithField("issuer", issuer). "certName", certName,
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). "certVersion", certVersion,
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). "fingerprint", fingerprint,
Error("Failed to add HostInfo to HostMap") "issuer", issuer,
"initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return return
} }
} }
@@ -426,15 +506,20 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
if !via.IsRelayed { if !via.IsRelayed {
err = f.outside.WriteTo(msg, via.UdpAddr) err = f.outside.WriteTo(msg, via.UdpAddr)
log := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). log := f.l.With(
WithField("certName", certName). "vpnAddrs", vpnAddrs,
WithField("certVersion", certVersion). "from", via,
WithField("fingerprint", fingerprint). "certName", certName,
WithField("issuer", issuer). "certVersion", certVersion,
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). "fingerprint", fingerprint,
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) "issuer", issuer,
"initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
if err != nil { if err != nil {
log.WithError(err).Error("Failed to send handshake") log.Error("Failed to send handshake", "error", err)
} else { } else {
log.Info("Handshake message sent") log.Info("Handshake message sent")
} }
@@ -448,14 +533,18 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
// it's correctly marked as working. // it's correctly marked as working.
via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established) via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established)
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]). f.l.Info("Handshake message sent",
WithField("certName", certName). "vpnAddrs", vpnAddrs,
WithField("certVersion", certVersion). "relay", via.relayHI.vpnAddrs[0],
WithField("fingerprint", fingerprint). "certName", certName,
WithField("issuer", issuer). "certVersion", certVersion,
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). "fingerprint", fingerprint,
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). "issuer", issuer,
Info("Handshake message sent") "initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
} }
f.connectionManager.AddTrafficWatch(hostinfo) f.connectionManager.AddTrafficWatch(hostinfo)
@@ -483,7 +572,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
if !via.IsRelayed { if !via.IsRelayed {
// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list. // The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) { if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake") if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("lighthouse.remote_allow_list denied incoming handshake",
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
)
}
return false return false
} }
} }
@@ -491,18 +585,24 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
ci := hostinfo.ConnectionState ci := hostinfo.ConnectionState
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:]) msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). f.l.Error("Failed to call noise.ReadMessage",
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). "error", err,
Error("Failed to call noise.ReadMessage") "vpnAddrs", hostinfo.vpnAddrs,
"from", via,
"handshake", m{"stage": 2, "style": "ix_psk0"},
"header", h,
)
// We don't want to tear down the connection on a bad ReadMessage because it could be an attacker trying // We don't want to tear down the connection on a bad ReadMessage because it could be an attacker trying
// to DOS us. Every other error condition after should to allow a possible good handshake to complete in the // to DOS us. Every other error condition after should to allow a possible good handshake to complete in the
// near future // near future
return false return false
} else if dKey == nil || eKey == nil { } else if dKey == nil || eKey == nil {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). f.l.Error("Noise did not arrive at a key",
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). "vpnAddrs", hostinfo.vpnAddrs,
Error("Noise did not arrive at a key") "from", via,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
// This should be impossible in IX but just in case, if we get here then there is no chance to recover // This should be impossible in IX but just in case, if we get here then there is no chance to recover
// the handshake state machine. Tear it down // the handshake state machine. Tear it down
@@ -512,8 +612,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
hs := &NebulaHandshake{} hs := &NebulaHandshake{}
err = hs.Unmarshal(msg) err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil { if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). f.l.Error("Failed unmarshal handshake message",
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") "error", err,
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
return true return true
@@ -521,10 +625,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
if err != nil { if err != nil {
f.l.WithError(err).WithField("from", via). f.l.Info("Handshake did not contain a certificate",
WithField("vpnAddrs", hostinfo.vpnAddrs). "error", err,
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). "from", via,
Info("Handshake did not contain a certificate") "vpnAddrs", hostinfo.vpnAddrs,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
return true return true
} }
@@ -535,32 +641,41 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
fp = "<error generating certificate fingerprint>" fp = "<error generating certificate fingerprint>"
} }
e := f.l.WithError(err).WithField("from", via). attrs := []slog.Attr{
WithField("vpnAddrs", hostinfo.vpnAddrs). slog.Any("error", err),
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). slog.Any("from", via),
WithField("certFingerprint", fp). slog.Any("vpnAddrs", hostinfo.vpnAddrs),
WithField("certVpnNetworks", rc.Networks()) slog.Any("handshake", m{"stage": 2, "style": "ix_psk0"}),
slog.String("certFingerprint", fp),
if f.l.Level >= logrus.DebugLevel { slog.Any("certVpnNetworks", rc.Networks()),
e = e.WithField("cert", rc) }
if f.l.Enabled(context.Background(), slog.LevelDebug) {
attrs = append(attrs, slog.Any("cert", rc))
} }
e.Info("Invalid certificate from host") // LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that
// callers grow conditionally, which has no pair-form equivalent.
//nolint:sloglint
f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...)
return true return true
} }
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) { if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
f.l.WithField("from", via). f.l.Info("public key mismatch between certificate and handshake",
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). "from", via,
WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake") "handshake", m{"stage": 2, "style": "ix_psk0"},
"cert", remoteCert,
)
return true return true
} }
if len(remoteCert.Certificate.Networks()) == 0 { if len(remoteCert.Certificate.Networks()) == 0 {
f.l.WithError(err).WithField("from", via). f.l.Info("No networks in certificate",
WithField("vpnAddrs", hostinfo.vpnAddrs). "error", err,
WithField("cert", remoteCert). "from", via,
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). "vpnAddrs", hostinfo.vpnAddrs,
Info("No networks in certificate") "cert", remoteCert,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
return true return true
} }
@@ -601,12 +716,14 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
// Ensure the right host responded // Ensure the right host responded
if !correctHostResponded { if !correctHostResponded {
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). f.l.Info("Incorrect host responded to handshake",
WithField("from", via). "intendedVpnAddrs", hostinfo.vpnAddrs,
WithField("certName", certName). "haveVpnNetworks", vpnNetworks,
WithField("certVersion", certVersion). "from", via,
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). "certName", certName,
Info("Incorrect host responded to handshake") "certVersion", certVersion,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
// Release our old handshake from pending, it should not continue // Release our old handshake from pending, it should not continue
f.handshakeManager.DeleteHostInfo(hostinfo) f.handshakeManager.DeleteHostInfo(hostinfo)
@@ -618,10 +735,11 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
newHH.hostinfo.remotes = hostinfo.remotes newHH.hostinfo.remotes = hostinfo.remotes
newHH.hostinfo.remotes.BlockRemote(via) newHH.hostinfo.remotes.BlockRemote(via)
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()). f.l.Info("Blocked addresses for handshakes",
WithField("vpnNetworks", vpnNetworks). "blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes(),
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())). "vpnNetworks", vpnNetworks,
Info("Blocked addresses for handshakes") "remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges()),
)
// Swap the packet store to benefit the original intended recipient // Swap the packet store to benefit the original intended recipient
newHH.packetStore = hh.packetStore newHH.packetStore = hh.packetStore
@@ -639,15 +757,20 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
ci.window.Update(f.l, 2) ci.window.Update(f.l, 2)
duration := time.Since(hh.startTime).Nanoseconds() duration := time.Since(hh.startTime).Nanoseconds()
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). msgRxL := f.l.With(
WithField("certName", certName). "vpnAddrs", vpnAddrs,
WithField("certVersion", certVersion). "from", via,
WithField("fingerprint", fingerprint). "certName", certName,
WithField("issuer", issuer). "certVersion", certVersion,
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). "fingerprint", fingerprint,
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). "issuer", issuer,
WithField("durationNs", duration). "initiatorIndex", hs.Details.InitiatorIndex,
WithField("sentCachedPackets", len(hh.packetStore)) "responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 2, "style": "ix_psk0"},
"durationNs", duration,
"sentCachedPackets", len(hh.packetStore),
)
if anyVpnAddrsInCommon { if anyVpnAddrsInCommon {
msgRxL.Info("Handshake message received") msgRxL.Info("Handshake message received")
} else { } else {
@@ -663,8 +786,10 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
f.handshakeManager.Complete(hostinfo, f) f.handshakeManager.Complete(hostinfo, f)
f.connectionManager.AddTrafficWatch(hostinfo) f.connectionManager.AddTrafficWatch(hostinfo)
if f.l.Level >= logrus.DebugLevel { if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore)) hostinfo.logger(f.l).Debug("Sending stored packets",
"count", len(hh.packetStore),
)
} }
if len(hh.packetStore) > 0 { if len(hh.packetStore) > 0 {

View File

@@ -6,13 +6,13 @@ import (
"crypto/rand" "crypto/rand"
"encoding/binary" "encoding/binary"
"errors" "errors"
"log/slog"
"net/netip" "net/netip"
"slices" "slices"
"sync" "sync"
"time" "time"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
@@ -59,7 +59,7 @@ type HandshakeManager struct {
metricInitiated metrics.Counter metricInitiated metrics.Counter
metricTimedOut metrics.Counter metricTimedOut metrics.Counter
f *Interface f *Interface
l *logrus.Logger l *slog.Logger
// can be used to trigger outbound handshake for the given vpnIp // can be used to trigger outbound handshake for the given vpnIp
trigger chan netip.Addr trigger chan netip.Addr
@@ -78,32 +78,32 @@ type HandshakeHostInfo struct {
hostinfo *HostInfo hostinfo *HostInfo
} }
func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { func (hh *HandshakeHostInfo) cachePacket(l *slog.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
if len(hh.packetStore) < 100 { if len(hh.packetStore) < 100 {
tempPacket := make([]byte, len(packet)) tempPacket := make([]byte, len(packet))
copy(tempPacket, packet) copy(tempPacket, packet)
hh.packetStore = append(hh.packetStore, &cachedPacket{t, st, f, tempPacket}) hh.packetStore = append(hh.packetStore, &cachedPacket{t, st, f, tempPacket})
if l.Level >= logrus.DebugLevel { if l.Enabled(context.Background(), slog.LevelDebug) {
hh.hostinfo.logger(l). hh.hostinfo.logger(l).Debug("Packet store",
WithField("length", len(hh.packetStore)). "length", len(hh.packetStore),
WithField("stored", true). "stored", true,
Debugf("Packet store") )
} }
} else { } else {
m.dropped.Inc(1) m.dropped.Inc(1)
if l.Level >= logrus.DebugLevel { if l.Enabled(context.Background(), slog.LevelDebug) {
hh.hostinfo.logger(l). hh.hostinfo.logger(l).Debug("Packet store",
WithField("length", len(hh.packetStore)). "length", len(hh.packetStore),
WithField("stored", false). "stored", false,
Debugf("Packet store") )
} }
} }
} }
func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager { func NewHandshakeManager(l *slog.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
return &HandshakeManager{ return &HandshakeManager{
vpnIps: map[netip.Addr]*HandshakeHostInfo{}, vpnIps: map[netip.Addr]*HandshakeHostInfo{},
indexes: map[uint32]*HandshakeHostInfo{}, indexes: map[uint32]*HandshakeHostInfo{},
@@ -140,7 +140,7 @@ func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *head
// First remote allow list check before we know the vpnIp // First remote allow list check before we know the vpnIp
if !via.IsRelayed { if !via.IsRelayed {
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) { if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) {
hm.l.WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake") hm.l.Debug("lighthouse.remote_allow_list denied incoming handshake", "from", via)
return return
} }
} }
@@ -183,12 +183,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
hostinfo := hh.hostinfo hostinfo := hh.hostinfo
// If we are out of time, clean up // If we are out of time, clean up
if hh.counter >= hm.config.retries { if hh.counter >= hm.config.retries {
hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())). hh.hostinfo.logger(hm.l).Info("Handshake timed out",
WithField("initiatorIndex", hh.hostinfo.localIndexId). "udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()),
WithField("remoteIndex", hh.hostinfo.remoteIndexId). "initiatorIndex", hh.hostinfo.localIndexId,
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). "remoteIndex", hh.hostinfo.remoteIndexId,
WithField("durationNs", time.Since(hh.startTime).Nanoseconds()). "handshake", m{"stage": 1, "style": "ix_psk0"},
Info("Handshake timed out") "durationNs", time.Since(hh.startTime).Nanoseconds(),
)
hm.metricTimedOut.Inc(1) hm.metricTimedOut.Inc(1)
hm.DeleteHostInfo(hostinfo) hm.DeleteHostInfo(hostinfo)
return return
@@ -241,10 +242,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr) err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
if err != nil { if err != nil {
hostinfo.logger(hm.l).WithField("udpAddr", addr). hostinfo.logger(hm.l).Error("Failed to send handshake message",
WithField("initiatorIndex", hostinfo.localIndexId). "udpAddr", addr,
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). "initiatorIndex", hostinfo.localIndexId,
WithError(err).Error("Failed to send handshake message") "handshake", m{"stage": 1, "style": "ix_psk0"},
"error", err,
)
} else { } else {
sentTo = append(sentTo, addr) sentTo = append(sentTo, addr)
@@ -254,19 +257,21 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
// Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout, // Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout,
// so only log when the list of remotes has changed // so only log when the list of remotes has changed
if remotesHaveChanged { if remotesHaveChanged {
hostinfo.logger(hm.l).WithField("udpAddrs", sentTo). hostinfo.logger(hm.l).Info("Handshake message sent",
WithField("initiatorIndex", hostinfo.localIndexId). "udpAddrs", sentTo,
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). "initiatorIndex", hostinfo.localIndexId,
Info("Handshake message sent") "handshake", m{"stage": 1, "style": "ix_psk0"},
} else if hm.l.Level >= logrus.DebugLevel { )
hostinfo.logger(hm.l).WithField("udpAddrs", sentTo). } else if hm.l.Enabled(context.Background(), slog.LevelDebug) {
WithField("initiatorIndex", hostinfo.localIndexId). hostinfo.logger(hm.l).Debug("Handshake message sent",
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). "udpAddrs", sentTo,
Debug("Handshake message sent") "initiatorIndex", hostinfo.localIndexId,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
} }
if hm.config.useRelays && len(hostinfo.remotes.relays) > 0 { if hm.config.useRelays && len(hostinfo.remotes.relays) > 0 {
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts") hostinfo.logger(hm.l).Info("Attempt to relay through hosts", "relays", hostinfo.remotes.relays)
// Send a RelayRequest to all known Relay IP's // Send a RelayRequest to all known Relay IP's
for _, relay := range hostinfo.remotes.relays { for _, relay := range hostinfo.remotes.relays {
// Don't relay through the host I'm trying to connect to // Don't relay through the host I'm trying to connect to
@@ -281,7 +286,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay) relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay)
if relayHostInfo == nil || !relayHostInfo.remote.IsValid() { if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") hostinfo.logger(hm.l).Info("Establish tunnel to relay target", "relay", relay.String())
hm.f.Handshake(relay) hm.f.Handshake(relay)
continue continue
} }
@@ -292,7 +297,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
if relayHostInfo.remote.IsValid() { if relayHostInfo.remote.IsValid() {
idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested) idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
if err != nil { if err != nil {
hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap") hostinfo.logger(hm.l).Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err)
} }
m := NebulaControl{ m := NebulaControl{
@@ -326,17 +331,15 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
msg, err := m.Marshal() msg, err := m.Marshal()
if err != nil { if err != nil {
hostinfo.logger(hm.l). hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err)
WithError(err).
Error("Failed to marshal Control message to create relay")
} else { } else {
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{ hm.l.Info("send CreateRelayRequest",
"relayFrom": hm.f.myVpnAddrs[0], "relayFrom", hm.f.myVpnAddrs[0],
"relayTo": vpnIp, "relayTo", vpnIp,
"initiatorRelayIndex": idx, "initiatorRelayIndex", idx,
"relay": relay}). "relay", relay,
Info("send CreateRelayRequest") )
} }
} }
continue continue
@@ -344,14 +347,14 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
switch existingRelay.State { switch existingRelay.State {
case Established: case Established:
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay") hostinfo.logger(hm.l).Info("Send handshake via relay", "relay", relay.String())
hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
case Disestablished: case Disestablished:
// Mark this relay as 'requested' // Mark this relay as 'requested'
relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested) relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested)
fallthrough fallthrough
case Requested: case Requested:
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") hostinfo.logger(hm.l).Info("Re-send CreateRelay request", "relay", relay.String())
// Re-send the CreateRelay request, in case the previous one was lost. // Re-send the CreateRelay request, in case the previous one was lost.
m := NebulaControl{ m := NebulaControl{
Type: NebulaControl_CreateRelayRequest, Type: NebulaControl_CreateRelayRequest,
@@ -383,28 +386,26 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
} }
msg, err := m.Marshal() msg, err := m.Marshal()
if err != nil { if err != nil {
hostinfo.logger(hm.l). hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err)
WithError(err).
Error("Failed to marshal Control message to create relay")
} else { } else {
// This must send over the hostinfo, not over hm.Hosts[ip] // This must send over the hostinfo, not over hm.Hosts[ip]
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{ hm.l.Info("send CreateRelayRequest",
"relayFrom": hm.f.myVpnAddrs[0], "relayFrom", hm.f.myVpnAddrs[0],
"relayTo": vpnIp, "relayTo", vpnIp,
"initiatorRelayIndex": existingRelay.LocalIndex, "initiatorRelayIndex", existingRelay.LocalIndex,
"relay": relay}). "relay", relay,
Info("send CreateRelayRequest") )
} }
case PeerRequested: case PeerRequested:
// PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case. // PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case.
fallthrough fallthrough
default: default:
hostinfo.logger(hm.l). hostinfo.logger(hm.l).Error("Relay unexpected state",
WithField("vpnIp", vpnIp). "vpnIp", vpnIp,
WithField("state", existingRelay.State). "state", existingRelay.State,
WithField("relay", relay). "relay", relay,
Errorf("Relay unexpected state") )
} }
} }
@@ -549,9 +550,10 @@ func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] { if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] {
// We have a collision, but this can happen since we can't control // We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note. // the remote ID. Just log about the situation as a note.
hostinfo.logger(hm.l). hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex",
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs). "remoteIndex", hostinfo.remoteIndexId,
Info("New host shadows existing host remoteIndex") "collision", existingRemoteIndex.vpnAddrs,
)
} }
hm.mainHostMap.unlockedAddHostInfo(hostinfo, f) hm.mainHostMap.unlockedAddHostInfo(hostinfo, f)
@@ -571,9 +573,10 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
if found && existingRemoteIndex != nil { if found && existingRemoteIndex != nil {
// We have a collision, but this can happen since we can't control // We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note. // the remote ID. Just log about the situation as a note.
hostinfo.logger(hm.l). hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex",
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs). "remoteIndex", hostinfo.remoteIndexId,
Info("New host shadows existing host remoteIndex") "collision", existingRemoteIndex.vpnAddrs,
)
} }
// We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap. // We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap.
@@ -629,10 +632,11 @@ func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
hm.indexes = map[uint32]*HandshakeHostInfo{} hm.indexes = map[uint32]*HandshakeHostInfo{}
} }
if hm.l.Level >= logrus.DebugLevel { if hm.l.Enabled(context.Background(), slog.LevelDebug) {
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps), hm.l.Debug("Pending hostmap hostInfo deleted",
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). "hostMap", m{"mapTotalSize": len(hm.vpnIps),
Debug("Pending hostmap hostInfo deleted") "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId},
)
} }
} }
@@ -700,7 +704,7 @@ func (hm *HandshakeManager) EmitStats() {
// Utility functions below // Utility functions below
func generateIndex(l *logrus.Logger) (uint32, error) { func generateIndex(l *slog.Logger) (uint32, error) {
b := make([]byte, 4) b := make([]byte, 4)
// Let zero mean we don't know the ID, so don't generate zero // Let zero mean we don't know the ID, so don't generate zero
@@ -708,16 +712,15 @@ func generateIndex(l *logrus.Logger) (uint32, error) {
for index == 0 { for index == 0 {
_, err := rand.Read(b) _, err := rand.Read(b)
if err != nil { if err != nil {
l.Errorln(err) l.Error("Failed to generate index", "error", err)
return 0, err return 0, err
} }
index = binary.BigEndian.Uint32(b) index = binary.BigEndian.Uint32(b)
} }
if l.Level >= logrus.DebugLevel { if l.Enabled(context.Background(), slog.LevelDebug) {
l.WithField("index", index). l.Debug("Generated index", "index", index)
Debug("Generated index")
} }
return index, nil return index, nil
} }

View File

@@ -1,9 +1,11 @@
package nebula package nebula
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"net" "net"
"net/netip" "net/netip"
"slices" "slices"
@@ -13,10 +15,10 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/logging"
) )
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
@@ -60,7 +62,7 @@ type HostMap struct {
RemoteIndexes map[uint32]*HostInfo RemoteIndexes map[uint32]*HostInfo
Hosts map[netip.Addr]*HostInfo Hosts map[netip.Addr]*HostInfo
preferredRanges atomic.Pointer[[]netip.Prefix] preferredRanges atomic.Pointer[[]netip.Prefix]
l *logrus.Logger l *slog.Logger
} }
// For synchronization, treat the pointed-to Relay struct as immutable. To edit the Relay // For synchronization, treat the pointed-to Relay struct as immutable. To edit the Relay
@@ -313,7 +315,7 @@ type cachedPacketMetrics struct {
dropped metrics.Counter dropped metrics.Counter
} }
func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap { func NewHostMapFromConfig(l *slog.Logger, c *config.C) *HostMap {
hm := newHostMap(l) hm := newHostMap(l)
hm.reload(c, true) hm.reload(c, true)
@@ -321,13 +323,12 @@ func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
hm.reload(c, false) hm.reload(c, false)
}) })
l.WithField("preferredRanges", hm.GetPreferredRanges()). l.Info("Main HostMap created", "preferredRanges", hm.GetPreferredRanges())
Info("Main HostMap created")
return hm return hm
} }
func newHostMap(l *logrus.Logger) *HostMap { func newHostMap(l *slog.Logger) *HostMap {
return &HostMap{ return &HostMap{
Indexes: map[uint32]*HostInfo{}, Indexes: map[uint32]*HostInfo{},
Relays: map[uint32]*HostInfo{}, Relays: map[uint32]*HostInfo{},
@@ -346,7 +347,10 @@ func (hm *HostMap) reload(c *config.C, initial bool) {
preferredRange, err := netip.ParsePrefix(rawPreferredRange) preferredRange, err := netip.ParsePrefix(rawPreferredRange)
if err != nil { if err != nil {
hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring") hm.l.Warn("Failed to parse preferred ranges, ignoring",
"error", err,
"range", rawPreferredRanges,
)
continue continue
} }
@@ -355,7 +359,10 @@ func (hm *HostMap) reload(c *config.C, initial bool) {
oldRanges := hm.preferredRanges.Swap(&preferredRanges) oldRanges := hm.preferredRanges.Swap(&preferredRanges)
if !initial { if !initial {
hm.l.WithField("oldPreferredRanges", *oldRanges).WithField("newPreferredRanges", preferredRanges).Info("preferred_ranges changed") hm.l.Info("preferred_ranges changed",
"oldPreferredRanges", *oldRanges,
"newPreferredRanges", preferredRanges,
)
} }
} }
} }
@@ -488,10 +495,11 @@ func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Ad
hm.Indexes = map[uint32]*HostInfo{} hm.Indexes = map[uint32]*HostInfo{}
} }
if hm.l.Level >= logrus.DebugLevel { if hm.l.Enabled(context.Background(), slog.LevelDebug) {
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts), hm.l.Debug("Hostmap hostInfo deleted",
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). "hostMap", m{"mapTotalSize": len(hm.Hosts),
Debug("Hostmap hostInfo deleted") "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId},
)
} }
if isLastHostinfo { if isLastHostinfo {
@@ -615,10 +623,11 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
hm.Indexes[hostinfo.localIndexId] = hostinfo hm.Indexes[hostinfo.localIndexId] = hostinfo
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
if hm.l.Level >= logrus.DebugLevel { if hm.l.Enabled(context.Background(), slog.LevelDebug) {
hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts), hm.l.Debug("Hostmap vpnIp added",
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}). "hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
Debug("Hostmap vpnIp added") "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}},
)
} }
} }
@@ -784,18 +793,21 @@ func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certifica
} }
} }
func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { // logger returns a derived slog.Logger with per-hostinfo fields pre-bound.
func (i *HostInfo) logger(l *slog.Logger) *slog.Logger {
if i == nil { if i == nil {
return logrus.NewEntry(l) return l
} }
li := l.WithField("vpnAddrs", i.vpnAddrs). li := l.With(
WithField("localIndex", i.localIndexId). "vpnAddrs", i.vpnAddrs,
WithField("remoteIndex", i.remoteIndexId) "localIndex", i.localIndexId,
"remoteIndex", i.remoteIndexId,
)
if connState := i.ConnectionState; connState != nil { if connState := i.ConnectionState; connState != nil {
if peerCert := connState.peerCert; peerCert != nil { if peerCert := connState.peerCert; peerCert != nil {
li = li.WithField("certName", peerCert.Certificate.Name()) li = li.With("certName", peerCert.Certificate.Name())
} }
} }
@@ -804,14 +816,17 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
// Utility functions // Utility functions
func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { func localAddrs(l *slog.Logger, allowList *LocalAllowList) []netip.Addr {
//FIXME: This function is pretty garbage //FIXME: This function is pretty garbage
var finalAddrs []netip.Addr var finalAddrs []netip.Addr
ifaces, _ := net.Interfaces() ifaces, _ := net.Interfaces()
for _, i := range ifaces { for _, i := range ifaces {
allow := allowList.AllowName(i.Name) allow := allowList.AllowName(i.Name)
if l.Level >= logrus.TraceLevel { if l.Enabled(context.Background(), logging.LevelTrace) {
l.WithField("interfaceName", i.Name).WithField("allow", allow).Trace("localAllowList.AllowName") l.Log(context.Background(), logging.LevelTrace, "localAllowList.AllowName",
"interfaceName", i.Name,
"allow", allow,
)
} }
if !allow { if !allow {
@@ -829,8 +844,8 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
} }
if !addr.IsValid() { if !addr.IsValid() {
if l.Level >= logrus.DebugLevel { if l.Enabled(context.Background(), slog.LevelDebug) {
l.WithField("localAddr", rawAddr).Debug("addr was invalid") l.Debug("addr was invalid", "localAddr", rawAddr)
} }
continue continue
} }
@@ -838,8 +853,11 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false { if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false {
isAllowed := allowList.Allow(addr) isAllowed := allowList.Allow(addr)
if l.Level >= logrus.TraceLevel { if l.Enabled(context.Background(), logging.LevelTrace) {
l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow") l.Log(context.Background(), logging.LevelTrace, "localAllowList.Allow",
"localAddr", addr,
"allowed", isAllowed,
)
} }
if !isAllowed { if !isAllowed {
continue continue

View File

@@ -196,7 +196,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
func TestHostMap_reload(t *testing.T) { func TestHostMap_reload(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
c := config.NewC(l) c := config.NewC(test.NewLogger())
hm := NewHostMapFromConfig(l, c) hm := NewHostMapFromConfig(l, c)

119
inside.go
View File

@@ -1,9 +1,10 @@
package nebula package nebula
import ( import (
"context"
"log/slog"
"net/netip" "net/netip"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
@@ -14,8 +15,11 @@ import (
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
err := newPacket(packet, false, fwPacket) err := newPacket(packet, false, fwPacket)
if err != nil { if err != nil {
if f.l.Level >= logrus.DebugLevel { if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err) f.l.Debug("Error while validating outbound packet",
"packet", packet,
"error", err,
)
} }
return return
} }
@@ -35,7 +39,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
if immediatelyForwardToSelf { if immediatelyForwardToSelf {
_, err := f.readers[q].Write(packet) _, err := f.readers[q].Write(packet)
if err != nil { if err != nil {
f.l.WithError(err).Error("Failed to forward to tun") f.l.Error("Failed to forward to tun", "error", err)
} }
} }
// Otherwise, drop. On linux, we should never see these packets - Linux // Otherwise, drop. On linux, we should never see these packets - Linux
@@ -54,10 +58,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
if hostinfo == nil { if hostinfo == nil {
f.rejectInside(packet, out, q) f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel { if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.WithField("vpnAddr", fwPacket.RemoteAddr). f.l.Debug("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks",
WithField("fwPacket", fwPacket). "vpnAddr", fwPacket.RemoteAddr,
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks") "fwPacket", fwPacket,
)
} }
return return
} }
@@ -72,11 +77,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
} else { } else {
f.rejectInside(packet, out, q) f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel { if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l). hostinfo.logger(f.l).Debug("dropping outbound packet",
WithField("fwPacket", fwPacket). "fwPacket", fwPacket,
WithField("reason", dropReason). "reason", dropReason,
Debugln("dropping outbound packet") )
} }
} }
} }
@@ -93,7 +98,7 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
_, err := f.readers[q].Write(out) _, err := f.readers[q].Write(out)
if err != nil { if err != nil {
f.l.WithError(err).Error("Failed to write to tun") f.l.Error("Failed to write to tun", "error", err)
} }
} }
@@ -108,11 +113,11 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
} }
if len(out) > iputil.MaxRejectPacketSize { if len(out) > iputil.MaxRejectPacketSize {
if f.l.GetLevel() >= logrus.InfoLevel { if f.l.Enabled(context.Background(), slog.LevelInfo) {
f.l. f.l.Info("rejectOutside: packet too big, not sending",
WithField("packet", packet). "packet", packet,
WithField("outPacket", out). "outPacket", out,
Info("rejectOutside: packet too big, not sending") )
} }
return return
} }
@@ -184,10 +189,11 @@ func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cac
// This would also need to interact with unsafe_route updates through reloading the config or // This would also need to interact with unsafe_route updates through reloading the config or
// use of the use_system_route_table option // use of the use_system_route_table option
if f.l.Level >= logrus.DebugLevel { if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.WithField("destination", destinationAddr). f.l.Debug("Calculated gateway for ECMP not available, attempting other gateways",
WithField("originalGateway", gatewayAddr). "destination", destinationAddr,
Debugln("Calculated gateway for ECMP not available, attempting other gateways") "originalGateway", gatewayAddr,
)
} }
for i := range gateways { for i := range gateways {
@@ -213,17 +219,18 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
fp := &firewall.Packet{} fp := &firewall.Packet{}
err := newPacket(p, false, fp) err := newPacket(p, false, fp)
if err != nil { if err != nil {
f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err) f.l.Warn("error while parsing outgoing packet for firewall check", "error", err)
return return
} }
// check if packet is in outbound fw rules // check if packet is in outbound fw rules
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil) dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
if dropReason != nil { if dropReason != nil {
if f.l.Level >= logrus.DebugLevel { if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.WithField("fwPacket", fp). f.l.Debug("dropping cached packet",
WithField("reason", dropReason). "fwPacket", fp,
Debugln("dropping cached packet") "reason", dropReason,
)
} }
return return
} }
@@ -239,9 +246,10 @@ func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.Message
}) })
if hostInfo == nil { if hostInfo == nil {
if f.l.Level >= logrus.DebugLevel { if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.WithField("vpnAddr", vpnAddr). f.l.Debug("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes",
Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes") "vpnAddr", vpnAddr,
)
} }
return return
} }
@@ -297,12 +305,12 @@ func (f *Interface) SendVia(via *HostInfo,
if noiseutil.EncryptLockNeeded { if noiseutil.EncryptLockNeeded {
via.ConnectionState.writeLock.Unlock() via.ConnectionState.writeLock.Unlock()
} }
via.logger(f.l). via.logger(f.l).Error("SendVia out buffer not large enough for relay",
WithField("outCap", cap(out)). "outCap", cap(out),
WithField("payloadLen", len(ad)). "payloadLen", len(ad),
WithField("headerLen", len(out)). "headerLen", len(out),
WithField("cipherOverhead", via.ConnectionState.eKey.Overhead()). "cipherOverhead", via.ConnectionState.eKey.Overhead(),
Error("SendVia out buffer not large enough for relay") )
return return
} }
@@ -322,12 +330,12 @@ func (f *Interface) SendVia(via *HostInfo,
via.ConnectionState.writeLock.Unlock() via.ConnectionState.writeLock.Unlock()
} }
if err != nil { if err != nil {
via.logger(f.l).WithError(err).Info("Failed to EncryptDanger in sendVia") via.logger(f.l).Info("Failed to EncryptDanger in sendVia", "error", err)
return return
} }
err = f.writers[0].WriteTo(out, via.remote) err = f.writers[0].WriteTo(out, via.remote)
if err != nil { if err != nil {
via.logger(f.l).WithError(err).Info("Failed to WriteTo in sendVia") via.logger(f.l).Info("Failed to WriteTo in sendVia", "error", err)
} }
f.connectionManager.RelayUsed(relay.LocalIndex) f.connectionManager.RelayUsed(relay.LocalIndex)
} }
@@ -366,8 +374,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0]) f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
hostinfo.lastRebindCount = f.rebindCount hostinfo.lastRebindCount = f.rebindCount
if f.l.Level >= logrus.DebugLevel { if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter") f.l.Debug("Lighthouse update triggered for punch due to rebind counter",
"vpnAddrs", hostinfo.vpnAddrs,
)
} }
} }
@@ -377,24 +387,30 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
ci.writeLock.Unlock() ci.writeLock.Unlock()
} }
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err). hostinfo.logger(f.l).Error("Failed to encrypt outgoing packet",
WithField("udpAddr", remote).WithField("counter", c). "error", err,
WithField("attemptedCounter", c). "udpAddr", remote,
Error("Failed to encrypt outgoing packet") "counter", c,
"attemptedCounter", c,
)
return return
} }
if remote.IsValid() { if remote.IsValid() {
err = f.writers[q].WriteTo(out, remote) err = f.writers[q].WriteTo(out, remote)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err). hostinfo.logger(f.l).Error("Failed to write outgoing packet",
WithField("udpAddr", remote).Error("Failed to write outgoing packet") "error", err,
"udpAddr", remote,
)
} }
} else if hostinfo.remote.IsValid() { } else if hostinfo.remote.IsValid() {
err = f.writers[q].WriteTo(out, hostinfo.remote) err = f.writers[q].WriteTo(out, hostinfo.remote)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err). hostinfo.logger(f.l).Error("Failed to write outgoing packet",
WithField("udpAddr", remote).Error("Failed to write outgoing packet") "error", err,
"udpAddr", remote,
)
} }
} else { } else {
// Try to send via a relay // Try to send via a relay
@@ -402,7 +418,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP) relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
if err != nil { if err != nil {
hostinfo.relayState.DeleteRelay(relayIP) hostinfo.relayState.DeleteRelay(relayIP)
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo") hostinfo.logger(f.l).Info("sendNoMetrics failed to find HostInfo",
"relay", relayIP,
"error", err,
)
continue continue
} }
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true) f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)

View File

@@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log/slog"
"net/netip" "net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -12,7 +13,7 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
@@ -46,7 +47,7 @@ type InterfaceConfig struct {
reQueryWait time.Duration reQueryWait time.Duration
ConntrackCacheTimeout time.Duration ConntrackCacheTimeout time.Duration
l *logrus.Logger l *slog.Logger
} }
type Interface struct { type Interface struct {
@@ -100,7 +101,7 @@ type Interface struct {
messageMetrics *MessageMetrics messageMetrics *MessageMetrics
cachedPacketMetrics *cachedPacketMetrics cachedPacketMetrics *cachedPacketMetrics
l *logrus.Logger l *slog.Logger
} }
type EncWriter interface { type EncWriter interface {
@@ -223,13 +224,16 @@ func (f *Interface) activate() error {
addr, err := f.outside.LocalAddr() addr, err := f.outside.LocalAddr()
if err != nil { if err != nil {
f.l.WithError(err).Error("Failed to get udp listen address") f.l.Error("Failed to get udp listen address", "error", err)
} }
f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks). f.l.Info("Nebula interface is active",
WithField("build", f.version).WithField("udpAddr", addr). "interface", f.inside.Name(),
WithField("boringcrypto", boringEnabled()). "networks", f.myVpnNetworks,
Info("Nebula interface is active") "build", f.version,
"udpAddr", addr,
"boringcrypto", boringEnabled(),
)
if f.routines > 1 { if f.routines > 1 {
if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() { if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() {
@@ -305,7 +309,7 @@ func (f *Interface) listenOut(i int) {
li = f.outside li = f.outside
} }
ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.conntrackCacheTimeout) ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
lhh := f.lightHouse.NewRequestHandler() lhh := f.lightHouse.NewRequestHandler()
plaintext := make([]byte, udp.MTU) plaintext := make([]byte, udp.MTU)
h := &header.H{} h := &header.H{}
@@ -313,15 +317,15 @@ func (f *Interface) listenOut(i int) {
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get())
}) })
if err != nil && !f.closed.Load() { if err != nil && !f.closed.Load() {
f.l.WithError(err).Error("Error while reading inbound packet, closing") f.l.Error("Error while reading inbound packet, closing", "error", err)
f.onFatal(err) f.onFatal(err)
} }
f.l.Debugf("underlay reader %v is done", i) f.l.Debug("underlay reader is done", "reader", i)
} }
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
@@ -330,22 +334,22 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.conntrackCacheTimeout) conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
for { for {
n, err := reader.Read(packet) n, err := reader.Read(packet)
if err != nil { if err != nil {
if !f.closed.Load() { if !f.closed.Load() {
f.l.WithError(err).WithField("reader", i).Error("Error while reading outbound packet, closing") f.l.Error("Error while reading outbound packet, closing", "error", err, "reader", i)
f.onFatal(err) f.onFatal(err)
} }
break break
} }
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get())
} }
f.l.Debugf("overlay reader %v is done", i) f.l.Debug("overlay reader is done", "reader", i)
} }
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
@@ -365,7 +369,7 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) {
if initial || c.HasChanged("pki.disconnect_invalid") { if initial || c.HasChanged("pki.disconnect_invalid") {
f.disconnectInvalid.Store(c.GetBool("pki.disconnect_invalid", true)) f.disconnectInvalid.Store(c.GetBool("pki.disconnect_invalid", true))
if !initial { if !initial {
f.l.Infof("pki.disconnect_invalid changed to %v", f.disconnectInvalid.Load()) f.l.Info("pki.disconnect_invalid changed", "value", f.disconnectInvalid.Load())
} }
} }
} }
@@ -379,7 +383,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c) fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
if err != nil { if err != nil {
f.l.WithError(err).Error("Error while creating firewall during reload") f.l.Error("Error while creating firewall during reload", "error", err)
return return
} }
@@ -392,10 +396,11 @@ func (f *Interface) reloadFirewall(c *config.C) {
// If rulesVersion is back to zero, we have wrapped all the way around. Be // If rulesVersion is back to zero, we have wrapped all the way around. Be
// safe and just reset conntrack in this case. // safe and just reset conntrack in this case.
if fw.rulesVersion == 0 { if fw.rulesVersion == 0 {
f.l.WithField("firewallHashes", fw.GetRuleHashes()). f.l.Warn("firewall rulesVersion has overflowed, resetting conntrack",
WithField("oldFirewallHashes", oldFw.GetRuleHashes()). "firewallHashes", fw.GetRuleHashes(),
WithField("rulesVersion", fw.rulesVersion). "oldFirewallHashes", oldFw.GetRuleHashes(),
Warn("firewall rulesVersion has overflowed, resetting conntrack") "rulesVersion", fw.rulesVersion,
)
} else { } else {
fw.Conntrack = conntrack fw.Conntrack = conntrack
} }
@@ -403,10 +408,11 @@ func (f *Interface) reloadFirewall(c *config.C) {
f.firewall = fw f.firewall = fw
oldFw.Destroy() oldFw.Destroy()
f.l.WithField("firewallHashes", fw.GetRuleHashes()). f.l.Info("New firewall has been installed",
WithField("oldFirewallHashes", oldFw.GetRuleHashes()). "firewallHashes", fw.GetRuleHashes(),
WithField("rulesVersion", fw.rulesVersion). "oldFirewallHashes", oldFw.GetRuleHashes(),
Info("New firewall has been installed") "rulesVersion", fw.rulesVersion,
)
} }
func (f *Interface) reloadSendRecvError(c *config.C) { func (f *Interface) reloadSendRecvError(c *config.C) {
@@ -428,8 +434,7 @@ func (f *Interface) reloadSendRecvError(c *config.C) {
} }
} }
f.l.WithField("sendRecvError", f.sendRecvErrorConfig.String()). f.l.Info("Loaded send_recv_error config", "sendRecvError", f.sendRecvErrorConfig.String())
Info("Loaded send_recv_error config")
} }
} }
@@ -452,8 +457,7 @@ func (f *Interface) reloadAcceptRecvError(c *config.C) {
} }
} }
f.l.WithField("acceptRecvError", f.acceptRecvErrorConfig.String()). f.l.Info("Loaded accept_recv_error config", "acceptRecvError", f.acceptRecvErrorConfig.String())
Info("Loaded accept_recv_error config")
} }
} }
@@ -527,7 +531,7 @@ func (f *Interface) Close() error {
for i, u := range f.writers { for i, u := range f.writers {
err := u.Close() err := u.Close()
if err != nil { if err != nil {
f.l.WithError(err).WithField("writer", i).Error("Error while closing udp socket") f.l.Error("Error while closing udp socket", "error", err, "writer", i)
errs = append(errs, err) errs = append(errs, err)
} }
} }

View File

@@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"net" "net"
"net/netip" "net/netip"
"slices" "slices"
@@ -15,10 +16,10 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/logging"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
) )
@@ -76,12 +77,12 @@ type LightHouse struct {
metrics *MessageMetrics metrics *MessageMetrics
metricHolepunchTx metrics.Counter metricHolepunchTx metrics.Counter
l *logrus.Logger l *slog.Logger
} }
// NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
// addrMap should be nil unless this is during a config reload // addrMap should be nil unless this is during a config reload
func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) { func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) {
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
nebulaPort := uint32(c.GetInt("listen.port", 0)) nebulaPort := uint32(c.GetInt("listen.port", 0))
if amLighthouse && nebulaPort == 0 { if amLighthouse && nebulaPort == 0 {
@@ -133,7 +134,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
case *util.ContextualError: case *util.ContextualError:
v.Log(l) v.Log(l)
case error: case error:
l.WithError(err).Error("failed to reload lighthouse") l.Error("failed to reload lighthouse", "error", err)
} }
}) })
@@ -205,8 +206,10 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
//TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used //TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used
addr := addrs[0].Unmap() addr := addrs[0].Unmap()
if lh.myVpnNetworksTable.Contains(addr) { if lh.myVpnNetworksTable.Contains(addr) {
lh.l.WithField("addr", rawAddr).WithField("entry", i+1). lh.l.Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range",
Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range") "addr", rawAddr,
"entry", i+1,
)
continue continue
} }
@@ -224,7 +227,9 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
lh.interval.Store(int64(c.GetInt("lighthouse.interval", 10))) lh.interval.Store(int64(c.GetInt("lighthouse.interval", 10)))
if !initial { if !initial {
lh.l.Infof("lighthouse.interval changed to %v", lh.interval.Load()) lh.l.Info("lighthouse.interval changed",
"interval", lh.interval.Load(),
)
if lh.updateCancel != nil { if lh.updateCancel != nil {
// May not always have a running routine // May not always have a running routine
@@ -336,9 +341,12 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
for _, v := range c.GetStringSlice("relay.relays", nil) { for _, v := range c.GetStringSlice("relay.relays", nil) {
configRIP, err := netip.ParseAddr(v) configRIP, err := netip.ParseAddr(v)
if err != nil { if err != nil {
lh.l.WithField("relay", v).WithError(err).Warn("Parse relay from config failed") lh.l.Warn("Parse relay from config failed",
"relay", v,
"error", err,
)
} else { } else {
lh.l.WithField("relay", v).Info("Read relay from config") lh.l.Info("Read relay from config", "relay", v)
relaysForMe = append(relaysForMe, configRIP) relaysForMe = append(relaysForMe, configRIP)
} }
} }
@@ -363,8 +371,10 @@ func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
} }
if !lh.myVpnNetworksTable.Contains(addr) { if !lh.myVpnNetworksTable.Contains(addr) {
lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}). lh.l.Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not",
Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not") "vpnAddr", addr,
"networks", lh.myVpnNetworks,
)
} }
out[i] = addr out[i] = addr
} }
@@ -435,8 +445,11 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
} }
if !lh.myVpnNetworksTable.Contains(vpnAddr) { if !lh.myVpnNetworksTable.Contains(vpnAddr) {
lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}). lh.l.Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work",
Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work") "vpnAddr", vpnAddr,
"networks", lh.myVpnNetworks,
"entry", i+1,
)
} }
vals, ok := v.([]any) vals, ok := v.([]any)
@@ -537,12 +550,13 @@ func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) {
lh.Lock() lh.Lock()
rm, ok := lh.addrMap[allVpnAddrs[0]] rm, ok := lh.addrMap[allVpnAddrs[0]]
if ok { if ok {
debugEnabled := lh.l.Enabled(context.Background(), slog.LevelDebug)
for _, addr := range allVpnAddrs { for _, addr := range allVpnAddrs {
srm := lh.addrMap[addr] srm := lh.addrMap[addr]
if srm == rm { if srm == rm {
delete(lh.addrMap, addr) delete(lh.addrMap, addr)
if lh.l.Level >= logrus.DebugLevel { if debugEnabled {
lh.l.Debugf("deleting %s from lighthouse.", addr) lh.l.Debug("deleting from lighthouse", "vpnAddr", addr)
} }
} }
} }
@@ -659,9 +673,12 @@ func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool { func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to) allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
if lh.l.Level >= logrus.TraceLevel { if lh.l.Enabled(context.Background(), logging.LevelTrace) {
lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow). lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow",
Trace("remoteAllowList.Allow") "vpnAddrs", vpnAddrs,
"udpAddr", to,
"allow", allow,
)
} }
if !allow { if !allow {
return false return false
@@ -678,9 +695,12 @@ func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bool { func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bool {
udpAddr := protoV4AddrPortToNetAddrPort(to) udpAddr := protoV4AddrPortToNetAddrPort(to)
allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr()) allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr())
if lh.l.Level >= logrus.TraceLevel { if lh.l.Enabled(context.Background(), logging.LevelTrace) {
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow). lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow",
Trace("remoteAllowList.Allow") "vpnAddr", vpnAddr,
"udpAddr", udpAddr,
"allow", allow,
)
} }
if !allow { if !allow {
@@ -698,9 +718,12 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bo
func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bool { func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bool {
udpAddr := protoV6AddrPortToNetAddrPort(to) udpAddr := protoV6AddrPortToNetAddrPort(to)
allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr()) allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr())
if lh.l.Level >= logrus.TraceLevel { if lh.l.Enabled(context.Background(), logging.LevelTrace) {
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow). lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow",
Trace("remoteAllowList.Allow") "vpnAddr", vpnAddr,
"udpAddr", udpAddr,
"allow", allow,
)
} }
if !allow { if !allow {
@@ -775,8 +798,10 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
if v == cert.Version1 { if v == cert.Version1 {
if !addr.Is4() { if !addr.Is4() {
lh.l.WithField("queryVpnAddr", addr).WithField("lighthouseAddr", lhVpnAddr). lh.l.Error("Can't query lighthouse for v6 address using a v1 protocol",
Error("Can't query lighthouse for v6 address using a v1 protocol") "queryVpnAddr", addr,
"lighthouseAddr", lhVpnAddr,
)
continue continue
} }
@@ -787,9 +812,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
v1Query, err = msg.Marshal() v1Query, err = msg.Marshal()
if err != nil { if err != nil {
lh.l.WithError(err).WithField("queryVpnAddr", addr). lh.l.Error("Failed to marshal lighthouse v1 query payload",
WithField("lighthouseAddr", lhVpnAddr). "error", err,
Error("Failed to marshal lighthouse v1 query payload") "queryVpnAddr", addr,
"lighthouseAddr", lhVpnAddr,
)
continue continue
} }
} }
@@ -804,9 +831,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
v2Query, err = msg.Marshal() v2Query, err = msg.Marshal()
if err != nil { if err != nil {
lh.l.WithError(err).WithField("queryVpnAddr", addr). lh.l.Error("Failed to marshal lighthouse v2 query payload",
WithField("lighthouseAddr", lhVpnAddr). "error", err,
Error("Failed to marshal lighthouse v2 query payload") "queryVpnAddr", addr,
"lighthouseAddr", lhVpnAddr,
)
continue continue
} }
} }
@@ -815,7 +844,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
queried++ queried++
} else { } else {
lh.l.Debugf("Can not query lighthouse for %v using unknown protocol version: %v", addr, v) lh.l.Debug("unsupported protocol version",
"op", "query",
"queryVpnAddr", addr,
"version", v,
)
continue continue
} }
} }
@@ -907,8 +940,9 @@ func (lh *LightHouse) SendUpdate() {
if v == cert.Version1 { if v == cert.Version1 {
if v1Update == nil { if v1Update == nil {
if !lh.myVpnNetworks[0].Addr().Is4() { if !lh.myVpnNetworks[0].Addr().Is4() {
lh.l.WithField("lighthouseAddr", lhVpnAddr). lh.l.Warn("cannot update lighthouse using v1 protocol without an IPv4 address",
Warn("cannot update lighthouse using v1 protocol without an IPv4 address") "lighthouseAddr", lhVpnAddr,
)
continue continue
} }
var relays []uint32 var relays []uint32
@@ -932,8 +966,10 @@ func (lh *LightHouse) SendUpdate() {
v1Update, err = msg.Marshal() v1Update, err = msg.Marshal()
if err != nil { if err != nil {
lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr). lh.l.Error("Error while marshaling for lighthouse v1 update",
Error("Error while marshaling for lighthouse v1 update") "error", err,
"lighthouseAddr", lhVpnAddr,
)
continue continue
} }
} }
@@ -959,8 +995,10 @@ func (lh *LightHouse) SendUpdate() {
v2Update, err = msg.Marshal() v2Update, err = msg.Marshal()
if err != nil { if err != nil {
lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr). lh.l.Error("Error while marshaling for lighthouse v2 update",
Error("Error while marshaling for lighthouse v2 update") "error", err,
"lighthouseAddr", lhVpnAddr,
)
continue continue
} }
} }
@@ -969,7 +1007,10 @@ func (lh *LightHouse) SendUpdate() {
updated++ updated++
} else { } else {
lh.l.Debugf("Can not update lighthouse using unknown protocol version: %v", v) lh.l.Debug("unsupported protocol version",
"op", "update",
"version", v,
)
continue continue
} }
} }
@@ -983,7 +1024,7 @@ type LightHouseHandler struct {
out []byte out []byte
pb []byte pb []byte
meta *NebulaMeta meta *NebulaMeta
l *logrus.Logger l *slog.Logger
} }
func (lh *LightHouse) NewRequestHandler() *LightHouseHandler { func (lh *LightHouse) NewRequestHandler() *LightHouseHandler {
@@ -1032,14 +1073,19 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [
n := lhh.resetMeta() n := lhh.resetMeta()
err := n.Unmarshal(p) err := n.Unmarshal(p)
if err != nil { if err != nil {
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). lhh.l.Error("Failed to unmarshal lighthouse packet",
Error("Failed to unmarshal lighthouse packet") "error", err,
"vpnAddrs", fromVpnAddrs,
"udpAddr", rAddr,
)
return return
} }
if n.Details == nil { if n.Details == nil {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). lhh.l.Error("Invalid lighthouse update",
Error("Invalid lighthouse update") "vpnAddrs", fromVpnAddrs,
"udpAddr", rAddr,
)
return return
} }
@@ -1067,25 +1113,29 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) { func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) {
// Exit if we don't answer queries // Exit if we don't answer queries
if !lhh.lh.amLighthouse { if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debugln("I don't answer queries, but received from: ", addr) lhh.l.Debug("I don't answer queries, but received one", "from", addr)
} }
return return
} }
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion() queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
if err != nil { if err != nil {
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details). lhh.l.Debug("Dropping malformed HostQuery",
Debugln("Dropping malformed HostQuery") "from", fromVpnAddrs,
"details", n.Details,
)
} }
return return
} }
if useVersion == cert.Version1 && queryVpnAddr.Is6() { if useVersion == cert.Version1 && queryVpnAddr.Is6() {
// this case really shouldn't be possible to represent, but reject it anyway. // this case really shouldn't be possible to represent, but reject it anyway.
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr). lhh.l.Debug("invalid vpn addr for v1 handleHostQuery",
Debugln("invalid vpn addr for v1 handleHostQuery") "vpnAddrs", fromVpnAddrs,
"queryVpnAddr", queryVpnAddr,
)
} }
return return
} }
@@ -1110,7 +1160,10 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
} }
if err != nil { if err != nil {
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host query reply") lhh.l.Error("Failed to marshal lighthouse host query reply",
"error", err,
"vpnAddrs", fromVpnAddrs,
)
return return
} }
@@ -1138,8 +1191,10 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
if ok { if ok {
whereToPunch = newDest whereToPunch = newDest
} else { } else {
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common") lhh.l.Debug("unable to punch to host, no addresses in common",
"to", crt.Networks(),
)
} }
} }
} }
@@ -1165,7 +1220,10 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
} }
if err != nil { if err != nil {
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host was queried for") lhh.l.Error("Failed to marshal lighthouse host was queried for",
"error", err,
"vpnAddrs", fromVpnAddrs,
)
return return
} }
@@ -1207,8 +1265,11 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r)) n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
} }
} else { } else {
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.WithField("version", v).Debug("unsupported protocol version") lhh.l.Debug("unsupported protocol version",
"op", "coalesceAnswers",
"version", v,
)
} }
} }
} }
@@ -1221,8 +1282,11 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion() certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
if err != nil { if err != nil {
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply") lhh.l.Error("dropping malformed HostQueryReply",
"error", err,
"vpnAddrs", fromVpnAddrs,
)
} }
return return
} }
@@ -1247,8 +1311,8 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
if !lhh.lh.amLighthouse { if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs) lhh.l.Debug("I am not a lighthouse, do not take host updates", "from", fromVpnAddrs)
} }
return return
} }
@@ -1271,8 +1335,11 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
//Simple check that the host sent this not someone else, if detailsVpnAddr is filled //Simple check that the host sent this not someone else, if detailsVpnAddr is filled
if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) { if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update") lhh.l.Debug("Host sent invalid update",
"vpnAddrs", fromVpnAddrs,
"answer", detailsVpnAddr,
)
} }
return return
} }
@@ -1294,7 +1361,9 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
switch useVersion { switch useVersion {
case cert.Version1: case cert.Version1:
if !fromVpnAddrs[0].Is4() { if !fromVpnAddrs[0].Is4() {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message") lhh.l.Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message",
"vpnAddrs", fromVpnAddrs,
)
return return
} }
vpnAddrB := fromVpnAddrs[0].As4() vpnAddrB := fromVpnAddrs[0].As4()
@@ -1302,13 +1371,16 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
case cert.Version2: case cert.Version2:
// do nothing, we want to send a blank message // do nothing, we want to send a blank message
default: default:
lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version") lhh.l.Error("invalid protocol version", "useVersion", useVersion)
return return
} }
ln, err := n.MarshalTo(lhh.pb) ln, err := n.MarshalTo(lhh.pb)
if err != nil { if err != nil {
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host update ack") lhh.l.Error("Failed to marshal lighthouse host update ack",
"error", err,
"vpnAddrs", fromVpnAddrs,
)
return return
} }
@@ -1325,8 +1397,11 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion() detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
if err != nil { if err != nil {
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification") lhh.l.Debug("dropping invalid HostPunchNotification",
"details", n.Details,
"error", err,
)
} }
return return
} }
@@ -1343,8 +1418,11 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
lhh.lh.punchConn.WriteTo(empty, vpnPeer) lhh.lh.punchConn.WriteTo(empty, vpnPeer)
}() }()
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr) lhh.l.Debug("Punching",
"vpnPeer", vpnPeer,
"logVpnAddr", logVpnAddr,
)
} }
} }
@@ -1369,8 +1447,10 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
if lhh.lh.punchy.GetRespond() { if lhh.lh.punchy.GetRespond() {
go func() { go func() {
time.Sleep(lhh.lh.punchy.GetRespondDelay()) time.Sleep(lhh.lh.punchy.GetRespondDelay())
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr) lhh.l.Debug("Sending a nebula test packet",
"vpnAddr", detailsVpnAddr,
)
} }
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine //NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
// for each punchBack packet. We should move this into a timerwheel or a single goroutine // for each punchBack packet. We should move this into a timerwheel or a single goroutine

View File

@@ -1,45 +0,0 @@
package nebula
import (
"fmt"
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
)
func configLogger(l *logrus.Logger, c *config.C) error {
// set up our logging level
logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
if err != nil {
return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
}
l.SetLevel(logLevel)
disableTimestamp := c.GetBool("logging.disable_timestamp", false)
timestampFormat := c.GetString("logging.timestamp_format", "")
fullTimestamp := (timestampFormat != "")
if timestampFormat == "" {
timestampFormat = time.RFC3339
}
logFormat := strings.ToLower(c.GetString("logging.format", "text"))
switch logFormat {
case "text":
l.Formatter = &logrus.TextFormatter{
TimestampFormat: timestampFormat,
FullTimestamp: fullTimestamp,
DisableTimestamp: disableTimestamp,
}
case "json":
l.Formatter = &logrus.JSONFormatter{
TimestampFormat: timestampFormat,
DisableTimestamp: disableTimestamp,
}
default:
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
}
return nil
}

233
logging/logger.go Normal file
View File

@@ -0,0 +1,233 @@
// Package logging wires the nebula runtime-reconfigurable slog handler used
// by nebula.Main and the nebula CLI binaries. Callers build a logger with
// NewLogger, then call ApplyConfig at startup and from a config reload
// callback to push logging.level, logging.format, and
// logging.disable_timestamp changes onto the logger without rebuilding it.
package logging
import (
"context"
"fmt"
"io"
"log/slog"
"strings"
"sync/atomic"
"time"
)
// Config is the subset of *config.C that ApplyConfig reads. Declaring it
// here keeps the logging package from depending on config directly, which
// would cycle through the shared test helpers (test.NewLogger imports
// logging, and config's tests import test). *config.C satisfies this
// interface structurally with no adapter.
type Config interface {
GetString(key, def string) string
GetBool(key string, def bool) bool
}
// LevelTrace is a custom slog level below Debug, used when logging.level is
// "trace". slog has no builtin trace level; the value is one step below
// slog.LevelDebug in slog's 4-point spacing.
const LevelTrace = slog.Level(-8)
// NewLogger returns a *slog.Logger whose level, format, and timestamp
// emission can be reconfigured at runtime via ApplyConfig and the SSH debug
// commands. The default configuration is info-level text output so log
// calls made before ApplyConfig runs still produce output. Timestamps
// follow slog's default RFC3339Nano format; set logging.disable_timestamp
// in config to suppress them.
//
// ApplyConfig and the SSH commands discover the reconfig surface via
// structural type-assertion on l.Handler(), so replacement implementations
// (tests, platform-specific sinks) need only implement the subset of
// {SetLevel(slog.Level), SetFormat(string) error, SetDisableTimestamp(bool)}
// they care about. Callers that pass a plain *slog.Logger without these
// methods get a silent no-op; reconfiguration is always opt-in.
func NewLogger(w io.Writer) *slog.Logger {
return slog.New(NewHandler(w))
}
// NewHandler builds the *Handler that NewLogger wraps. Exported for
// platform-specific sinks (notably cmd/nebula-service/logs_windows.go)
// that want to wrap the handler with extra behavior, such as tagging each
// record with its Event Log severity, while still benefiting from all the
// level / format / timestamp / WithAttrs machinery implemented here.
func NewHandler(w io.Writer) *Handler {
root := &handlerRoot{}
root.level.Set(slog.LevelInfo)
opts := &slog.HandlerOptions{Level: &root.level}
return &Handler{
root: root,
text: slog.NewTextHandler(w, opts),
json: slog.NewJSONHandler(w, opts),
}
}
// handlerRoot carries the reconfiguration state shared by every logger
// derived from a NewHandler call. All fields are consulted on the log
// path and updated lock-free.
type handlerRoot struct {
level slog.LevelVar
disableTimestamp atomic.Bool
// jsonMode picks which of the pre-derived inner handlers Handler.Handle
// dispatches to. Flipping it propagates instantly to every derived logger
// without rebuilding or chain-replaying anything.
jsonMode atomic.Bool
}
// Handler is the slog.Handler returned by NewHandler. It holds two
// pre-derived slog handlers -- one text, one json -- both built from the
// same accumulated WithAttrs/WithGroup state. Handle picks which one to
// dispatch to based on handlerRoot.jsonMode, so a SetFormat call takes
// effect immediately across the whole process without having to rebuild
// any derived loggers.
type Handler struct {
root *handlerRoot
text slog.Handler
json slog.Handler
}
func (h *Handler) Enabled(_ context.Context, l slog.Level) bool {
return h.root.level.Level() <= l
}
func (h *Handler) Handle(ctx context.Context, r slog.Record) error {
if h.root.disableTimestamp.Load() {
r.Time = time.Time{}
}
if h.root.jsonMode.Load() {
return h.json.Handle(ctx, r)
}
return h.text.Handle(ctx, r)
}
func (h *Handler) WithAttrs(attrs []slog.Attr) slog.Handler {
if len(attrs) == 0 {
return h
}
return &Handler{
root: h.root,
text: h.text.WithAttrs(attrs),
json: h.json.WithAttrs(attrs),
}
}
func (h *Handler) WithGroup(name string) slog.Handler {
if name == "" {
return h
}
return &Handler{
root: h.root,
text: h.text.WithGroup(name),
json: h.json.WithGroup(name),
}
}
// SetLevel updates the effective log level. Propagates to every derived
// logger via the shared LevelVar.
func (h *Handler) SetLevel(level slog.Level) { h.root.level.Set(level) }
// GetLevel reports the current log level.
func (h *Handler) GetLevel() slog.Level { return h.root.level.Level() }
// SetFormat flips the output format atomically. Valid formats are "text"
// and "json". Every derived logger sees the new format on its next Handle
// call; no rebuild or registration is required.
func (h *Handler) SetFormat(format string) error {
switch format {
case "text":
h.root.jsonMode.Store(false)
case "json":
h.root.jsonMode.Store(true)
default:
return fmt.Errorf("unknown log format `%s`. possible formats: %s", format, []string{"text", "json"})
}
return nil
}
// GetFormat reports the currently selected format name.
func (h *Handler) GetFormat() string {
if h.root.jsonMode.Load() {
return "json"
}
return "text"
}
// SetDisableTimestamp toggles whether Handle zeroes r.Time before
// dispatching (slog's builtin text/json handlers skip emitting the time
// attribute on a zero time).
func (h *Handler) SetDisableTimestamp(v bool) { h.root.disableTimestamp.Store(v) }
// ApplyConfig reads logging.level, logging.format, and (optionally)
// logging.disable_timestamp from c and applies them to l. The reconfig
// surface is discovered via structural type-assertion on l.Handler(), so
// foreign handlers silently opt out of whichever capabilities they do not
// implement.
//
// nebula.Main does NOT call this function on your behalf; callers that want
// config-driven log level / format / timestamp updates invoke it at
// startup and register it as a reload callback themselves. This keeps the
// library from mutating an embedder's logger without their say-so.
func ApplyConfig(l *slog.Logger, c Config) error {
h := l.Handler()
lvl, err := ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
if err != nil {
return err
}
if ls, ok := h.(interface{ SetLevel(slog.Level) }); ok {
ls.SetLevel(lvl)
}
format := strings.ToLower(c.GetString("logging.format", "text"))
if fs, ok := h.(interface{ SetFormat(string) error }); ok {
if err := fs.SetFormat(format); err != nil {
return err
}
}
if ts, ok := h.(interface{ SetDisableTimestamp(bool) }); ok {
ts.SetDisableTimestamp(c.GetBool("logging.disable_timestamp", false))
}
return nil
}
// ParseLevel converts a config-string level name ("trace", "debug", "info",
// "warn"/"warning", "error", "fatal"/"panic") to a slog.Level. "fatal" and
// "panic" are accepted for backwards compatibility with pre-slog configs
// and both map to slog.LevelError.
func ParseLevel(s string) (slog.Level, error) {
switch s {
case "trace":
return LevelTrace, nil
case "debug":
return slog.LevelDebug, nil
case "info":
return slog.LevelInfo, nil
case "warn", "warning":
return slog.LevelWarn, nil
case "error":
return slog.LevelError, nil
case "fatal", "panic":
return slog.LevelError, nil
default:
return 0, fmt.Errorf("not a valid logging level: %q", s)
}
}
// LevelName returns a human-readable name for a slog.Level matching the
// strings accepted by ParseLevel.
func LevelName(l slog.Level) string {
switch {
case l <= LevelTrace:
return "trace"
case l <= slog.LevelDebug:
return "debug"
case l <= slog.LevelInfo:
return "info"
case l <= slog.LevelWarn:
return "warn"
default:
return "error"
}
}

View File

@@ -0,0 +1,90 @@
package logging
import (
"context"
"io"
"log/slog"
"testing"
)
// BenchmarkLogger_* compare the handler returned by NewLogger against a
// stock slog text handler. The key thing we care about is the per-log
// cost on a logger that has been derived via .With(), because that is the
// shape subsystems store on their structs (HostInfo.logger(),
// lh.l.With("subsystem", ...), etc.) and call from hot paths.
func BenchmarkLogger_Stock_RootInfo(b *testing.B) {
l := slog.New(slog.DiscardHandler)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.Info("hello", "i", i)
}
}
func BenchmarkLogger_Nebula_RootInfo(b *testing.B) {
l := NewLogger(io.Discard)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.Info("hello", "i", i)
}
}
func BenchmarkLogger_Stock_DerivedInfo(b *testing.B) {
l := slog.New(slog.DiscardHandler).With(
"subsystem", "bench",
"localIndex", 1234,
)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.Info("hello", "i", i)
}
}
func BenchmarkLogger_Nebula_DerivedInfo(b *testing.B) {
l := NewLogger(io.Discard).With(
"subsystem", "bench",
"localIndex", 1234,
)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.Info("hello", "i", i)
}
}
// Gated-off-path benchmarks: mimic the typical hot-path shape
// `if l.Enabled(ctx, slog.LevelDebug) { ... }` where the log is gated below
// the active level. This is the dominant pattern in inside.go/outside.go and
// what we pay on every packet.
func BenchmarkLogger_Stock_DerivedEnabledGateMiss(b *testing.B) {
l := slog.New(slog.DiscardHandler).With(
"subsystem", "bench",
"localIndex", 1234,
)
ctx := context.Background()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
if l.Enabled(ctx, slog.LevelDebug) {
l.Debug("hello", "i", i)
}
}
}
func BenchmarkLogger_Nebula_DerivedEnabledGateMiss(b *testing.B) {
l := NewLogger(io.Discard).With(
"subsystem", "bench",
"localIndex", 1234,
)
ctx := context.Background()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
if l.Enabled(ctx, slog.LevelDebug) {
l.Debug("hello", "i", i)
}
}
}

39
main.go
View File

@@ -3,13 +3,13 @@ package nebula
import ( import (
"context" "context"
"fmt" "fmt"
"log/slog"
"net" "net"
"net/netip" "net/netip"
"runtime/debug" "runtime/debug"
"strings" "strings"
"time" "time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/sshd" "github.com/slackhq/nebula/sshd"
@@ -20,7 +20,7 @@ import (
type m = map[string]any type m = map[string]any
func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) { func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit. // Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
defer func() { defer func() {
@@ -33,11 +33,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
buildVersion = moduleVersion() buildVersion = moduleVersion()
} }
l := logger
l.Formatter = &logrus.TextFormatter{
FullTimestamp: true,
}
// Print the config if in test, the exit comes later // Print the config if in test, the exit comes later
if configTest { if configTest {
b, err := yaml.Marshal(c.Settings) b, err := yaml.Marshal(c.Settings)
@@ -46,21 +41,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
} }
// Print the final config // Print the final config
l.Println(string(b)) l.Info(string(b))
} }
err := configLogger(l, c)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to configure the logger", err)
}
c.RegisterReloadCallback(func(c *config.C) {
err := configLogger(l, c)
if err != nil {
l.WithError(err).Error("Failed to configure the logger")
}
})
pki, err := NewPKIFromConfig(l, c) pki, err := NewPKIFromConfig(l, c)
if err != nil { if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err) return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
@@ -70,9 +53,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if err != nil { if err != nil {
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err) return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
} }
l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started") l.Info("Firewall started", "firewallHashes", fw.GetRuleHashes())
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) ssh, err := sshd.NewSSHServer(l.With("subsystem", "sshd"))
if err != nil { if err != nil {
return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err) return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
} }
@@ -81,7 +64,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if c.GetBool("sshd.enabled", false) { if c.GetBool("sshd.enabled", false) {
sshStart, err = configSSH(l, ssh, c) sshStart, err = configSSH(l, ssh, c)
if err != nil { if err != nil {
l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available") l.Warn("Failed to configure sshd, ssh debugging will not be available", "error", err)
sshStart = nil sshStart = nil
} }
} }
@@ -99,7 +82,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
routines = 1 routines = 1
} }
if routines > 1 { if routines > 1 {
l.WithField("routines", routines).Info("Using multiple routines") l.Info("Using multiple routines", "routines", routines)
} }
} else { } else {
// deprecated and undocumented // deprecated and undocumented
@@ -107,7 +90,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
udpQueues := c.GetInt("listen.routines", 1) udpQueues := c.GetInt("listen.routines", 1)
routines = max(tunQueues, udpQueues) routines = max(tunQueues, udpQueues)
if routines != 1 { if routines != 1 {
l.WithField("routines", routines).Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead") l.Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead", "routines", routines)
} }
} }
@@ -120,7 +103,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
conntrackCacheTimeout = 1 * time.Second conntrackCacheTimeout = 1 * time.Second
} }
if conntrackCacheTimeout > 0 { if conntrackCacheTimeout > 0 {
l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache") l.Info("Using routine-local conntrack cache", "duration", conntrackCacheTimeout)
} }
var tun overlay.Device var tun overlay.Device
@@ -166,7 +149,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
} }
for i := 0; i < routines; i++ { for i := 0; i < routines; i++ {
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port))) l.Info("listening", "addr", netip.AddrPortFrom(listenHost, uint16(port)))
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64)) udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
if err != nil { if err != nil {
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
@@ -217,7 +200,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
ds, err := newDnsServerFromConfig(ctx, l, pki.getCertState(), hostMap, c) ds, err := newDnsServerFromConfig(ctx, l, pki.getCertState(), hostMap, c)
if err != nil { if err != nil {
l.WithError(err).Warn("Failed to start DNS responder") l.Warn("Failed to start DNS responder", "error", err)
} }
ifConfig := &InterfaceConfig{ ifConfig := &InterfaceConfig{

View File

@@ -1,15 +1,16 @@
package nebula package nebula
import ( import (
"context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"log/slog"
"net/netip" "net/netip"
"time" "time"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
@@ -24,7 +25,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
if err != nil { if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
if len(packet) > 1 { if len(packet) > 1 {
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err) f.l.Info("Error while parsing inbound packet",
"from", via,
"error", err,
"packet", packet,
)
} }
return return
} }
@@ -32,8 +37,8 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
//l.Error("in packet ", header, packet[HeaderLen:]) //l.Error("in packet ", header, packet[HeaderLen:])
if !via.IsRelayed { if !via.IsRelayed {
if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) { if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) {
if f.l.Level >= logrus.DebugLevel { if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.WithField("from", via).Debug("Refusing to process double encrypted packet") f.l.Debug("Refusing to process double encrypted packet", "from", via)
} }
return return
} }
@@ -87,7 +92,10 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
if !ok { if !ok {
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
// its internal mapping. This should never happen. // its internal mapping. This should never happen.
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") hostinfo.logger(f.l).Error("HostInfo missing remote relay index",
"vpnAddrs", hostinfo.vpnAddrs,
"remoteIndex", h.RemoteIndex,
)
return return
} }
@@ -108,7 +116,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
// Find the target HostInfo relay object // Find the target HostInfo relay object
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip") hostinfo.logger(f.l).Info("Failed to find target host info by ip",
"relayTo", relay.PeerAddr,
"error", err,
"hostinfo.vpnAddrs", hostinfo.vpnAddrs,
)
return return
} }
@@ -124,7 +136,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
} }
} else { } else {
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") hostinfo.logger(f.l).Info("Unexpected target relay state",
"relayTo", relay.PeerAddr,
"relayFrom", hostinfo.vpnAddrs[0],
"targetRelayState", targetRelay.State,
)
return return
} }
} }
@@ -138,9 +154,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("from", via). hostinfo.logger(f.l).Error("Failed to decrypt lighthouse packet",
WithField("packet", packet). "error", err,
Error("Failed to decrypt lighthouse packet") "from", via,
"packet", packet,
)
return return
} }
@@ -157,9 +175,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("from", via). hostinfo.logger(f.l).Error("Failed to decrypt test packet",
WithField("packet", packet). "error", err,
Error("Failed to decrypt test packet") "from", via,
"packet", packet,
)
return return
} }
@@ -192,14 +212,15 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
} }
_, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) _, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("from", via). hostinfo.logger(f.l).Error("Failed to decrypt CloseTunnel packet",
WithField("packet", packet). "error", err,
Error("Failed to decrypt CloseTunnel packet") "from", via,
"packet", packet,
)
return return
} }
hostinfo.logger(f.l).WithField("from", via). hostinfo.logger(f.l).Info("Close tunnel received, tearing down.", "from", via)
Info("Close tunnel received, tearing down.")
f.closeTunnel(hostinfo) f.closeTunnel(hostinfo)
return return
@@ -211,9 +232,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("from", via). hostinfo.logger(f.l).Error("Failed to decrypt Control packet",
WithField("packet", packet). "error", err,
Error("Failed to decrypt Control packet") "from", via,
"packet", packet,
)
return return
} }
@@ -221,7 +244,9 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
default: default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via) if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("Unexpected packet received", "from", via)
}
return return
} }
@@ -247,20 +272,27 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) { func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) {
if !via.IsRelayed && hostinfo.remote != via.UdpAddr { if !via.IsRelayed && hostinfo.remote != via.UdpAddr {
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) { if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
hostinfo.logger(f.l).WithField("newAddr", via.UdpAddr).Debug("lighthouse.remote_allow_list denied roaming") if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("lighthouse.remote_allow_list denied roaming", "newAddr", via.UdpAddr)
}
return return
} }
if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
if f.l.Level >= logrus.DebugLevel { if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr). hostinfo.logger(f.l).Debug("Suppressing roam back to previous remote",
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) "suppressSeconds", RoamingSuppressSeconds,
"udpAddr", hostinfo.remote,
"newAddr", via.UdpAddr,
)
} }
return return
} }
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr). hostinfo.logger(f.l).Info("Host roamed to new udp ip/port.",
Info("Host roamed to new udp ip/port.") "udpAddr", hostinfo.remote,
"newAddr", via.UdpAddr,
)
hostinfo.lastRoam = time.Now() hostinfo.lastRoam = time.Now()
hostinfo.lastRoamRemote = hostinfo.remote hostinfo.lastRoamRemote = hostinfo.remote
hostinfo.SetRemote(via.UdpAddr) hostinfo.SetRemote(via.UdpAddr)
@@ -491,8 +523,9 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
} }
if !hostinfo.ConnectionState.window.Update(f.l, mc) { if !hostinfo.ConnectionState.window.Update(f.l, mc) {
hostinfo.logger(f.l).WithField("header", h). if f.l.Enabled(context.Background(), slog.LevelDebug) {
Debugln("dropping out of window packet") hostinfo.logger(f.l).Debug("dropping out of window packet", "header", h)
}
return nil, errors.New("out of window packet") return nil, errors.New("out of window packet")
} }
@@ -504,20 +537,23 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") hostinfo.logger(f.l).Error("Failed to decrypt packet", "error", err)
return false return false
} }
err = newPacket(out, true, fwPacket) err = newPacket(out, true, fwPacket)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("packet", out). hostinfo.logger(f.l).Warn("Error while validating inbound packet",
Warnf("Error while validating inbound packet") "error", err,
"packet", out,
)
return false return false
} }
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) { if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket). if f.l.Enabled(context.Background(), slog.LevelDebug) {
Debugln("dropping out of window packet") hostinfo.logger(f.l).Debug("dropping out of window packet", "fwPacket", fwPacket)
}
return false return false
} }
@@ -526,10 +562,11 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in // This gives us a buffer to build the reject packet in
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q) f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
if f.l.Level >= logrus.DebugLevel { if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket). hostinfo.logger(f.l).Debug("dropping inbound packet",
WithField("reason", dropReason). "fwPacket", fwPacket,
Debugln("dropping inbound packet") "reason", dropReason,
)
} }
return false return false
} }
@@ -537,7 +574,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
f.connectionManager.In(hostinfo) f.connectionManager.In(hostinfo)
_, err = f.readers[q].Write(out) _, err = f.readers[q].Write(out)
if err != nil { if err != nil {
f.l.WithError(err).Error("Failed to write to tun") f.l.Error("Failed to write to tun", "error", err)
} }
return true return true
} }
@@ -553,35 +590,41 @@ func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) {
b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0) b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
_ = f.outside.WriteTo(b, endpoint) _ = f.outside.WriteTo(b, endpoint)
if f.l.Level >= logrus.DebugLevel { if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.WithField("index", index). f.l.Debug("Recv error sent",
WithField("udpAddr", endpoint). "index", index,
Debug("Recv error sent") "udpAddr", endpoint,
)
} }
} }
func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) { func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
if !f.acceptRecvErrorConfig.ShouldRecvError(addr) { if !f.acceptRecvErrorConfig.ShouldRecvError(addr) {
f.l.WithField("index", h.RemoteIndex). f.l.Debug("Recv error received, ignoring",
WithField("udpAddr", addr). "index", h.RemoteIndex,
Debug("Recv error received, ignoring") "udpAddr", addr,
)
return return
} }
if f.l.Level >= logrus.DebugLevel { if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.WithField("index", h.RemoteIndex). f.l.Debug("Recv error received",
WithField("udpAddr", addr). "index", h.RemoteIndex,
Debug("Recv error received") "udpAddr", addr,
)
} }
hostinfo := f.hostMap.QueryReverseIndex(h.RemoteIndex) hostinfo := f.hostMap.QueryReverseIndex(h.RemoteIndex)
if hostinfo == nil { if hostinfo == nil {
f.l.WithField("remoteIndex", h.RemoteIndex).Debugln("Did not find remote index in main hostmap") f.l.Debug("Did not find remote index in main hostmap", "remoteIndex", h.RemoteIndex)
return return
} }
if hostinfo.remote.IsValid() && hostinfo.remote != addr { if hostinfo.remote.IsValid() && hostinfo.remote != addr {
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote) f.l.Info("Someone spoofing recv_errors?",
"addr", addr,
"hostinfoRemote", hostinfo.remote,
)
return return
} }

View File

@@ -1,4 +1,6 @@
package test // Package overlaytest provides fakes of overlay.Device for tests that do
// not want to touch a real tun device or route table.
package overlaytest
import ( import (
"errors" "errors"
@@ -8,6 +10,9 @@ import (
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
// NoopTun is an overlay.Device that silently discards every read and write.
// Useful in tests that need to construct a nebula Interface but do not
// exercise the datapath.
type NoopTun struct{} type NoopTun struct{}
func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways { func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways {

View File

@@ -2,6 +2,7 @@ package overlay
import ( import (
"fmt" "fmt"
"log/slog"
"math" "math"
"net" "net"
"net/netip" "net/netip"
@@ -9,7 +10,6 @@ import (
"strconv" "strconv"
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
@@ -48,11 +48,14 @@ func (r Route) String() string {
return s return s
} }
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) { func makeRouteTree(l *slog.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) {
routeTree := new(bart.Table[routing.Gateways]) routeTree := new(bart.Table[routing.Gateways])
for _, r := range routes { for _, r := range routes {
if !allowMTU && r.MTU > 0 { if !allowMTU && r.MTU > 0 {
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS) l.Warn("route MTU is not supported on this platform",
"goos", runtime.GOOS,
"route", r,
)
} }
gateways := r.Via gateways := r.Via

View File

@@ -295,7 +295,7 @@ func Test_makeRouteTree(t *testing.T) {
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, routes, 2) assert.Len(t, routes, 2)
routeTree, err := makeRouteTree(l, routes, true) routeTree, err := makeRouteTree(test.NewLogger(), routes, true)
require.NoError(t, err) require.NoError(t, err)
ip, err := netip.ParseAddr("1.0.0.2") ip, err := netip.ParseAddr("1.0.0.2")
@@ -367,7 +367,7 @@ func Test_makeMultipathUnsafeRouteTree(t *testing.T) {
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, routes, 3) assert.Len(t, routes, 3)
routeTree, err := makeRouteTree(l, routes, true) routeTree, err := makeRouteTree(test.NewLogger(), routes, true)
require.NoError(t, err) require.NoError(t, err)
ip, err := netip.ParseAddr("192.168.86.1") ip, err := netip.ParseAddr("192.168.86.1")

View File

@@ -2,10 +2,10 @@ package overlay
import ( import (
"fmt" "fmt"
"log/slog"
"net" "net"
"net/netip" "net/netip"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
) )
@@ -22,9 +22,9 @@ func (e *NameError) Error() string {
} }
// TODO: We may be able to remove routines // TODO: We may be able to remove routines
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) type DeviceFactory func(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { func NewDeviceFromConfig(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
switch { switch {
case c.GetBool("tun.disabled", false): case c.GetBool("tun.disabled", false):
tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
@@ -36,7 +36,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Pref
} }
func NewFdDeviceFromConfig(fd *int) DeviceFactory { func NewFdDeviceFromConfig(fd *int) DeviceFactory {
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { return func(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
return newTunFromFd(c, l, *fd, vpnNetworks) return newTunFromFd(c, l, *fd, vpnNetworks)
} }
} }

View File

@@ -6,12 +6,12 @@ package overlay
import ( import (
"fmt" "fmt"
"io" "io"
"log/slog"
"net/netip" "net/netip"
"os" "os"
"sync/atomic" "sync/atomic"
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
@@ -23,10 +23,10 @@ type tun struct {
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger l *slog.Logger
} }
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
// Be sure not to call file.Fd() as it will set the fd to blocking mode. // Be sure not to call file.Fd() as it will set the fd to blocking mode.
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
@@ -53,7 +53,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
return t, nil return t, nil
} }
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in Android") return nil, fmt.Errorf("newTun not supported in Android")
} }

View File

@@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log/slog"
"net/netip" "net/netip"
"os" "os"
"sync/atomic" "sync/atomic"
@@ -14,7 +15,6 @@ import (
"unsafe" "unsafe"
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
@@ -30,7 +30,7 @@ type tun struct {
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
linkAddr *netroute.LinkAddr linkAddr *netroute.LinkAddr
l *logrus.Logger l *slog.Logger
// cache out buffer since we need to prepend 4 bytes for tun metadata // cache out buffer since we need to prepend 4 bytes for tun metadata
out []byte out []byte
@@ -79,7 +79,7 @@ type ifreqAlias6 struct {
Lifetime addrLifetime Lifetime addrLifetime
} }
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
name := c.GetString("tun.dev", "") name := c.GetString("tun.dev", "")
ifIndex := -1 ifIndex := -1
if name != "" && name != "utun" { if name != "" && name != "utun" {
@@ -153,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
return return
} }
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin") return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
} }
@@ -389,8 +389,7 @@ func (t *tun) addRoutes(logErrors bool) error {
err := addRoute(r.Cidr, t.linkAddr) err := addRoute(r.Cidr, t.linkAddr)
if err != nil { if err != nil {
if errors.Is(err, unix.EEXIST) { if errors.Is(err, unix.EEXIST) {
t.l.WithField("route", r.Cidr). t.l.Warn("unable to add unsafe_route, identical route already exists", "route", r.Cidr)
Warnf("unable to add unsafe_route, identical route already exists")
} else { } else {
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
if logErrors { if logErrors {
@@ -400,7 +399,7 @@ func (t *tun) addRoutes(logErrors bool) error {
} }
} }
} else { } else {
t.l.WithField("route", r).Info("Added route") t.l.Info("Added route", "route", r)
} }
} }
@@ -415,9 +414,9 @@ func (t *tun) removeRoutes(routes []Route) error {
err := delRoute(r.Cidr, t.linkAddr) err := delRoute(r.Cidr, t.linkAddr)
if err != nil { if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route") t.l.Error("Failed to remove route", "error", err, "route", r)
} else { } else {
t.l.WithField("route", r).Info("Removed route") t.l.Info("Removed route", "route", r)
} }
} }
return nil return nil

View File

@@ -1,13 +1,14 @@
package overlay package overlay
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"log/slog"
"net/netip" "net/netip"
"strings" "strings"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
@@ -19,10 +20,10 @@ type disabledTun struct {
// Track these metrics since we don't have the tun device to do it for us // Track these metrics since we don't have the tun device to do it for us
tx metrics.Counter tx metrics.Counter
rx metrics.Counter rx metrics.Counter
l *logrus.Logger l *slog.Logger
} }
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *slog.Logger) *disabledTun {
tun := &disabledTun{ tun := &disabledTun{
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
read: make(chan []byte, queueLen), read: make(chan []byte, queueLen),
@@ -67,8 +68,8 @@ func (t *disabledTun) Read(b []byte) (int, error) {
} }
t.tx.Inc(1) t.tx.Inc(1)
if t.l.Level >= logrus.DebugLevel { if t.l.Enabled(context.Background(), slog.LevelDebug) {
t.l.WithField("raw", prettyPacket(r)).Debugf("Write payload") t.l.Debug("Write payload", "raw", prettyPacket(r))
} }
return copy(b, r), nil return copy(b, r), nil
@@ -85,7 +86,7 @@ func (t *disabledTun) handleICMPEchoRequest(b []byte) bool {
select { select {
case t.read <- out: case t.read <- out:
default: default:
t.l.Debugf("tun_disabled: dropped ICMP Echo Reply response") t.l.Debug("tun_disabled: dropped ICMP Echo Reply response")
} }
return true return true
@@ -96,11 +97,11 @@ func (t *disabledTun) Write(b []byte) (int, error) {
// Check for ICMP Echo Request before spending time doing the full parsing // Check for ICMP Echo Request before spending time doing the full parsing
if t.handleICMPEchoRequest(b) { if t.handleICMPEchoRequest(b) {
if t.l.Level >= logrus.DebugLevel { if t.l.Enabled(context.Background(), slog.LevelDebug) {
t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request") t.l.Debug("Disabled tun responded to ICMP Echo Request", "raw", prettyPacket(b))
} }
} else if t.l.Level >= logrus.DebugLevel { } else if t.l.Enabled(context.Background(), slog.LevelDebug) {
t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload") t.l.Debug("Disabled tun received unexpected payload", "raw", prettyPacket(b))
} }
return len(b), nil return len(b), nil
} }

View File

@@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"io" "io"
"io/fs" "io/fs"
"log/slog"
"net/netip" "net/netip"
"os" "os"
"sync/atomic" "sync/atomic"
@@ -17,8 +18,9 @@ import (
"unsafe" "unsafe"
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route" netroute "golang.org/x/net/route"
@@ -93,7 +95,7 @@ type tun struct {
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
linkAddr *netroute.LinkAddr linkAddr *netroute.LinkAddr
l *logrus.Logger l *slog.Logger
fd int fd int
shutdownR int // read end of the shutdown pipe; closing the write end wakes blocked polls shutdownR int // read end of the shutdown pipe; closing the write end wakes blocked polls
@@ -243,7 +245,7 @@ func (t *tun) Close() error {
if t.fd >= 0 { if t.fd >= 0 {
if err := unix.Close(t.fd); err != nil { if err := unix.Close(t.fd); err != nil {
t.l.WithError(err).Error("Error closing device") t.l.Error("Error closing device", "error", err)
} }
t.fd = -1 t.fd = -1
} }
@@ -264,7 +266,7 @@ func (t *tun) Close() error {
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq))) err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
} }
if err != nil { if err != nil {
t.l.WithError(err).Error("Error destroying tunnel") t.l.Error("Error destroying tunnel", "error", err)
} }
}() }()
@@ -277,11 +279,11 @@ func (t *tun) Close() error {
return nil return nil
} }
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
} }
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open existing tun device // Try to open existing tun device
var fd int var fd int
var err error var err error
@@ -584,7 +586,7 @@ func (t *tun) addRoutes(logErrors bool) error {
return retErr return retErr
} }
} else { } else {
t.l.WithField("route", r).Info("Added route") t.l.Info("Added route", "route", r)
} }
} }
@@ -599,9 +601,9 @@ func (t *tun) removeRoutes(routes []Route) error {
err := delRoute(r.Cidr, t.linkAddr) err := delRoute(r.Cidr, t.linkAddr)
if err != nil { if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route") t.l.Error("Failed to remove route", "error", err, "route", r)
} else { } else {
t.l.WithField("route", r).Info("Removed route") t.l.Info("Removed route", "route", r)
} }
} }
return nil return nil

View File

@@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log/slog"
"net/netip" "net/netip"
"os" "os"
"sync" "sync"
@@ -14,7 +15,6 @@ import (
"syscall" "syscall"
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
@@ -25,14 +25,14 @@ type tun struct {
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger l *slog.Logger
} }
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in iOS") return nil, fmt.Errorf("newTun not supported in iOS")
} }
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/tun") file := os.NewFile(uintptr(deviceFd), "/dev/tun")
t := &tun{ t := &tun{
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,

View File

@@ -7,6 +7,7 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
"log/slog"
"net" "net"
"net/netip" "net/netip"
"os" "os"
@@ -17,7 +18,6 @@ import (
"unsafe" "unsafe"
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
@@ -213,7 +213,7 @@ type tun struct {
routesFromSystem map[netip.Prefix]routing.Gateways routesFromSystem map[netip.Prefix]routing.Gateways
routesFromSystemLock sync.Mutex routesFromSystemLock sync.Mutex
l *logrus.Logger l *slog.Logger
} }
func (t *tun) Networks() []netip.Prefix { func (t *tun) Networks() []netip.Prefix {
@@ -238,7 +238,7 @@ type ifreqQLEN struct {
pad [8]byte pad [8]byte
} }
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
t, err := newTunGeneric(c, l, deviceFd, vpnNetworks) t, err := newTunGeneric(c, l, deviceFd, vpnNetworks)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -249,7 +249,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
return t, nil return t, nil
} }
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil { if err != nil {
// If /dev/net/tun doesn't exist, try to create it (will happen in docker) // If /dev/net/tun doesn't exist, try to create it (will happen in docker)
@@ -299,7 +299,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
} }
// newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error. // newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error.
func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) { func newTunGeneric(c *config.C, l *slog.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) {
tfd, err := newTunFd(fd) tfd, err := newTunFd(fd)
if err != nil { if err != nil {
_ = unix.Close(fd) _ = unix.Close(fd)
@@ -378,16 +378,16 @@ func (t *tun) reload(c *config.C, initial bool) error {
if !initial { if !initial {
if oldMaxMTU != newMaxMTU { if oldMaxMTU != newMaxMTU {
t.setMTU() t.setMTU()
t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU) t.l.Info("Set max MTU", "mtu", t.MaxMTU, "oldMTU", oldMaxMTU)
} }
if oldDefaultMTU != newDefaultMTU { if oldDefaultMTU != newDefaultMTU {
for i := range t.vpnNetworks { for i := range t.vpnNetworks {
err := t.setDefaultRoute(t.vpnNetworks[i]) err := t.setDefaultRoute(t.vpnNetworks[i])
if err != nil { if err != nil {
t.l.Warn(err) t.l.Warn(err.Error())
} else { } else {
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU) t.l.Info("Set default MTU", "mtu", t.DefaultMTU, "oldMTU", oldDefaultMTU)
} }
} }
} }
@@ -492,9 +492,9 @@ func (t *tun) addIPs(link netlink.Link) error {
} }
err = netlink.AddrDel(link, &al[i]) err = netlink.AddrDel(link, &al[i])
if err != nil { if err != nil {
t.l.WithError(err).Error("failed to remove address from tun address list") t.l.Error("failed to remove address from tun address list", "error", err)
} else { } else {
t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)") t.l.Info("removed address not listed in cert(s)", "removed", al[i].String())
} }
} }
@@ -538,12 +538,12 @@ func (t *tun) Activate() error {
ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)} ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil { if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
// If we can't set the queue length nebula will still work but it may lead to packet loss // If we can't set the queue length nebula will still work but it may lead to packet loss
t.l.WithError(err).Error("Failed to set tun tx queue length") t.l.Error("Failed to set tun tx queue length", "error", err)
} }
const modeNone = 1 const modeNone = 1
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil { if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
t.l.WithError(err).Warn("Failed to disable link local address generation") t.l.Warn("Failed to disable link local address generation", "error", err)
} }
if err = t.addIPs(link); err != nil { if err = t.addIPs(link); err != nil {
@@ -582,7 +582,7 @@ func (t *tun) setMTU() {
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)} ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)}
if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil { if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well // This is currently a non fatal condition because the route table must have the MTU set appropriately as well
t.l.WithError(err).Error("Failed to set tun mtu") t.l.Error("Failed to set tun mtu", "error", err)
} }
} }
@@ -605,7 +605,7 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
} }
err := netlink.RouteReplace(&nr) err := netlink.RouteReplace(&nr)
if err != nil { if err != nil {
t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying") t.l.Warn("Failed to set default route MTU, retrying", "error", err, "cidr", cidr)
//retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument` //retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument`
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
@@ -613,7 +613,11 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
if err == nil { if err == nil {
break break
} else { } else {
t.l.WithError(err).WithField("cidr", cidr).WithField("mtu", t.DefaultMTU).Warn("Failed to set default route MTU, retrying") t.l.Warn("Failed to set default route MTU, retrying",
"error", err,
"cidr", cidr,
"mtu", t.DefaultMTU,
)
} }
} }
if err != nil { if err != nil {
@@ -658,7 +662,7 @@ func (t *tun) addRoutes(logErrors bool) error {
return retErr return retErr
} }
} else { } else {
t.l.WithField("route", r).Info("Added route") t.l.Info("Added route", "route", r)
} }
} }
@@ -690,9 +694,9 @@ func (t *tun) removeRoutes(routes []Route) {
err := netlink.RouteDel(&nr) err := netlink.RouteDel(&nr)
if err != nil { if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route") t.l.Error("Failed to remove route", "error", err, "route", r)
} else { } else {
t.l.WithField("route", r).Info("Removed route") t.l.Info("Removed route", "route", r)
} }
} }
} }
@@ -721,11 +725,11 @@ func (t *tun) watchRoutes() {
netlinkOptions := netlink.RouteSubscribeOptions{ netlinkOptions := netlink.RouteSubscribeOptions{
ReceiveBufferSize: t.useSystemRoutesBufferSize, ReceiveBufferSize: t.useSystemRoutesBufferSize,
ReceiveBufferForceSize: t.useSystemRoutesBufferSize != 0, ReceiveBufferForceSize: t.useSystemRoutesBufferSize != 0,
ErrorCallback: func(e error) { t.l.WithError(e).Errorf("netlink error") }, ErrorCallback: func(e error) { t.l.Error("netlink error", "error", e) },
} }
if err := netlink.RouteSubscribeWithOptions(rch, doneChan, netlinkOptions); err != nil { if err := netlink.RouteSubscribeWithOptions(rch, doneChan, netlinkOptions); err != nil {
t.l.WithError(err).Errorf("failed to subscribe to system route changes") t.l.Error("failed to subscribe to system route changes", "error", err)
return return
} }
@@ -767,7 +771,7 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
link, err := netlink.LinkByName(t.Device) link, err := netlink.LinkByName(t.Device)
if err != nil { if err != nil {
t.l.WithField("deviceName", t.Device).Error("Ignoring route update: failed to get link by name") t.l.Error("Ignoring route update: failed to get link by name", "deviceName", t.Device)
return gateways return gateways
} }
@@ -779,10 +783,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
gateways = append(gateways, routing.NewGateway(gwAddr, 1)) gateways = append(gateways, routing.NewGateway(gwAddr, 1))
} else { } else {
// Gateway isn't in our overlay network, ignore // Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network") t.l.Debug("Ignoring route update, gateway is not in our network", "route", r)
} }
} else { } else {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address") t.l.Debug("Ignoring route update, invalid gateway or via address", "route", r)
} }
} }
@@ -795,10 +799,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1)) gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
} else { } else {
// Gateway isn't in our overlay network, ignore // Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network") t.l.Debug("Ignoring route update, gateway is not in our network", "route", r)
} }
} else { } else {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address") t.l.Debug("Ignoring route update, invalid gateway or via address", "route", r)
} }
} }
} }
@@ -830,18 +834,18 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
gateways := t.getGatewaysFromRoute(&r.Route) gateways := t.getGatewaysFromRoute(&r.Route)
if len(gateways) == 0 { if len(gateways) == 0 {
// No gateways relevant to our network, no routing changes required. // No gateways relevant to our network, no routing changes required.
t.l.WithField("route", r).Debug("Ignoring route update, no gateways") t.l.Debug("Ignoring route update, no gateways", "route", r)
return return
} }
if r.Dst == nil { if r.Dst == nil {
t.l.WithField("route", r).Debug("Ignoring route update, no destination address") t.l.Debug("Ignoring route update, no destination address", "route", r)
return return
} }
dstAddr, ok := netip.AddrFromSlice(r.Dst.IP) dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
if !ok { if !ok {
t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address") t.l.Debug("Ignoring route update, invalid destination address", "route", r)
return return
} }
@@ -852,12 +856,12 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
t.routesFromSystemLock.Lock() t.routesFromSystemLock.Lock()
if r.Type == unix.RTM_NEWROUTE { if r.Type == unix.RTM_NEWROUTE {
t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route") t.l.Info("Adding route", "destination", dst, "via", gateways)
t.routesFromSystem[dst] = gateways t.routesFromSystem[dst] = gateways
newTree.Insert(dst, gateways) newTree.Insert(dst, gateways)
} else { } else {
t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route") t.l.Info("Removing route", "destination", dst, "via", gateways)
delete(t.routesFromSystem, dst) delete(t.routesFromSystem, dst)
newTree.Delete(dst) newTree.Delete(dst)
} }
@@ -888,18 +892,18 @@ func (t *tun) Close() error {
} }
err := t.readers[i].Close() err := t.readers[i].Close()
if err != nil { if err != nil {
t.l.WithField("reader", i).WithError(err).Error("error closing tun reader") t.l.Error("error closing tun reader", "reader", i, "error", err)
} else { } else {
t.l.WithField("reader", i).Info("closed tun reader") t.l.Info("closed tun reader", "reader", i)
} }
} }
//this is t.readers[0] too //this is t.readers[0] too
err := t.tunFile.Close() err := t.tunFile.Close()
if err != nil { if err != nil {
t.l.WithField("reader", 0).WithError(err).Error("error closing tun reader") t.l.Error("error closing tun reader", "reader", 0, "error", err)
} else { } else {
t.l.WithField("reader", 0).Info("closed tun reader") t.l.Info("closed tun reader", "reader", 0)
} }
return err return err
} }

View File

@@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log/slog"
"net/netip" "net/netip"
"os" "os"
"regexp" "regexp"
@@ -15,7 +16,6 @@ import (
"unsafe" "unsafe"
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
@@ -63,18 +63,18 @@ type tun struct {
MTU int MTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger l *slog.Logger
f *os.File f *os.File
fd int fd int
} }
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
} }
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device // Try to open tun device
var err error var err error
deviceName := c.GetString("tun.dev", "") deviceName := c.GetString("tun.dev", "")
@@ -92,7 +92,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
err = unix.SetNonblock(fd, true) err = unix.SetNonblock(fd, true)
if err != nil { if err != nil {
l.WithError(err).Warn("Failed to set the tun device as nonblocking") l.Warn("Failed to set the tun device as nonblocking", "error", err)
} }
t := &tun{ t := &tun{
@@ -416,7 +416,7 @@ func (t *tun) addRoutes(logErrors bool) error {
return retErr return retErr
} }
} else { } else {
t.l.WithField("route", r).Info("Added route") t.l.Info("Added route", "route", r)
} }
} }
@@ -431,9 +431,9 @@ func (t *tun) removeRoutes(routes []Route) error {
err := delRoute(r.Cidr, t.vpnNetworks) err := delRoute(r.Cidr, t.vpnNetworks)
if err != nil { if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route") t.l.Error("Failed to remove route", "error", err, "route", r)
} else { } else {
t.l.WithField("route", r).Info("Removed route") t.l.Info("Removed route", "route", r)
} }
} }
return nil return nil

View File

@@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log/slog"
"net/netip" "net/netip"
"os" "os"
"regexp" "regexp"
@@ -15,7 +16,6 @@ import (
"unsafe" "unsafe"
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
@@ -54,7 +54,7 @@ type tun struct {
MTU int MTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger l *slog.Logger
f *os.File f *os.File
fd int fd int
// cache out buffer since we need to prepend 4 bytes for tun metadata // cache out buffer since we need to prepend 4 bytes for tun metadata
@@ -63,11 +63,11 @@ type tun struct {
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in openbsd") return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
} }
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device // Try to open tun device
var err error var err error
deviceName := c.GetString("tun.dev", "") deviceName := c.GetString("tun.dev", "")
@@ -85,7 +85,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
err = unix.SetNonblock(fd, true) err = unix.SetNonblock(fd, true)
if err != nil { if err != nil {
l.WithError(err).Warn("Failed to set the tun device as nonblocking") l.Warn("Failed to set the tun device as nonblocking", "error", err)
} }
t := &tun{ t := &tun{
@@ -336,7 +336,7 @@ func (t *tun) addRoutes(logErrors bool) error {
return retErr return retErr
} }
} else { } else {
t.l.WithField("route", r).Info("Added route") t.l.Info("Added route", "route", r)
} }
} }
@@ -351,9 +351,9 @@ func (t *tun) removeRoutes(routes []Route) error {
err := delRoute(r.Cidr, t.vpnNetworks) err := delRoute(r.Cidr, t.vpnNetworks)
if err != nil { if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route") t.l.Error("Failed to remove route", "error", err, "route", r)
} else { } else {
t.l.WithField("route", r).Info("Removed route") t.l.Info("Removed route", "route", r)
} }
} }
return nil return nil

View File

@@ -4,14 +4,15 @@
package overlay package overlay
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"log/slog"
"net/netip" "net/netip"
"os" "os"
"sync/atomic" "sync/atomic"
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
@@ -21,14 +22,14 @@ type TestTun struct {
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
Routes []Route Routes []Route
routeTree *bart.Table[routing.Gateways] routeTree *bart.Table[routing.Gateways]
l *logrus.Logger l *slog.Logger
closed atomic.Bool closed atomic.Bool
rxPackets chan []byte // Packets to receive into nebula rxPackets chan []byte // Packets to receive into nebula
TxPackets chan []byte // Packets transmitted outside by nebula TxPackets chan []byte // Packets transmitted outside by nebula
} }
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
_, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true) _, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -49,7 +50,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
}, nil }, nil
} }
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) { func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*TestTun, error) {
return nil, fmt.Errorf("newTunFromFd not supported") return nil, fmt.Errorf("newTunFromFd not supported")
} }
@@ -61,8 +62,8 @@ func (t *TestTun) Send(packet []byte) {
return return
} }
if t.l.Level >= logrus.DebugLevel { if t.l.Enabled(context.Background(), slog.LevelDebug) {
t.l.WithField("dataLen", len(packet)).Debug("Tun receiving injected packet") t.l.Debug("Tun receiving injected packet", "dataLen", len(packet))
} }
t.rxPackets <- packet t.rxPackets <- packet
} }

View File

@@ -7,6 +7,7 @@ import (
"crypto" "crypto"
"fmt" "fmt"
"io" "io"
"log/slog"
"net/netip" "net/netip"
"os" "os"
"path/filepath" "path/filepath"
@@ -16,7 +17,6 @@ import (
"unsafe" "unsafe"
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
@@ -33,16 +33,16 @@ type winTun struct {
MTU int MTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger l *slog.Logger
tun *wintun.NativeTun tun *wintun.NativeTun
} }
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) { func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (Device, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows") return nil, fmt.Errorf("newTunFromFd not supported in Windows")
} }
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) { func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
err := checkWinTunExists() err := checkWinTunExists()
if err != nil { if err != nil {
return nil, fmt.Errorf("can not load the wintun driver: %w", err) return nil, fmt.Errorf("can not load the wintun driver: %w", err)
@@ -71,7 +71,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
if err != nil { if err != nil {
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device. // Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
// Trying a second time resolves the issue. // Trying a second time resolves the issue.
l.WithError(err).Debug("Failed to create wintun device, retrying") l.Debug("Failed to create wintun device, retrying", "error", err)
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
if err != nil { if err != nil {
return nil, &NameError{ return nil, &NameError{
@@ -170,7 +170,7 @@ func (t *winTun) addRoutes(logErrors bool) error {
return retErr return retErr
} }
} else { } else {
t.l.WithField("route", r).Info("Added route") t.l.Info("Added route", "route", r)
} }
if !foundDefault4 { if !foundDefault4 {
@@ -208,9 +208,9 @@ func (t *winTun) removeRoutes(routes []Route) error {
// See comment on luid.AddRoute // See comment on luid.AddRoute
err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr()) err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
if err != nil { if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route") t.l.Error("Failed to remove route", "error", err, "route", r)
} else { } else {
t.l.WithField("route", r).Info("Removed route") t.l.Info("Removed route", "route", r)
} }
} }
return nil return nil

View File

@@ -2,14 +2,14 @@ package overlay
import ( import (
"io" "io"
"log/slog"
"net/netip" "net/netip"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { func NewUserDeviceFromConfig(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
return NewUserDevice(vpnNetworks) return NewUserDevice(vpnNetworks)
} }

18
pki.go
View File

@@ -6,6 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log/slog"
"net" "net"
"net/netip" "net/netip"
"os" "os"
@@ -15,7 +16,6 @@ import (
"time" "time"
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
@@ -24,7 +24,7 @@ import (
type PKI struct { type PKI struct {
cs atomic.Pointer[CertState] cs atomic.Pointer[CertState]
caPool atomic.Pointer[cert.CAPool] caPool atomic.Pointer[cert.CAPool]
l *logrus.Logger l *slog.Logger
} }
type CertState struct { type CertState struct {
@@ -46,7 +46,7 @@ type CertState struct {
myVpnBroadcastAddrsTable *bart.Lite myVpnBroadcastAddrsTable *bart.Lite
} }
func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { func NewPKIFromConfig(l *slog.Logger, c *config.C) (*PKI, error) {
pki := &PKI{l: l} pki := &PKI{l: l}
err := pki.reload(c, true) err := pki.reload(c, true)
if err != nil { if err != nil {
@@ -182,9 +182,9 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
p.cs.Store(newState) p.cs.Store(newState)
if initial { if initial {
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)") p.l.Debug("Client nebula certificate(s)", "cert", newState)
} else { } else {
p.l.WithField("cert", newState).Info("Client certificate(s) refreshed from disk") p.l.Info("Client certificate(s) refreshed from disk", "cert", newState)
} }
return nil return nil
} }
@@ -196,7 +196,7 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
} }
p.caPool.Store(caPool) p.caPool.Store(caPool)
p.l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints") p.l.Debug("Trusted CA fingerprints", "fingerprints", caPool.GetFingerprints())
return nil return nil
} }
@@ -487,7 +487,7 @@ func loadCertificate(b []byte) (cert.Certificate, []byte, error) {
return c, b, nil return c, b, nil
} }
func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { func loadCAPoolFromConfig(l *slog.Logger, c *config.C) (*cert.CAPool, error) {
caPathOrPEM := c.GetString("pki.ca", "") caPathOrPEM := c.GetString("pki.ca", "")
if caPathOrPEM == "" { if caPathOrPEM == "" {
return nil, errors.New("no pki.ca path or PEM data provided") return nil, errors.New("no pki.ca path or PEM data provided")
@@ -512,7 +512,7 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
for _, crt := range caPool.CAs { for _, crt := range caPool.CAs {
if crt.Certificate.Expired(time.Now()) { if crt.Certificate.Expired(time.Now()) {
expired++ expired++
l.WithField("cert", crt).Warn("expired certificate present in CA pool") l.Warn("expired certificate present in CA pool", "cert", crt)
} }
} }
@@ -530,7 +530,7 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
caPool.BlocklistFingerprint(fp) caPool.BlocklistFingerprint(fp)
} }
l.WithField("fingerprintCount", len(bl)).Info("Blocklisted certificates") l.Info("Blocklisted certificates", "fingerprintCount", len(bl))
} }
return caPool, nil return caPool, nil

View File

@@ -41,7 +41,7 @@ func BenchmarkReloadConfigWithCAs(b *testing.B) {
c := config.NewC(l) c := config.NewC(l)
require.NoError(b, c.Load(dir)) require.NoError(b, c.Load(dir))
_, err := NewPKIFromConfig(l, c) _, err := NewPKIFromConfig(test.NewLogger(), c)
require.NoError(b, err) require.NoError(b, err)
b.ReportAllocs() b.ReportAllocs()

View File

@@ -1,10 +1,10 @@
package nebula package nebula
import ( import (
"log/slog"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
) )
@@ -14,10 +14,10 @@ type Punchy struct {
delay atomic.Int64 delay atomic.Int64
respondDelay atomic.Int64 respondDelay atomic.Int64
punchEverything atomic.Bool punchEverything atomic.Bool
l *logrus.Logger l *slog.Logger
} }
func NewPunchyFromConfig(l *logrus.Logger, c *config.C) *Punchy { func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy {
p := &Punchy{l: l} p := &Punchy{l: l}
p.reload(c, true) p.reload(c, true)
@@ -62,7 +62,7 @@ func (p *Punchy) reload(c *config.C, initial bool) {
p.respond.Store(yes) p.respond.Store(yes)
if !initial { if !initial {
p.l.Infof("punchy.respond changed to %v", p.GetRespond()) p.l.Info("punchy.respond changed", "respond", p.GetRespond())
} }
} }
@@ -70,21 +70,21 @@ func (p *Punchy) reload(c *config.C, initial bool) {
if initial || c.HasChanged("punchy.delay") { if initial || c.HasChanged("punchy.delay") {
p.delay.Store((int64)(c.GetDuration("punchy.delay", time.Second))) p.delay.Store((int64)(c.GetDuration("punchy.delay", time.Second)))
if !initial { if !initial {
p.l.Infof("punchy.delay changed to %s", p.GetDelay()) p.l.Info("punchy.delay changed", "delay", p.GetDelay())
} }
} }
if initial || c.HasChanged("punchy.target_all_remotes") { if initial || c.HasChanged("punchy.target_all_remotes") {
p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false)) p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false))
if !initial { if !initial {
p.l.WithField("target_all_remotes", p.GetTargetEverything()).Info("punchy.target_all_remotes changed") p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", p.GetTargetEverything())
} }
} }
if initial || c.HasChanged("punchy.respond_delay") { if initial || c.HasChanged("punchy.respond_delay") {
p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second))) p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second)))
if !initial { if !initial {
p.l.Infof("punchy.respond_delay changed to %s", p.GetRespondDelay()) p.l.Info("punchy.respond_delay changed", "respond_delay", p.GetRespondDelay())
} }
} }
} }

View File

@@ -1,6 +1,8 @@
package nebula package nebula
import ( import (
"context"
"log/slog"
"testing" "testing"
"time" "time"
@@ -15,7 +17,7 @@ func TestNewPunchyFromConfig(t *testing.T) {
c := config.NewC(l) c := config.NewC(l)
// Test defaults // Test defaults
p := NewPunchyFromConfig(l, c) p := NewPunchyFromConfig(test.NewLogger(), c)
assert.False(t, p.GetPunch()) assert.False(t, p.GetPunch())
assert.False(t, p.GetRespond()) assert.False(t, p.GetRespond())
assert.Equal(t, time.Second, p.GetDelay()) assert.Equal(t, time.Second, p.GetDelay())
@@ -23,33 +25,33 @@ func TestNewPunchyFromConfig(t *testing.T) {
// punchy deprecation // punchy deprecation
c.Settings["punchy"] = true c.Settings["punchy"] = true
p = NewPunchyFromConfig(l, c) p = NewPunchyFromConfig(test.NewLogger(), c)
assert.True(t, p.GetPunch()) assert.True(t, p.GetPunch())
// punchy.punch // punchy.punch
c.Settings["punchy"] = map[string]any{"punch": true} c.Settings["punchy"] = map[string]any{"punch": true}
p = NewPunchyFromConfig(l, c) p = NewPunchyFromConfig(test.NewLogger(), c)
assert.True(t, p.GetPunch()) assert.True(t, p.GetPunch())
// punch_back deprecation // punch_back deprecation
c.Settings["punch_back"] = true c.Settings["punch_back"] = true
p = NewPunchyFromConfig(l, c) p = NewPunchyFromConfig(test.NewLogger(), c)
assert.True(t, p.GetRespond()) assert.True(t, p.GetRespond())
// punchy.respond // punchy.respond
c.Settings["punchy"] = map[string]any{"respond": true} c.Settings["punchy"] = map[string]any{"respond": true}
c.Settings["punch_back"] = false c.Settings["punch_back"] = false
p = NewPunchyFromConfig(l, c) p = NewPunchyFromConfig(test.NewLogger(), c)
assert.True(t, p.GetRespond()) assert.True(t, p.GetRespond())
// punchy.delay // punchy.delay
c.Settings["punchy"] = map[string]any{"delay": "1m"} c.Settings["punchy"] = map[string]any{"delay": "1m"}
p = NewPunchyFromConfig(l, c) p = NewPunchyFromConfig(test.NewLogger(), c)
assert.Equal(t, time.Minute, p.GetDelay()) assert.Equal(t, time.Minute, p.GetDelay())
// punchy.respond_delay // punchy.respond_delay
c.Settings["punchy"] = map[string]any{"respond_delay": "1m"} c.Settings["punchy"] = map[string]any{"respond_delay": "1m"}
p = NewPunchyFromConfig(l, c) p = NewPunchyFromConfig(test.NewLogger(), c)
assert.Equal(t, time.Minute, p.GetRespondDelay()) assert.Equal(t, time.Minute, p.GetRespondDelay())
} }
@@ -62,7 +64,7 @@ punchy:
delay: 1m delay: 1m
respond: false respond: false
`)) `))
p := NewPunchyFromConfig(l, c) p := NewPunchyFromConfig(test.NewLogger(), c)
assert.Equal(t, delay, p.GetDelay()) assert.Equal(t, delay, p.GetDelay())
assert.False(t, p.GetRespond()) assert.False(t, p.GetRespond())
@@ -76,3 +78,158 @@ punchy:
assert.Equal(t, newDelay, p.GetDelay()) assert.Equal(t, newDelay, p.GetDelay())
assert.True(t, p.GetRespond()) assert.True(t, p.GetRespond())
} }
// The tests below pin the shape of each log line Punchy produces so changes
// cannot silently break whatever operators are grepping for. The assertions
// are on the structured message + attrs (e.g. "punchy.respond changed" with
// a respond=true field) rather than a formatted string.
//
// Punchy.reload also emits a spurious "Changing punchy.punch with reload is
// not supported" warning whenever any key under punchy changes, because of
// the c.HasChanged("punchy") fallback kept for the deprecated top-level
// punchy form. The tests filter by message rather than asserting total
// entry counts so that warning is tolerated without being locked into
// the format.
type capturedEntry struct {
Level slog.Level
Msg string
Attrs map[string]any
}
// capturingHandler is a slog.Handler that records each Record it receives so
// tests can assert on the level, message, and attribute map of individual log
// lines without coupling to any specific text format.
type capturingHandler struct {
entries []capturedEntry
}
func (h *capturingHandler) Enabled(_ context.Context, _ slog.Level) bool { return true }
func (h *capturingHandler) Handle(_ context.Context, r slog.Record) error {
e := capturedEntry{
Level: r.Level,
Msg: r.Message,
Attrs: make(map[string]any),
}
r.Attrs(func(a slog.Attr) bool {
e.Attrs[a.Key] = a.Value.Resolve().Any()
return true
})
h.entries = append(h.entries, e)
return nil
}
func (h *capturingHandler) WithAttrs(_ []slog.Attr) slog.Handler { return h }
func (h *capturingHandler) WithGroup(_ string) slog.Handler { return h }
func newCapturingPunchyLogger(t *testing.T) (*slog.Logger, *capturingHandler) {
t.Helper()
hook := &capturingHandler{}
return slog.New(hook), hook
}
func findEntry(t *testing.T, entries []capturedEntry, msg string) capturedEntry {
t.Helper()
for _, e := range entries {
if e.Msg == msg {
return e
}
}
t.Fatalf("no entry with message %q among %d entries", msg, len(entries))
return capturedEntry{}
}
func TestPunchy_LogFormat_InitialEnabled(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {punch: true}`))
NewPunchyFromConfig(l, c)
entry := findEntry(t, hook.entries, "punchy enabled")
assert.Equal(t, slog.LevelInfo, entry.Level)
assert.Empty(t, entry.Attrs)
}
func TestPunchy_LogFormat_InitialDisabled(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {punch: false}`))
NewPunchyFromConfig(l, c)
entry := findEntry(t, hook.entries, "punchy disabled")
assert.Equal(t, slog.LevelInfo, entry.Level)
assert.Empty(t, entry.Attrs)
}
func TestPunchy_LogFormat_ReloadPunchUnsupported(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {punch: false}`))
NewPunchyFromConfig(l, c)
hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {punch: true}`))
entry := findEntry(t, hook.entries, "Changing punchy.punch with reload is not supported, ignoring.")
assert.Equal(t, slog.LevelWarn, entry.Level)
assert.Empty(t, entry.Attrs)
}
func TestPunchy_LogFormat_ReloadRespond(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {respond: false}`))
NewPunchyFromConfig(l, c)
hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {respond: true}`))
entry := findEntry(t, hook.entries, "punchy.respond changed")
assert.Equal(t, slog.LevelInfo, entry.Level)
assert.Equal(t, map[string]any{"respond": true}, entry.Attrs)
}
func TestPunchy_LogFormat_ReloadDelay(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {delay: 1s}`))
NewPunchyFromConfig(l, c)
hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {delay: 10s}`))
entry := findEntry(t, hook.entries, "punchy.delay changed")
assert.Equal(t, slog.LevelInfo, entry.Level)
assert.Equal(t, map[string]any{"delay": 10 * time.Second}, entry.Attrs)
}
func TestPunchy_LogFormat_ReloadTargetAllRemotes(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {target_all_remotes: false}`))
NewPunchyFromConfig(l, c)
hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {target_all_remotes: true}`))
entry := findEntry(t, hook.entries, "punchy.target_all_remotes changed")
assert.Equal(t, slog.LevelInfo, entry.Level)
assert.Equal(t, map[string]any{"target_all_remotes": true}, entry.Attrs)
}
func TestPunchy_LogFormat_ReloadRespondDelay(t *testing.T) {
l, hook := newCapturingPunchyLogger(t)
c := config.NewC(test.NewLogger())
require.NoError(t, c.LoadString(`punchy: {respond_delay: 5s}`))
NewPunchyFromConfig(l, c)
hook.entries = nil
require.NoError(t, c.ReloadConfigString(`punchy: {respond_delay: 15s}`))
entry := findEntry(t, hook.entries, "punchy.respond_delay changed")
assert.Equal(t, slog.LevelInfo, entry.Level)
assert.Equal(t, map[string]any{"respond_delay": 15 * time.Second}, entry.Attrs)
}

View File

@@ -5,22 +5,22 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"net/netip" "net/netip"
"sync/atomic" "sync/atomic"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
) )
type relayManager struct { type relayManager struct {
l *logrus.Logger l *slog.Logger
hostmap *HostMap hostmap *HostMap
amRelay atomic.Bool amRelay atomic.Bool
} }
func NewRelayManager(ctx context.Context, l *logrus.Logger, hostmap *HostMap, c *config.C) *relayManager { func NewRelayManager(ctx context.Context, l *slog.Logger, hostmap *HostMap, c *config.C) *relayManager {
rm := &relayManager{ rm := &relayManager{
l: l, l: l,
hostmap: hostmap, hostmap: hostmap,
@@ -29,7 +29,7 @@ func NewRelayManager(ctx context.Context, l *logrus.Logger, hostmap *HostMap, c
c.RegisterReloadCallback(func(c *config.C) { c.RegisterReloadCallback(func(c *config.C) {
err := rm.reload(c, false) err := rm.reload(c, false)
if err != nil { if err != nil {
l.WithError(err).Error("Failed to reload relay_manager") rm.l.Error("Failed to reload relay_manager", "error", err)
} }
}) })
return rm return rm
@@ -52,7 +52,7 @@ func (rm *relayManager) setAmRelay(v bool) {
// AddRelay finds an available relay index on the hostmap, and associates the relay info with it. // AddRelay finds an available relay index on the hostmap, and associates the relay info with it.
// relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp. // relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp.
func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) { func AddRelay(l *slog.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) {
hm.Lock() hm.Lock()
defer hm.Unlock() defer hm.Unlock()
for range 32 { for range 32 {
@@ -92,24 +92,24 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti
func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) { func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) {
relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex) relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex)
if !ok { if !ok {
fields := logrus.Fields{ var relayFrom, relayTo any
"relay": relayHostInfo.vpnAddrs[0],
"initiatorRelayIndex": m.InitiatorRelayIndex,
}
if m.RelayFromAddr == nil { if m.RelayFromAddr == nil {
fields["relayFrom"] = m.OldRelayFromAddr relayFrom = m.OldRelayFromAddr
} else { } else {
fields["relayFrom"] = m.RelayFromAddr relayFrom = m.RelayFromAddr
} }
if m.RelayToAddr == nil { if m.RelayToAddr == nil {
fields["relayTo"] = m.OldRelayToAddr relayTo = m.OldRelayToAddr
} else { } else {
fields["relayTo"] = m.RelayToAddr relayTo = m.RelayToAddr
} }
rm.l.WithFields(fields).Info("relayManager failed to update relay") rm.l.Info("relayManager failed to update relay",
"relay", relayHostInfo.vpnAddrs[0],
"initiatorRelayIndex", m.InitiatorRelayIndex,
"relayFrom", relayFrom,
"relayTo", relayTo,
)
return nil, fmt.Errorf("unknown relay") return nil, fmt.Errorf("unknown relay")
} }
@@ -120,7 +120,7 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) {
msg := &NebulaControl{} msg := &NebulaControl{}
err := msg.Unmarshal(d) err := msg.Unmarshal(d)
if err != nil { if err != nil {
h.logger(f.l).WithError(err).Error("Failed to unmarshal control message") h.logger(f.l).Error("Failed to unmarshal control message", "error", err)
return return
} }
@@ -147,20 +147,20 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) {
} }
func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) { func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) {
rm.l.WithFields(logrus.Fields{ rm.l.Info("handleCreateRelayResponse",
"relayFrom": protoAddrToNetAddr(m.RelayFromAddr), "relayFrom", protoAddrToNetAddr(m.RelayFromAddr),
"relayTo": protoAddrToNetAddr(m.RelayToAddr), "relayTo", protoAddrToNetAddr(m.RelayToAddr),
"initiatorRelayIndex": m.InitiatorRelayIndex, "initiatorRelayIndex", m.InitiatorRelayIndex,
"responderRelayIndex": m.ResponderRelayIndex, "responderRelayIndex", m.ResponderRelayIndex,
"vpnAddrs": h.vpnAddrs}). "vpnAddrs", h.vpnAddrs,
Info("handleCreateRelayResponse") )
target := m.RelayToAddr target := m.RelayToAddr
targetAddr := protoAddrToNetAddr(target) targetAddr := protoAddrToNetAddr(target)
relay, err := rm.EstablishRelay(h, m) relay, err := rm.EstablishRelay(h, m)
if err != nil { if err != nil {
rm.l.WithError(err).Error("Failed to update relay for relayTo") rm.l.Error("Failed to update relay for relayTo", "error", err)
return return
} }
// Do I need to complete the relays now? // Do I need to complete the relays now?
@@ -170,12 +170,12 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
// I'm the middle man. Let the initiator know that the I've established the relay they requested. // I'm the middle man. Let the initiator know that the I've established the relay they requested.
peerHostInfo := rm.hostmap.QueryVpnAddr(relay.PeerAddr) peerHostInfo := rm.hostmap.QueryVpnAddr(relay.PeerAddr)
if peerHostInfo == nil { if peerHostInfo == nil {
rm.l.WithField("relayTo", relay.PeerAddr).Error("Can't find a HostInfo for peer") rm.l.Error("Can't find a HostInfo for peer", "relayTo", relay.PeerAddr)
return return
} }
peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr) peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr)
if !ok { if !ok {
rm.l.WithField("relayTo", peerHostInfo.vpnAddrs[0]).Error("peerRelay does not have Relay state for relayTo") rm.l.Error("peerRelay does not have Relay state for relayTo", "relayTo", peerHostInfo.vpnAddrs[0])
return return
} }
switch peerRelay.State { switch peerRelay.State {
@@ -193,12 +193,13 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
if v == cert.Version1 { if v == cert.Version1 {
peer := peerHostInfo.vpnAddrs[0] peer := peerHostInfo.vpnAddrs[0]
if !peer.Is4() { if !peer.Is4() {
rm.l.WithField("relayFrom", peer). rm.l.Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address",
WithField("relayTo", target). "relayFrom", peer,
WithField("initiatorRelayIndex", resp.InitiatorRelayIndex). "relayTo", target,
WithField("responderRelayIndex", resp.ResponderRelayIndex). "initiatorRelayIndex", resp.InitiatorRelayIndex,
WithField("vpnAddrs", peerHostInfo.vpnAddrs). "responderRelayIndex", resp.ResponderRelayIndex,
Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address") "vpnAddrs", peerHostInfo.vpnAddrs,
)
return return
} }
@@ -213,17 +214,16 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
msg, err := resp.Marshal() msg, err := resp.Marshal()
if err != nil { if err != nil {
rm.l.WithError(err). rm.l.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err)
Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
} else { } else {
f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu)) f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{ rm.l.Info("send CreateRelayResponse",
"relayFrom": resp.RelayFromAddr, "relayFrom", resp.RelayFromAddr,
"relayTo": resp.RelayToAddr, "relayTo", resp.RelayToAddr,
"initiatorRelayIndex": resp.InitiatorRelayIndex, "initiatorRelayIndex", resp.InitiatorRelayIndex,
"responderRelayIndex": resp.ResponderRelayIndex, "responderRelayIndex", resp.ResponderRelayIndex,
"vpnAddrs": peerHostInfo.vpnAddrs}). "vpnAddrs", peerHostInfo.vpnAddrs,
Info("send CreateRelayResponse") )
} }
} }
} }
@@ -232,17 +232,18 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
from := protoAddrToNetAddr(m.RelayFromAddr) from := protoAddrToNetAddr(m.RelayFromAddr)
target := protoAddrToNetAddr(m.RelayToAddr) target := protoAddrToNetAddr(m.RelayToAddr)
logMsg := rm.l.WithFields(logrus.Fields{ logMsg := rm.l.With(
"relayFrom": from, "relayFrom", from,
"relayTo": target, "relayTo", target,
"initiatorRelayIndex": m.InitiatorRelayIndex, "initiatorRelayIndex", m.InitiatorRelayIndex,
"vpnAddrs": h.vpnAddrs}) "vpnAddrs", h.vpnAddrs,
)
logMsg.Info("handleCreateRelayRequest") logMsg.Info("handleCreateRelayRequest")
// Is the source of the relay me? This should never happen, but did happen due to // Is the source of the relay me? This should never happen, but did happen due to
// an issue migrating relays over to newly re-handshaked host info objects. // an issue migrating relays over to newly re-handshaked host info objects.
if f.myVpnAddrsTable.Contains(from) { if f.myVpnAddrsTable.Contains(from) {
logMsg.WithField("myIP", from).Error("Discarding relay request from myself") logMsg.Error("Discarding relay request from myself", "myIP", from)
return return
} }
@@ -261,37 +262,37 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
if existingRelay.RemoteIndex != m.InitiatorRelayIndex { if existingRelay.RemoteIndex != m.InitiatorRelayIndex {
// We got a brand new Relay request, because its index is different than what we saw before. // We got a brand new Relay request, because its index is different than what we saw before.
// This should never happen. The peer should never change an index, once created. // This should never happen. The peer should never change an index, once created.
logMsg.WithFields(logrus.Fields{ logMsg.Error("Existing relay mismatch with CreateRelayRequest",
"existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") "existingRemoteIndex", existingRelay.RemoteIndex)
return return
} }
case Disestablished: case Disestablished:
if existingRelay.RemoteIndex != m.InitiatorRelayIndex { if existingRelay.RemoteIndex != m.InitiatorRelayIndex {
// We got a brand new Relay request, because its index is different than what we saw before. // We got a brand new Relay request, because its index is different than what we saw before.
// This should never happen. The peer should never change an index, once created. // This should never happen. The peer should never change an index, once created.
logMsg.WithFields(logrus.Fields{ logMsg.Error("Existing relay mismatch with CreateRelayRequest",
"existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") "existingRemoteIndex", existingRelay.RemoteIndex)
return return
} }
// Mark the relay as 'Established' because it's safe to use again // Mark the relay as 'Established' because it's safe to use again
h.relayState.UpdateRelayForByIpState(from, Established) h.relayState.UpdateRelayForByIpState(from, Established)
case PeerRequested: case PeerRequested:
// I should never be in this state, because I am terminal, not forwarding. // I should never be in this state, because I am terminal, not forwarding.
logMsg.WithFields(logrus.Fields{ logMsg.Error("Unexpected Relay State found",
"existingRemoteIndex": existingRelay.RemoteIndex, "existingRemoteIndex", existingRelay.RemoteIndex,
"state": existingRelay.State}).Error("Unexpected Relay State found") "state", existingRelay.State)
} }
} else { } else {
_, err := AddRelay(rm.l, h, f.hostMap, from, &m.InitiatorRelayIndex, TerminalType, Established) _, err := AddRelay(rm.l, h, f.hostMap, from, &m.InitiatorRelayIndex, TerminalType, Established)
if err != nil { if err != nil {
logMsg.WithError(err).Error("Failed to add relay") logMsg.Error("Failed to add relay", "error", err)
return return
} }
} }
relay, ok := h.relayState.QueryRelayForByIp(from) relay, ok := h.relayState.QueryRelayForByIp(from)
if !ok { if !ok {
logMsg.WithField("from", from).Error("Relay State not found") logMsg.Error("Relay State not found", "from", from)
return return
} }
@@ -313,17 +314,16 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
msg, err := resp.Marshal() msg, err := resp.Marshal()
if err != nil { if err != nil {
logMsg. logMsg.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err)
WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
} else { } else {
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{ rm.l.Info("send CreateRelayResponse",
"relayFrom": from, "relayFrom", from,
"relayTo": target, "relayTo", target,
"initiatorRelayIndex": resp.InitiatorRelayIndex, "initiatorRelayIndex", resp.InitiatorRelayIndex,
"responderRelayIndex": resp.ResponderRelayIndex, "responderRelayIndex", resp.ResponderRelayIndex,
"vpnAddrs": h.vpnAddrs}). "vpnAddrs", h.vpnAddrs,
Info("send CreateRelayResponse") )
} }
return return
} else { } else {
@@ -363,12 +363,13 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
if v == cert.Version1 { if v == cert.Version1 {
if !h.vpnAddrs[0].Is4() { if !h.vpnAddrs[0].Is4() {
rm.l.WithField("relayFrom", h.vpnAddrs[0]). rm.l.Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address",
WithField("relayTo", target). "relayFrom", h.vpnAddrs[0],
WithField("initiatorRelayIndex", req.InitiatorRelayIndex). "relayTo", target,
WithField("responderRelayIndex", req.ResponderRelayIndex). "initiatorRelayIndex", req.InitiatorRelayIndex,
WithField("vpnAddr", target). "responderRelayIndex", req.ResponderRelayIndex,
Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address") "vpnAddr", target,
)
return return
} }
@@ -383,17 +384,16 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
msg, err := req.Marshal() msg, err := req.Marshal()
if err != nil { if err != nil {
logMsg. logMsg.Error("relayManager Failed to marshal Control message to create relay", "error", err)
WithError(err).Error("relayManager Failed to marshal Control message to create relay")
} else { } else {
f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{ rm.l.Info("send CreateRelayRequest",
"relayFrom": h.vpnAddrs[0], "relayFrom", h.vpnAddrs[0],
"relayTo": target, "relayTo", target,
"initiatorRelayIndex": req.InitiatorRelayIndex, "initiatorRelayIndex", req.InitiatorRelayIndex,
"responderRelayIndex": req.ResponderRelayIndex, "responderRelayIndex", req.ResponderRelayIndex,
"vpnAddr": target}). "vpnAddr", target,
Info("send CreateRelayRequest") )
} }
// Also track the half-created Relay state just received // Also track the half-created Relay state just received
@@ -401,8 +401,7 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
if !ok { if !ok {
_, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, PeerRequested) _, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, PeerRequested)
if err != nil { if err != nil {
logMsg. logMsg.Error("relayManager Failed to allocate a local index for relay", "error", err)
WithError(err).Error("relayManager Failed to allocate a local index for relay")
return return
} }
} }

View File

@@ -2,6 +2,7 @@ package nebula
import ( import (
"context" "context"
"log/slog"
"net" "net"
"net/netip" "net/netip"
"slices" "slices"
@@ -10,8 +11,6 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/sirupsen/logrus"
) )
// forEachFunc is used to benefit folks that want to do work inside the lock // forEachFunc is used to benefit folks that want to do work inside the lock
@@ -66,11 +65,11 @@ type hostnamesResults struct {
network string network string
lookupTimeout time.Duration lookupTimeout time.Duration
cancelFn func() cancelFn func()
l *logrus.Logger l *slog.Logger
ips atomic.Pointer[map[netip.AddrPort]struct{}] ips atomic.Pointer[map[netip.AddrPort]struct{}]
} }
func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) { func NewHostnameResults(ctx context.Context, l *slog.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) {
r := &hostnamesResults{ r := &hostnamesResults{
hostnames: make([]hostnamePort, len(hostPorts)), hostnames: make([]hostnamePort, len(hostPorts)),
network: network, network: network,
@@ -121,7 +120,11 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, r.network, hostPort.name) addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, r.network, hostPort.name)
timeoutCancel() timeoutCancel()
if err != nil { if err != nil {
l.WithFields(logrus.Fields{"hostname": hostPort.name, "network": r.network}).WithError(err).Error("DNS resolution failed for static_map host") l.Error("DNS resolution failed for static_map host",
"hostname", hostPort.name,
"network", r.network,
"error", err,
)
continue continue
} }
for _, a := range addrs { for _, a := range addrs {
@@ -145,7 +148,10 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
} }
} }
if different { if different {
l.WithFields(logrus.Fields{"origSet": origSet, "newSet": netipAddrs}).Info("DNS results changed for host list") l.Info("DNS results changed for host list",
"origSet", origSet,
"newSet", netipAddrs,
)
r.ips.Store(&netipAddrs) r.ips.Store(&netipAddrs)
onUpdate() onUpdate()
} }

View File

@@ -10,11 +10,11 @@ import (
"time" "time"
"dario.cat/mergo" "dario.cat/mergo"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/logging"
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
"go.yaml.in/yaml/v3" "go.yaml.in/yaml/v3"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
@@ -75,8 +75,7 @@ func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp n
panic(err) panic(err)
} }
logger := logrus.New() logger := logging.NewLogger(os.Stdout)
logger.Out = os.Stdout
control, err := nebula.Main(&c, false, "custom-app", logger, overlay.NewUserDeviceFromConfig) control, err := nebula.Main(&c, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
if err != nil { if err != nil {

85
ssh.go
View File

@@ -6,21 +6,21 @@ import (
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"log/slog"
"maps" "maps"
"net" "net"
"net/netip" "net/netip"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"runtime" "runtime"
"runtime/pprof" "runtime/pprof"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/logging"
"github.com/slackhq/nebula/sshd" "github.com/slackhq/nebula/sshd"
) )
@@ -57,12 +57,12 @@ type sshDeviceInfoFlags struct {
Pretty bool Pretty bool
} }
func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) { func wireSSHReload(l *slog.Logger, ssh *sshd.SSHServer, c *config.C) {
c.RegisterReloadCallback(func(c *config.C) { c.RegisterReloadCallback(func(c *config.C) {
if c.GetBool("sshd.enabled", false) { if c.GetBool("sshd.enabled", false) {
sshRun, err := configSSH(l, ssh, c) sshRun, err := configSSH(l, ssh, c)
if err != nil { if err != nil {
l.WithError(err).Error("Failed to reconfigure the sshd") l.Error("Failed to reconfigure the sshd", "error", err)
ssh.Stop() ssh.Stop()
} }
if sshRun != nil { if sshRun != nil {
@@ -78,7 +78,7 @@ func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) {
// updates the passed-in SSHServer. On success, it returns a function // updates the passed-in SSHServer. On success, it returns a function
// that callers may invoke to run the configured ssh server. On // that callers may invoke to run the configured ssh server. On
// failure, it returns nil, error. // failure, it returns nil, error.
func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) { func configSSH(l *slog.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) {
listen := c.GetString("sshd.listen", "") listen := c.GetString("sshd.listen", "")
if listen == "" { if listen == "" {
return nil, fmt.Errorf("sshd.listen must be provided") return nil, fmt.Errorf("sshd.listen must be provided")
@@ -120,7 +120,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
for _, caAuthorizedKey := range rawCAs { for _, caAuthorizedKey := range rawCAs {
err := ssh.AddTrustedCA(caAuthorizedKey) err := ssh.AddTrustedCA(caAuthorizedKey)
if err != nil { if err != nil {
l.WithError(err).WithField("sshCA", caAuthorizedKey).Warn("SSH CA had an error, ignoring") l.Warn("SSH CA had an error, ignoring", "error", err, "sshCA", caAuthorizedKey)
continue continue
} }
} }
@@ -131,13 +131,13 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
for _, rk := range keys { for _, rk := range keys {
kDef, ok := rk.(map[string]any) kDef, ok := rk.(map[string]any)
if !ok { if !ok {
l.WithField("sshKeyConfig", rk).Warn("Authorized user had an error, ignoring") l.Warn("Authorized user had an error, ignoring", "sshKeyConfig", rk)
continue continue
} }
user, ok := kDef["user"].(string) user, ok := kDef["user"].(string)
if !ok { if !ok {
l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the user field") l.Warn("Authorized user is missing the user field", "sshKeyConfig", rk)
continue continue
} }
@@ -146,7 +146,11 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
case string: case string:
err := ssh.AddAuthorizedKey(user, v) err := ssh.AddAuthorizedKey(user, v)
if err != nil { if err != nil {
l.WithError(err).WithField("sshKeyConfig", rk).WithField("sshKey", v).Warn("Failed to authorize key") l.Warn("Failed to authorize key",
"error", err,
"sshKeyConfig", rk,
"sshKey", v,
)
continue continue
} }
@@ -154,19 +158,25 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
for _, subK := range v { for _, subK := range v {
sk, ok := subK.(string) sk, ok := subK.(string)
if !ok { if !ok {
l.WithField("sshKeyConfig", rk).WithField("sshKey", subK).Warn("Did not understand ssh key") l.Warn("Did not understand ssh key",
"sshKeyConfig", rk,
"sshKey", subK,
)
continue continue
} }
err := ssh.AddAuthorizedKey(user, sk) err := ssh.AddAuthorizedKey(user, sk)
if err != nil { if err != nil {
l.WithError(err).WithField("sshKeyConfig", sk).Warn("Failed to authorize key") l.Warn("Failed to authorize key",
"error", err,
"sshKeyConfig", sk,
)
continue continue
} }
} }
default: default:
l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the keys field or was not understood") l.Warn("Authorized user is missing the keys field or was not understood", "sshKeyConfig", rk)
} }
} }
} else { } else {
@@ -178,7 +188,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
ssh.Stop() ssh.Stop()
runner = func() { runner = func() {
if err := ssh.Run(listen); err != nil { if err := ssh.Run(listen); err != nil {
l.WithField("err", err).Warn("Failed to run the SSH server") l.Warn("Failed to run the SSH server", "error", err)
} }
} }
} else { } else {
@@ -188,7 +198,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
return runner, nil return runner, nil
} }
func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) { func attachCommands(l *slog.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) {
// sandboxDir defaults to a dir in temp. The intention is that end user will // sandboxDir defaults to a dir in temp. The intention is that end user will
// create this dir as needed. Overriding this config value to "" allows // create this dir as needed. Overriding this config value to "" allows
// writing to anywhere in the system. // writing to anywhere in the system.
@@ -789,36 +799,45 @@ func sshGetMutexProfile(sandboxDir string, fs any, a []string, w sshd.StringWrit
return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a)) return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a))
} }
func sshLogLevel(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error { func sshLogLevel(l *slog.Logger, fs any, a []string, w sshd.StringWriter) error {
if len(a) == 0 { ctrl, ok := l.Handler().(interface {
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) GetLevel() slog.Level
SetLevel(slog.Level)
})
if !ok {
return w.WriteLine("Log level is not reconfigurable on this logger")
} }
level, err := logrus.ParseLevel(a[0]) if len(a) == 0 {
return w.WriteLine(fmt.Sprintf("Log level is: %s", logging.LevelName(ctrl.GetLevel())))
}
level, err := logging.ParseLevel(strings.ToLower(a[0]))
if err != nil { if err != nil {
return w.WriteLine(fmt.Sprintf("Unknown log level %s. Possible log levels: %s", a, logrus.AllLevels)) return w.WriteLine(fmt.Sprintf("Unknown log level %s. Possible log levels: trace, debug, info, warn, error", a))
} }
l.SetLevel(level) ctrl.SetLevel(level)
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) return w.WriteLine(fmt.Sprintf("Log level is: %s", logging.LevelName(ctrl.GetLevel())))
}
func sshLogFormat(l *slog.Logger, fs any, a []string, w sshd.StringWriter) error {
ctrl, ok := l.Handler().(interface {
GetFormat() string
SetFormat(string) error
})
if !ok {
return w.WriteLine("Log format is not reconfigurable on this logger")
} }
func sshLogFormat(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error {
if len(a) == 0 { if len(a) == 0 {
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter))) return w.WriteLine(fmt.Sprintf("Log format is: %s", ctrl.GetFormat()))
} }
logFormat := strings.ToLower(a[0]) if err := ctrl.SetFormat(strings.ToLower(a[0])); err != nil {
switch logFormat { return err
case "text":
l.Formatter = &logrus.TextFormatter{}
case "json":
l.Formatter = &logrus.JSONFormatter{}
default:
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
} }
return w.WriteLine(fmt.Sprintf("Log format is: %s", ctrl.GetFormat()))
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
} }
func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {

View File

@@ -5,16 +5,16 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"net" "net"
"github.com/armon/go-radix" "github.com/armon/go-radix"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
type SSHServer struct { type SSHServer struct {
config *ssh.ServerConfig config *ssh.ServerConfig
l *logrus.Entry l *slog.Logger
certChecker *ssh.CertChecker certChecker *ssh.CertChecker
@@ -33,7 +33,7 @@ type SSHServer struct {
} }
// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen // NewSSHServer creates a new ssh server rigged with default commands and prepares to listen
func NewSSHServer(l *logrus.Entry) (*SSHServer, error) { func NewSSHServer(l *slog.Logger) (*SSHServer, error) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
s := &SSHServer{ s := &SSHServer{
@@ -121,7 +121,7 @@ func (s *SSHServer) AddTrustedCA(pubKey string) error {
} }
s.trustedCAs = append(s.trustedCAs, pk) s.trustedCAs = append(s.trustedCAs, pk)
s.l.WithField("sshKey", pubKey).Info("Trusted CA key") s.l.Info("Trusted CA key", "sshKey", pubKey)
return nil return nil
} }
@@ -139,7 +139,10 @@ func (s *SSHServer) AddAuthorizedKey(user, pubKey string) error {
} }
tk[string(pk.Marshal())] = true tk[string(pk.Marshal())] = true
s.l.WithField("sshKey", pubKey).WithField("sshUser", user).Info("Authorized ssh key") s.l.Info("Authorized ssh key",
"sshKey", pubKey,
"sshUser", user,
)
return nil return nil
} }
@@ -156,7 +159,7 @@ func (s *SSHServer) Run(addr string) error {
return err return err
} }
s.l.WithField("sshListener", addr).Info("SSH server is listening") s.l.Info("SSH server is listening", "sshListener", addr)
// Run loops until there is an error // Run loops until there is an error
s.run() s.run()
@@ -172,7 +175,7 @@ func (s *SSHServer) run() {
c, err := s.listener.Accept() c, err := s.listener.Accept()
if err != nil { if err != nil {
if !errors.Is(err, net.ErrClosed) { if !errors.Is(err, net.ErrClosed) {
s.l.WithError(err).Warn("Error in listener, shutting down") s.l.Warn("Error in listener, shutting down", "error", err)
} }
return return
} }
@@ -193,23 +196,29 @@ func (s *SSHServer) run() {
} }
if err != nil { if err != nil {
l := s.l.WithError(err).WithField("remoteAddress", c.RemoteAddr()) l := s.l.With(
"error", err,
"remoteAddress", c.RemoteAddr(),
)
if conn != nil { if conn != nil {
l = l.WithField("sshUser", conn.User()) l = l.With("sshUser", conn.User())
conn.Close() conn.Close()
} }
if fp != "" { if fp != "" {
l = l.WithField("sshFingerprint", fp) l = l.With("sshFingerprint", fp)
} }
l.Warn("failed to handshake") l.Warn("failed to handshake")
sessionCancel() sessionCancel()
return return
} }
l := s.l.WithField("sshUser", conn.User()) l := s.l.With("sshUser", conn.User())
l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in") l.Info("ssh user logged in",
"remoteAddress", c.RemoteAddr(),
"sshFingerprint", fp,
)
NewSession(s.commands, conn, chans, sessionCancel, l.WithField("subsystem", "sshd.session")) NewSession(s.commands, conn, chans, sessionCancel, l.With("subsystem", "sshd.session"))
go ssh.DiscardRequests(reqs) go ssh.DiscardRequests(reqs)
@@ -221,7 +230,7 @@ func (s *SSHServer) Stop() {
// Close the listener, this will cause all session to terminate as well, see SSHServer.Run // Close the listener, this will cause all session to terminate as well, see SSHServer.Run
if s.listener != nil { if s.listener != nil {
if err := s.listener.Close(); err != nil { if err := s.listener.Close(); err != nil {
s.l.WithError(err).Warn("Failed to close the sshd listener") s.l.Warn("Failed to close the sshd listener", "error", err)
} }
} }
} }

View File

@@ -2,25 +2,25 @@ package sshd
import ( import (
"fmt" "fmt"
"log/slog"
"sort" "sort"
"strings" "strings"
"github.com/anmitsu/go-shlex" "github.com/anmitsu/go-shlex"
"github.com/armon/go-radix" "github.com/armon/go-radix"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/term" "golang.org/x/term"
) )
type session struct { type session struct {
l *logrus.Entry l *slog.Logger
c *ssh.ServerConn c *ssh.ServerConn
term *term.Terminal term *term.Terminal
commands *radix.Tree commands *radix.Tree
cancel func() cancel func()
} }
func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, cancel func(), l *logrus.Entry) *session { func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, cancel func(), l *slog.Logger) *session {
s := &session{ s := &session{
commands: radix.NewFromMap(commands.ToMap()), commands: radix.NewFromMap(commands.ToMap()),
l: l, l: l,
@@ -45,14 +45,14 @@ func (s *session) handleChannels(chans <-chan ssh.NewChannel) {
defer s.Close() defer s.Close()
for newChannel := range chans { for newChannel := range chans {
if newChannel.ChannelType() != "session" { if newChannel.ChannelType() != "session" {
s.l.WithField("sshChannelType", newChannel.ChannelType()).Error("unknown channel type") s.l.Error("unknown channel type", "sshChannelType", newChannel.ChannelType())
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
continue continue
} }
channel, requests, err := newChannel.Accept() channel, requests, err := newChannel.Accept()
if err != nil { if err != nil {
s.l.WithError(err).Warn("could not accept channel") s.l.Warn("could not accept channel", "error", err)
continue continue
} }
@@ -95,12 +95,12 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) {
return return
default: default:
s.l.WithField("sshRequest", req.Type).Debug("Rejected unknown request") s.l.Debug("Rejected unknown request", "sshRequest", req.Type)
err = req.Reply(false, nil) err = req.Reply(false, nil)
} }
if err != nil { if err != nil {
s.l.WithError(err).Info("Error handling ssh session requests") s.l.Info("Error handling ssh session requests", "error", err)
return return
} }
} }

View File

@@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"log" "log"
"log/slog"
"net" "net"
"net/http" "net/http"
"runtime" "runtime"
@@ -15,14 +16,13 @@ import (
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
) )
// startStats initializes stats from config. On success, if any further work // startStats initializes stats from config. On success, if any further work
// is needed to serve stats, it returns a func to handle that work. If no // is needed to serve stats, it returns a func to handle that work. If no
// work is needed, it'll return nil. On failure, it returns nil, error. // work is needed, it'll return nil. On failure, it returns nil, error.
func startStats(l *logrus.Logger, c *config.C, buildVersion string, configTest bool) (func(), error) { func startStats(l *slog.Logger, c *config.C, buildVersion string, configTest bool) (func(), error) {
mType := c.GetString("stats.type", "") mType := c.GetString("stats.type", "")
if mType == "" || mType == "none" { if mType == "" || mType == "none" {
return nil, nil return nil, nil
@@ -59,7 +59,7 @@ func startStats(l *logrus.Logger, c *config.C, buildVersion string, configTest b
return startFn, nil return startFn, nil
} }
func startGraphiteStats(l *logrus.Logger, i time.Duration, c *config.C, configTest bool) error { func startGraphiteStats(l *slog.Logger, i time.Duration, c *config.C, configTest bool) error {
proto := c.GetString("stats.protocol", "tcp") proto := c.GetString("stats.protocol", "tcp")
host := c.GetString("stats.host", "") host := c.GetString("stats.host", "")
if host == "" { if host == "" {
@@ -73,13 +73,17 @@ func startGraphiteStats(l *logrus.Logger, i time.Duration, c *config.C, configTe
} }
if !configTest { if !configTest {
l.Infof("Starting graphite. Interval: %s, prefix: %s, addr: %s", i, prefix, addr) l.Info("Starting graphite",
"interval", i,
"prefix", prefix,
"addr", addr.String(),
)
go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr) go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr)
} }
return nil return nil
} }
func startPrometheusStats(l *logrus.Logger, i time.Duration, c *config.C, buildVersion string, configTest bool) (func(), error) { func startPrometheusStats(l *slog.Logger, i time.Duration, c *config.C, buildVersion string, configTest bool) (func(), error) {
namespace := c.GetString("stats.namespace", "") namespace := c.GetString("stats.namespace", "")
subsystem := c.GetString("stats.subsystem", "") subsystem := c.GetString("stats.subsystem", "")
@@ -116,9 +120,15 @@ func startPrometheusStats(l *logrus.Logger, i time.Duration, c *config.C, buildV
var startFn func() var startFn func()
if !configTest { if !configTest {
// promhttp.HandlerOpts.ErrorLog needs a stdlib-shaped Println logger,
// so bridge our slog.Logger back to a *log.Logger that emits at Error.
errLog := slog.NewLogLogger(l.Handler(), slog.LevelError)
startFn = func() { startFn = func() {
l.Infof("Prometheus stats listening on %s at %s", listen, path) l.Info("Prometheus stats listening",
http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: l})) "listen", listen,
"path", path,
)
http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: errLog}))
log.Fatal(http.ListenAndServe(listen, nil)) log.Fatal(http.ListenAndServe(listen, nil))
} }
} }

View File

@@ -1,29 +1,73 @@
package test package test
import ( import (
"context"
"io" "io"
"log/slog"
"os" "os"
"time"
"github.com/sirupsen/logrus" "github.com/slackhq/nebula/logging"
) )
func NewLogger() *logrus.Logger { // NewLogger returns a *slog.Logger suitable for use in tests. Output goes to
l := logrus.New() // io.Discard by default; set TEST_LOGS=1 (info), 2 (debug), or 3 (trace) to
// stream output to stderr for local debugging.
func NewLogger() *slog.Logger {
v := os.Getenv("TEST_LOGS") v := os.Getenv("TEST_LOGS")
if v == "" { if v == "" {
l.SetOutput(io.Discard) return slog.New(slog.DiscardHandler)
return l
} }
level := slog.LevelInfo
switch v { switch v {
case "2": case "2":
l.SetLevel(logrus.DebugLevel) level = slog.LevelDebug
case "3": case "3":
l.SetLevel(logrus.TraceLevel) level = logging.LevelTrace
default: }
l.SetLevel(logrus.InfoLevel) return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level}))
} }
return l // NewLoggerWithOutput returns a *slog.Logger whose text output is captured by
// w. Timestamps are suppressed so tests can assert on exact output without
// baking the current time into expected strings.
func NewLoggerWithOutput(w io.Writer) *slog.Logger {
return slog.New(&stripTimeHandler{inner: slog.NewTextHandler(w, nil)})
}
// NewLoggerWithOutputAndLevel is NewLoggerWithOutput with an explicit level
// so tests can exercise Enabled-gated paths.
func NewLoggerWithOutputAndLevel(w io.Writer, level slog.Level) *slog.Logger {
return slog.New(&stripTimeHandler{inner: slog.NewTextHandler(w, &slog.HandlerOptions{Level: level})})
}
// NewJSONLoggerWithOutput returns a *slog.Logger emitting JSON to w with
// timestamps suppressed, for tests that pin the JSON shape.
func NewJSONLoggerWithOutput(w io.Writer, level slog.Level) *slog.Logger {
return slog.New(&stripTimeHandler{inner: slog.NewJSONHandler(w, &slog.HandlerOptions{Level: level})})
}
// stripTimeHandler zeros each record's time before delegating so slog's
// built-in handlers skip emitting the time attribute. Used to avoid
// timestamp-dependent assertions in tests without resorting to ReplaceAttr.
type stripTimeHandler struct {
inner slog.Handler
}
func (h *stripTimeHandler) Enabled(ctx context.Context, l slog.Level) bool {
return h.inner.Enabled(ctx, l)
}
func (h *stripTimeHandler) Handle(ctx context.Context, r slog.Record) error {
r.Time = time.Time{}
return h.inner.Handle(ctx, r)
}
func (h *stripTimeHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
return &stripTimeHandler{inner: h.inner.WithAttrs(attrs)}
}
func (h *stripTimeHandler) WithGroup(name string) slog.Handler {
return &stripTimeHandler{inner: h.inner.WithGroup(name)}
} }

View File

@@ -9,11 +9,12 @@ import (
"net/netip" "net/netip"
"syscall" "syscall"
"github.com/sirupsen/logrus" "log/slog"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch) return NewGenericListener(l, ip, port, multi, batch)
} }

View File

@@ -12,11 +12,12 @@ import (
"net/netip" "net/netip"
"syscall" "syscall"
"github.com/sirupsen/logrus" "log/slog"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch) return NewGenericListener(l, ip, port, multi, batch)
} }

View File

@@ -8,12 +8,12 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"net" "net"
"net/netip" "net/netip"
"syscall" "syscall"
"unsafe" "unsafe"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@@ -22,12 +22,12 @@ type StdConn struct {
*net.UDPConn *net.UDPConn
isV4 bool isV4 bool
sysFd uintptr sysFd uintptr
l *logrus.Logger l *slog.Logger
} }
var _ Conn = &StdConn{} var _ Conn = &StdConn{}
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
lc := NewListenConfig(multi) lc := NewListenConfig(multi)
pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))) pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
if err != nil { if err != nil {
@@ -176,7 +176,7 @@ func (u *StdConn) ListenOut(r EncReader) error {
return err return err
} }
u.l.WithError(err).Error("unexpected udp socket receive error") u.l.Error("unexpected udp socket receive error", "error", err)
} }
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
@@ -196,7 +196,7 @@ func (u *StdConn) Rebind() error {
} }
if err != nil { if err != nil {
u.l.WithError(err).Error("Failed to rebind udp socket") u.l.Error("Failed to rebind udp socket", "error", err)
} }
return nil return nil

View File

@@ -12,22 +12,22 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"net" "net"
"net/netip" "net/netip"
"time" "time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
) )
type GenericConn struct { type GenericConn struct {
*net.UDPConn *net.UDPConn
l *logrus.Logger l *slog.Logger
} }
var _ Conn = &GenericConn{} var _ Conn = &GenericConn{}
func NewGenericListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { func NewGenericListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
lc := NewListenConfig(multi) lc := NewListenConfig(multi)
pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))) pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
if err != nil { if err != nil {
@@ -88,7 +88,7 @@ func (u *GenericConn) ListenOut(r EncReader) error {
// Dampen unexpected message warns to once per minute // Dampen unexpected message warns to once per minute
if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute { if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute {
lastRecvErr = time.Now() lastRecvErr = time.Now()
u.l.WithError(err).Warn("unexpected udp socket receive error") u.l.Warn("unexpected udp socket receive error", "error", err)
} }
continue continue
} }

View File

@@ -7,13 +7,13 @@ import (
"context" "context"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"log/slog"
"net" "net"
"net/netip" "net/netip"
"syscall" "syscall"
"unsafe" "unsafe"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@@ -22,7 +22,7 @@ type StdConn struct {
udpConn *net.UDPConn udpConn *net.UDPConn
rawConn syscall.RawConn rawConn syscall.RawConn
isV4 bool isV4 bool
l *logrus.Logger l *slog.Logger
batch int batch int
} }
@@ -38,7 +38,7 @@ func setReusePort(network, address string, c syscall.RawConn) error {
return opErr return opErr
} }
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
listen := netip.AddrPortFrom(ip, uint16(port)) listen := netip.AddrPortFrom(ip, uint16(port))
lc := net.ListenConfig{} lc := net.ListenConfig{}
if multi { if multi {
@@ -242,12 +242,12 @@ func (u *StdConn) ReloadConfig(c *config.C) {
if err == nil { if err == nil {
s, err := u.GetRecvBuffer() s, err := u.GetRecvBuffer()
if err == nil { if err == nil {
u.l.WithField("size", s).Info("listen.read_buffer was set") u.l.Info("listen.read_buffer was set", "size", s)
} else { } else {
u.l.WithError(err).Warn("Failed to get listen.read_buffer") u.l.Warn("Failed to get listen.read_buffer", "error", err)
} }
} else { } else {
u.l.WithError(err).Error("Failed to set listen.read_buffer") u.l.Error("Failed to set listen.read_buffer", "error", err)
} }
} }
@@ -257,12 +257,12 @@ func (u *StdConn) ReloadConfig(c *config.C) {
if err == nil { if err == nil {
s, err := u.GetSendBuffer() s, err := u.GetSendBuffer()
if err == nil { if err == nil {
u.l.WithField("size", s).Info("listen.write_buffer was set") u.l.Info("listen.write_buffer was set", "size", s)
} else { } else {
u.l.WithError(err).Warn("Failed to get listen.write_buffer") u.l.Warn("Failed to get listen.write_buffer", "error", err)
} }
} else { } else {
u.l.WithError(err).Error("Failed to set listen.write_buffer") u.l.Error("Failed to set listen.write_buffer", "error", err)
} }
} }
@@ -273,12 +273,12 @@ func (u *StdConn) ReloadConfig(c *config.C) {
if err == nil { if err == nil {
s, err := u.GetSoMark() s, err := u.GetSoMark()
if err == nil { if err == nil {
u.l.WithField("mark", s).Info("listen.so_mark was set") u.l.Info("listen.so_mark was set", "mark", s)
} else { } else {
u.l.WithError(err).Warn("Failed to get listen.so_mark") u.l.Warn("Failed to get listen.so_mark", "error", err)
} }
} else { } else {
u.l.WithError(err).Error("Failed to set listen.so_mark") u.l.Error("Failed to set listen.so_mark", "error", err)
} }
} }
} }

View File

@@ -11,11 +11,12 @@ import (
"net/netip" "net/netip"
"syscall" "syscall"
"github.com/sirupsen/logrus" "log/slog"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch) return NewGenericListener(l, ip, port, multi, batch)
} }

View File

@@ -9,6 +9,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log/slog"
"net" "net"
"net/netip" "net/netip"
"sync" "sync"
@@ -17,7 +18,6 @@ import (
"time" "time"
"unsafe" "unsafe"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/conn/winrio" "golang.zx2c4.com/wireguard/conn/winrio"
@@ -53,14 +53,14 @@ type ringBuffer struct {
type RIOConn struct { type RIOConn struct {
isOpen atomic.Bool isOpen atomic.Bool
l *logrus.Logger l *slog.Logger
sock windows.Handle sock windows.Handle
rx, tx ringBuffer rx, tx ringBuffer
rq winrio.Rq rq winrio.Rq
results [packetsPerRing]winrio.Result results [packetsPerRing]winrio.Result
} }
func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, error) { func NewRIOListener(l *slog.Logger, addr netip.Addr, port int) (*RIOConn, error) {
if !winrio.Initialize() { if !winrio.Initialize() {
return nil, errors.New("could not initialize winrio") return nil, errors.New("could not initialize winrio")
} }
@@ -83,7 +83,7 @@ func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, erro
return u, nil return u, nil
} }
func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error { func (u *RIOConn) bind(l *slog.Logger, sa windows.Sockaddr) error {
var err error var err error
u.sock, err = winrio.Socket(windows.AF_INET6, windows.SOCK_DGRAM, windows.IPPROTO_UDP) u.sock, err = winrio.Socket(windows.AF_INET6, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
if err != nil { if err != nil {
@@ -103,7 +103,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error {
if err != nil { if err != nil {
// This is a best-effort to prevent errors from being returned by the udp recv operation. // This is a best-effort to prevent errors from being returned by the udp recv operation.
// Quietly log a failure and continue. // Quietly log a failure and continue.
l.WithError(err).Debug("failed to set UDP_CONNRESET ioctl") l.Debug("failed to set UDP_CONNRESET ioctl", "error", err)
} }
ret = 0 ret = 0
@@ -114,7 +114,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error {
if err != nil { if err != nil {
// This is a best-effort to prevent errors from being returned by the udp recv operation. // This is a best-effort to prevent errors from being returned by the udp recv operation.
// Quietly log a failure and continue. // Quietly log a failure and continue.
l.WithError(err).Debug("failed to set UDP_NETRESET ioctl") l.Debug("failed to set UDP_NETRESET ioctl", "error", err)
} }
err = u.rx.Open() err = u.rx.Open()
@@ -156,7 +156,7 @@ func (u *RIOConn) ListenOut(r EncReader) error {
// Dampen unexpected message warns to once per minute // Dampen unexpected message warns to once per minute
if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute { if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute {
lastRecvErr = time.Now() lastRecvErr = time.Now()
u.l.WithError(err).Warn("unexpected udp socket receive error") u.l.Warn("unexpected udp socket receive error", "error", err)
} }
continue continue
} }

View File

@@ -4,12 +4,13 @@
package udp package udp
import ( import (
"context"
"io" "io"
"log/slog"
"net/netip" "net/netip"
"os" "os"
"sync" "sync"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
) )
@@ -46,10 +47,10 @@ type TesterConn struct {
done chan struct{} done chan struct{}
closeOnce sync.Once closeOnce sync.Once
l *logrus.Logger l *slog.Logger
} }
func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) { func NewListener(l *slog.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) {
return &TesterConn{ return &TesterConn{
Addr: netip.AddrPortFrom(ip, uint16(port)), Addr: netip.AddrPortFrom(ip, uint16(port)),
RxPackets: make(chan *Packet, 10), RxPackets: make(chan *Packet, 10),
@@ -67,11 +68,12 @@ func (u *TesterConn) Send(packet *Packet) {
if err := h.Parse(packet.Data); err != nil { if err := h.Parse(packet.Data); err != nil {
panic(err) panic(err)
} }
if u.l.Level >= logrus.DebugLevel { if u.l.Enabled(context.Background(), slog.LevelDebug) {
u.l.WithField("header", h). u.l.Debug("UDP receiving injected packet",
WithField("udpAddr", packet.From). "header", h,
WithField("dataLen", len(packet.Data)). "udpAddr", packet.From,
Debug("UDP receiving injected packet") "dataLen", len(packet.Data),
)
} }
select { select {
case <-u.done: case <-u.done:

View File

@@ -5,14 +5,13 @@ package udp
import ( import (
"fmt" "fmt"
"log/slog"
"net" "net"
"net/netip" "net/netip"
"syscall" "syscall"
"github.com/sirupsen/logrus"
) )
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
if multi { if multi {
//NOTE: Technically we can support it with RIO but it wouldn't be at the socket level //NOTE: Technically we can support it with RIO but it wouldn't be at the socket level
// The udp stack would need to be reworked to hide away the implementation differences between // The udp stack would need to be reworked to hide away the implementation differences between
@@ -25,7 +24,7 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
return rc, nil return rc, nil
} }
l.WithError(err).Error("Falling back to standard udp sockets") l.Error("Falling back to standard udp sockets", "error", err)
return NewGenericListener(l, ip, port, multi, batch) return NewGenericListener(l, ip, port, multi, batch)
} }

View File

@@ -1,10 +1,10 @@
package util package util
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"github.com/sirupsen/logrus"
) )
type ContextualError struct { type ContextualError struct {
@@ -28,12 +28,12 @@ func ContextualizeIfNeeded(msg string, err error) error {
} }
// LogWithContextIfNeeded is a helper function to log an error line for an error or ContextualError // LogWithContextIfNeeded is a helper function to log an error line for an error or ContextualError
func LogWithContextIfNeeded(msg string, err error, l *logrus.Logger) { func LogWithContextIfNeeded(msg string, err error, l *slog.Logger) {
switch v := err.(type) { switch v := err.(type) {
case *ContextualError: case *ContextualError:
v.Log(l) v.Log(l)
default: default:
l.WithError(err).Error(msg) l.Error(msg, "error", err)
} }
} }
@@ -51,10 +51,19 @@ func (ce *ContextualError) Unwrap() error {
return ce.RealError return ce.RealError
} }
func (ce *ContextualError) Log(lr *logrus.Logger) { // Log emits ce as a single error-level log line with Fields and RealError
// promoted to top-level attributes, producing a flat shape callers can grep
// or parse without walking into a nested object.
func (ce *ContextualError) Log(l *slog.Logger) {
attrs := make([]slog.Attr, 0, len(ce.Fields)+1)
for k, v := range ce.Fields {
attrs = append(attrs, slog.Any(k, v))
}
if ce.RealError != nil { if ce.RealError != nil {
lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context) attrs = append(attrs, slog.Any("error", ce.RealError))
} else {
lr.WithFields(ce.Fields).Error(ce.Context)
} }
// LogAttrs is intentional: attrs is built from a map[string]any so it has
// no pair-form equivalent.
//nolint:sloglint
l.LogAttrs(context.Background(), slog.LevelError, ce.Context, attrs...)
} }

View File

@@ -1,95 +1,67 @@
package util package util
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"testing" "testing"
"github.com/sirupsen/logrus" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
type m = map[string]any type m = map[string]any
type TestLogWriter struct {
Logs []string
}
func NewTestLogWriter() *TestLogWriter {
return &TestLogWriter{Logs: make([]string, 0)}
}
func (tl *TestLogWriter) Write(p []byte) (n int, err error) {
tl.Logs = append(tl.Logs, string(p))
return len(p), nil
}
func (tl *TestLogWriter) Reset() {
tl.Logs = tl.Logs[:0]
}
func TestContextualError_Log(t *testing.T) { func TestContextualError_Log(t *testing.T) {
l := logrus.New() buf := &bytes.Buffer{}
l.Formatter = &logrus.TextFormatter{ l := test.NewLoggerWithOutput(buf)
DisableTimestamp: true,
DisableColors: true,
}
tl := NewTestLogWriter()
l.Out = tl
// Test a full context line // Test a full context line
tl.Reset() buf.Reset()
e := NewContextualError("test message", m{"field": "1"}, errors.New("error")) e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
e.Log(l) e.Log(l)
assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs) assert.Equal(t, "level=ERROR msg=\"test message\" field=1 error=error\n", buf.String())
// Test a line with an error and msg but no fields // Test a line with an error and msg but no fields
tl.Reset() buf.Reset()
e = NewContextualError("test message", nil, errors.New("error")) e = NewContextualError("test message", nil, errors.New("error"))
e.Log(l) e.Log(l)
assert.Equal(t, []string{"level=error msg=\"test message\" error=error\n"}, tl.Logs) assert.Equal(t, "level=ERROR msg=\"test message\" error=error\n", buf.String())
// Test just a context and fields // Test just a context and fields
tl.Reset() buf.Reset()
e = NewContextualError("test message", m{"field": "1"}, nil) e = NewContextualError("test message", m{"field": "1"}, nil)
e.Log(l) e.Log(l)
assert.Equal(t, []string{"level=error msg=\"test message\" field=1\n"}, tl.Logs) assert.Equal(t, "level=ERROR msg=\"test message\" field=1\n", buf.String())
// Test just a context // Test just a context
tl.Reset() buf.Reset()
e = NewContextualError("test message", nil, nil) e = NewContextualError("test message", nil, nil)
e.Log(l) e.Log(l)
assert.Equal(t, []string{"level=error msg=\"test message\"\n"}, tl.Logs) assert.Equal(t, "level=ERROR msg=\"test message\"\n", buf.String())
// Test just an error // Test just an error
tl.Reset() buf.Reset()
e = NewContextualError("", nil, errors.New("error")) e = NewContextualError("", nil, errors.New("error"))
e.Log(l) e.Log(l)
assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs) assert.Equal(t, "level=ERROR msg=\"\" error=error\n", buf.String())
} }
func TestLogWithContextIfNeeded(t *testing.T) { func TestLogWithContextIfNeeded(t *testing.T) {
l := logrus.New() buf := &bytes.Buffer{}
l.Formatter = &logrus.TextFormatter{ l := test.NewLoggerWithOutput(buf)
DisableTimestamp: true,
DisableColors: true,
}
tl := NewTestLogWriter()
l.Out = tl
// Test ignoring fallback context // Test ignoring fallback context
tl.Reset() buf.Reset()
e := NewContextualError("test message", m{"field": "1"}, errors.New("error")) e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
LogWithContextIfNeeded("This should get thrown away", e, l) LogWithContextIfNeeded("This should get thrown away", e, l)
assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs) assert.Equal(t, "level=ERROR msg=\"test message\" field=1 error=error\n", buf.String())
// Test using fallback context // Test using fallback context
tl.Reset() buf.Reset()
err := fmt.Errorf("this is a normal error") err := fmt.Errorf("this is a normal error")
LogWithContextIfNeeded("Fallback context woo", err, l) LogWithContextIfNeeded("Fallback context woo", err, l)
assert.Equal(t, []string{"level=error msg=\"Fallback context woo\" error=\"this is a normal error\"\n"}, tl.Logs) assert.Equal(t, "level=ERROR msg=\"Fallback context woo\" error=\"this is a normal error\"\n", buf.String())
} }
func TestContextualizeIfNeeded(t *testing.T) { func TestContextualizeIfNeeded(t *testing.T) {