mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-15 20:37:36 +02:00
Switch to slog, remove logrus (#1672)
This commit is contained in:
@@ -2,7 +2,21 @@ version: "2"
|
||||
linters:
|
||||
default: none
|
||||
enable:
|
||||
- sloglint
|
||||
- testifylint
|
||||
settings:
|
||||
sloglint:
|
||||
# Enforce key-value pair form for Info/Debug/Warn/Error/Log/With and
|
||||
# the package-level slog equivalents. Use l.Log(ctx, level, ...) for
|
||||
# custom levels instead of LogAttrs when you can.
|
||||
#
|
||||
# LogAttrs is also flagged by this rule because it takes ...slog.Attr;
|
||||
# the few legitimate sites (where attrs is built up as a []slog.Attr)
|
||||
# carry a //nolint:sloglint with rationale.
|
||||
kv-only: true
|
||||
# no-mixed-args is on by default: forbids mixing kv and attrs in one call.
|
||||
# discard-handler is on by default (since Go 1.24): suggests
|
||||
# slog.DiscardHandler over slog.NewTextHandler(io.Discard, nil).
|
||||
exclusions:
|
||||
generated: lax
|
||||
presets:
|
||||
|
||||
38
bits.go
38
bits.go
@@ -1,8 +1,10 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Bits struct {
|
||||
@@ -30,7 +32,7 @@ func NewBits(bits uint64) *Bits {
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *Bits) Check(l *logrus.Logger, i uint64) bool {
|
||||
func (b *Bits) Check(l *slog.Logger, i uint64) bool {
|
||||
// If i is the next number, return true.
|
||||
if i > b.current {
|
||||
return true
|
||||
@@ -42,13 +44,16 @@ func (b *Bits) Check(l *logrus.Logger, i uint64) bool {
|
||||
}
|
||||
|
||||
// Not within the window
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
|
||||
if l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
l.Debug("rejected a packet (top)",
|
||||
"current", b.current,
|
||||
"incoming", i,
|
||||
)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
||||
func (b *Bits) Update(l *slog.Logger, i uint64) bool {
|
||||
// If i is the next number, return true and update current.
|
||||
if i == b.current+1 {
|
||||
// Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter
|
||||
@@ -87,9 +92,13 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
||||
// Check to see if it's a duplicate
|
||||
if i > b.current-b.length || i < b.length && b.current < b.length {
|
||||
if b.current == i || b.bits[i%b.length] == true {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}).
|
||||
Debug("Receive window")
|
||||
if l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
l.Debug("Receive window",
|
||||
"accepted", false,
|
||||
"currentCounter", b.current,
|
||||
"incomingCounter", i,
|
||||
"reason", "duplicate",
|
||||
)
|
||||
}
|
||||
b.dupeCounter.Inc(1)
|
||||
return false
|
||||
@@ -101,12 +110,13 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
||||
|
||||
// In all other cases, fail and don't change current.
|
||||
b.outOfWindowCounter.Inc(1)
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("accepted", false).
|
||||
WithField("currentCounter", b.current).
|
||||
WithField("incomingCounter", i).
|
||||
WithField("reason", "nonsense").
|
||||
Debug("Receive window")
|
||||
if l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
l.Debug("Receive window",
|
||||
"accepted", false,
|
||||
"currentCounter", b.current,
|
||||
"incomingCounter", i,
|
||||
"reason", "nonsense",
|
||||
)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -3,8 +3,15 @@
|
||||
|
||||
package main
|
||||
|
||||
import "github.com/sirupsen/logrus"
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
func HookLogger(l *logrus.Logger) {
|
||||
// Do nothing, let the logs flow to stdout/stderr
|
||||
"github.com/slackhq/nebula/logging"
|
||||
)
|
||||
|
||||
// newPlatformLogger returns a *slog.Logger that writes to stdout. Non-Windows
|
||||
// platforms have no special sink to integrate with.
|
||||
func newPlatformLogger() *slog.Logger {
|
||||
return logging.NewLogger(os.Stdout)
|
||||
}
|
||||
|
||||
@@ -1,54 +1,86 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/logging"
|
||||
)
|
||||
|
||||
// HookLogger routes the logrus logs through the service logger so that they end up in the Windows Event Viewer
|
||||
// logrus output will be discarded
|
||||
func HookLogger(l *logrus.Logger) {
|
||||
l.AddHook(newLogHook(logger))
|
||||
l.SetOutput(ioutil.Discard)
|
||||
// newPlatformLogger returns a *slog.Logger that routes every log record
|
||||
// through the Windows service logger so records end up in the Windows
|
||||
// Event Log. All the heavy lifting (level management, format swap,
|
||||
// timestamp toggle, WithAttrs/WithGroup) comes from logging.NewHandler;
|
||||
// this file only contributes:
|
||||
//
|
||||
// - an io.Writer that forwards each formatted line to the service
|
||||
// logger at the current record's Event Log severity, and
|
||||
// - a thin severityTag that embeds *logging.Handler and overrides
|
||||
// only Handle / WithAttrs / WithGroup, so Event Viewer's severity
|
||||
// column and severity-based filters keep working the way they did
|
||||
// before the slog migration.
|
||||
//
|
||||
// Format (text vs json) is carried by the embedded *logging.Handler, so
|
||||
// logging.format: json in config still produces JSON lines in Event
|
||||
// Viewer, same as the pre-slog logrus setup.
|
||||
func newPlatformLogger() *slog.Logger {
|
||||
w := &eventLogWriter{}
|
||||
return slog.New(&severityTag{Handler: logging.NewHandler(w), w: w})
|
||||
}
|
||||
|
||||
type logHook struct {
|
||||
sl service.Logger
|
||||
// eventLogWriter forwards slog-formatted lines to the Windows service
|
||||
// logger at the severity most recently stashed by severityTag.Handle.
|
||||
// The mutex serializes the stash + inner.Handle + Write cycle per record
|
||||
// across all concurrent goroutines; slog's builtin text/json handlers
|
||||
// each hold their own mutex around Write, but that only protects the
|
||||
// Write call itself, not our stash-then-handle sequence.
|
||||
type eventLogWriter struct {
|
||||
mu sync.Mutex
|
||||
level slog.Level
|
||||
}
|
||||
|
||||
func newLogHook(sl service.Logger) *logHook {
|
||||
return &logHook{sl: sl}
|
||||
}
|
||||
|
||||
func (h *logHook) Fire(entry *logrus.Entry) error {
|
||||
line, err := entry.String()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Unable to read entry, %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
switch entry.Level {
|
||||
case logrus.PanicLevel:
|
||||
return h.sl.Error(line)
|
||||
case logrus.FatalLevel:
|
||||
return h.sl.Error(line)
|
||||
case logrus.ErrorLevel:
|
||||
return h.sl.Error(line)
|
||||
case logrus.WarnLevel:
|
||||
return h.sl.Warning(line)
|
||||
case logrus.InfoLevel:
|
||||
return h.sl.Info(line)
|
||||
case logrus.DebugLevel:
|
||||
return h.sl.Info(line)
|
||||
func (w *eventLogWriter) Write(p []byte) (int, error) {
|
||||
line := strings.TrimRight(string(p), "\n")
|
||||
switch {
|
||||
case w.level >= slog.LevelError:
|
||||
return len(p), logger.Error(line)
|
||||
case w.level >= slog.LevelWarn:
|
||||
return len(p), logger.Warning(line)
|
||||
default:
|
||||
return nil
|
||||
return len(p), logger.Info(line)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *logHook) Levels() []logrus.Level {
|
||||
return logrus.AllLevels
|
||||
// severityTag embeds *logging.Handler to pick up everything it does for
|
||||
// free (Enabled, SetLevel, GetLevel, SetFormat, GetFormat,
|
||||
// SetDisableTimestamp) and overrides only Handle / WithAttrs / WithGroup
|
||||
// so each record's slog.Level is stashed on the writer before formatting
|
||||
// and so derived handlers stay wrapped as severityTag rather than
|
||||
// downgrading to bare *logging.Handler.
|
||||
type severityTag struct {
|
||||
*logging.Handler
|
||||
w *eventLogWriter
|
||||
}
|
||||
|
||||
func (s *severityTag) Handle(ctx context.Context, r slog.Record) error {
|
||||
s.w.mu.Lock()
|
||||
defer s.w.mu.Unlock()
|
||||
s.w.level = r.Level
|
||||
return s.Handler.Handle(ctx, r)
|
||||
}
|
||||
|
||||
func (s *severityTag) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||
if len(attrs) == 0 {
|
||||
return s
|
||||
}
|
||||
return &severityTag{Handler: s.Handler.WithAttrs(attrs).(*logging.Handler), w: s.w}
|
||||
}
|
||||
|
||||
func (s *severityTag) WithGroup(name string) slog.Handler {
|
||||
if name == "" {
|
||||
return s
|
||||
}
|
||||
return &severityTag{Handler: s.Handler.WithGroup(name).(*logging.Handler), w: s.w}
|
||||
}
|
||||
|
||||
@@ -7,9 +7,9 @@ import (
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/logging"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
@@ -50,12 +50,11 @@ func main() {
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
l := logrus.New()
|
||||
l.Out = os.Stdout
|
||||
l := logging.NewLogger(os.Stdout)
|
||||
|
||||
if *serviceFlag != "" {
|
||||
if err := doService(configPath, configTest, Build, serviceFlag); err != nil {
|
||||
l.WithError(err).Error("Service command failed")
|
||||
l.Error("Service command failed", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return
|
||||
@@ -74,6 +73,16 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := logging.ApplyConfig(l, c); err != nil {
|
||||
fmt.Printf("failed to apply logging config: %s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
if err := logging.ApplyConfig(l, c); err != nil {
|
||||
l.Error("Failed to reconfigure logger on reload", "error", err)
|
||||
}
|
||||
})
|
||||
|
||||
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("Failed to start", err, l)
|
||||
@@ -90,7 +99,7 @@ func main() {
|
||||
go ctrl.ShutdownBlock()
|
||||
|
||||
if err := wait(); err != nil {
|
||||
l.WithError(err).Error("Nebula stopped due to fatal error")
|
||||
l.Error("Nebula stopped due to fatal error", "error", err)
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,9 +7,9 @@ import (
|
||||
"path/filepath"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/logging"
|
||||
)
|
||||
|
||||
var logger service.Logger
|
||||
@@ -25,8 +25,7 @@ func (p *program) Start(s service.Service) error {
|
||||
// Start should not block.
|
||||
logger.Info("Nebula service starting.")
|
||||
|
||||
l := logrus.New()
|
||||
HookLogger(l)
|
||||
l := newPlatformLogger()
|
||||
|
||||
c := config.NewC(l)
|
||||
err := c.Load(*p.configPath)
|
||||
@@ -34,6 +33,15 @@ func (p *program) Start(s service.Service) error {
|
||||
return fmt.Errorf("failed to load config: %s", err)
|
||||
}
|
||||
|
||||
if err := logging.ApplyConfig(l, c); err != nil {
|
||||
return fmt.Errorf("failed to apply logging config: %s", err)
|
||||
}
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
if err := logging.ApplyConfig(l, c); err != nil {
|
||||
l.Error("Failed to reconfigure logger on reload", "error", err)
|
||||
}
|
||||
})
|
||||
|
||||
p.control, err = nebula.Main(c, *p.configTest, Build, l, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -85,7 +93,7 @@ func doService(configPath *string, configTest *bool, build string, serviceFlag *
|
||||
// Here are what the different loggers are doing:
|
||||
// - `log` is the standard go log utility, meant to be used while the process is still attached to stdout/stderr
|
||||
// - `logger` is the service log utility that may be attached to a special place depending on OS (Windows will have it attached to the event log)
|
||||
// - above, in `Run` we create a `logrus.Logger` which is what nebula expects to use
|
||||
// - in program.Start we build a *slog.Logger via newPlatformLogger; on non-Windows that is a stdout-backed slog logger, on Windows it routes records through the service logger
|
||||
s, err := service.New(prg, svcConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -7,9 +7,9 @@ import (
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/logging"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
@@ -55,8 +55,7 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
l := logrus.New()
|
||||
l.Out = os.Stdout
|
||||
l := logging.NewLogger(os.Stdout)
|
||||
|
||||
c := config.NewC(l)
|
||||
err := c.Load(*configPath)
|
||||
@@ -65,6 +64,16 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := logging.ApplyConfig(l, c); err != nil {
|
||||
fmt.Printf("failed to apply logging config: %s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
if err := logging.ApplyConfig(l, c); err != nil {
|
||||
l.Error("Failed to reconfigure logger on reload", "error", err)
|
||||
}
|
||||
})
|
||||
|
||||
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("Failed to start", err, l)
|
||||
@@ -82,7 +91,7 @@ func main() {
|
||||
notifyReady(l)
|
||||
|
||||
if err := wait(); err != nil {
|
||||
l.WithError(err).Error("Nebula stopped due to fatal error")
|
||||
l.Error("Nebula stopped due to fatal error", "error", err)
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// SdNotifyReady tells systemd the service is ready and dependent services can now be started
|
||||
@@ -13,30 +12,30 @@ import (
|
||||
// https://www.freedesktop.org/software/systemd/man/systemd.service.html
|
||||
const SdNotifyReady = "READY=1"
|
||||
|
||||
func notifyReady(l *logrus.Logger) {
|
||||
func notifyReady(l *slog.Logger) {
|
||||
sockName := os.Getenv("NOTIFY_SOCKET")
|
||||
if sockName == "" {
|
||||
l.Debugln("NOTIFY_SOCKET systemd env var not set, not sending ready signal")
|
||||
l.Debug("NOTIFY_SOCKET systemd env var not set, not sending ready signal")
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := net.DialTimeout("unixgram", sockName, time.Second)
|
||||
if err != nil {
|
||||
l.WithError(err).Error("failed to connect to systemd notification socket")
|
||||
l.Error("failed to connect to systemd notification socket", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.SetWriteDeadline(time.Now().Add(time.Second))
|
||||
if err != nil {
|
||||
l.WithError(err).Error("failed to set the write deadline for the systemd notification socket")
|
||||
l.Error("failed to set the write deadline for the systemd notification socket", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = conn.Write([]byte(SdNotifyReady)); err != nil {
|
||||
l.WithError(err).Error("failed to signal the systemd notification socket")
|
||||
l.Error("failed to signal the systemd notification socket", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
l.Debugln("notified systemd the service is ready")
|
||||
l.Debug("notified systemd the service is ready")
|
||||
}
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
package main
|
||||
|
||||
import "github.com/sirupsen/logrus"
|
||||
import "log/slog"
|
||||
|
||||
func notifyReady(_ *logrus.Logger) {
|
||||
func notifyReady(_ *slog.Logger) {
|
||||
// No init service to notify
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"os"
|
||||
"os/signal"
|
||||
@@ -16,7 +17,6 @@ import (
|
||||
"time"
|
||||
|
||||
"dario.cat/mergo"
|
||||
"github.com/sirupsen/logrus"
|
||||
"go.yaml.in/yaml/v3"
|
||||
)
|
||||
|
||||
@@ -26,11 +26,11 @@ type C struct {
|
||||
Settings map[string]any
|
||||
oldSettings map[string]any
|
||||
callbacks []func(*C)
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
reloadLock sync.Mutex
|
||||
}
|
||||
|
||||
func NewC(l *logrus.Logger) *C {
|
||||
func NewC(l *slog.Logger) *C {
|
||||
return &C{
|
||||
Settings: make(map[string]any),
|
||||
l: l,
|
||||
@@ -107,12 +107,18 @@ func (c *C) HasChanged(k string) bool {
|
||||
|
||||
newVals, err := yaml.Marshal(nv)
|
||||
if err != nil {
|
||||
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
|
||||
c.l.Error("Error while marshaling new config",
|
||||
"config_path", k,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
|
||||
oldVals, err := yaml.Marshal(ov)
|
||||
if err != nil {
|
||||
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
|
||||
c.l.Error("Error while marshaling old config",
|
||||
"config_path", k,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
|
||||
return string(newVals) != string(oldVals)
|
||||
@@ -154,7 +160,10 @@ func (c *C) ReloadConfig() {
|
||||
|
||||
err := c.Load(c.path)
|
||||
if err != nil {
|
||||
c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
|
||||
c.l.Error("Error occurred while reloading config",
|
||||
"config_path", c.path,
|
||||
"error", err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -5,13 +5,13 @@ import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/header"
|
||||
@@ -47,10 +47,10 @@ type connectionManager struct {
|
||||
|
||||
metricsTxPunchy metrics.Counter
|
||||
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
}
|
||||
|
||||
func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
|
||||
func newConnectionManagerFromConfig(l *slog.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
|
||||
cm := &connectionManager{
|
||||
hostMap: hm,
|
||||
l: l,
|
||||
@@ -85,9 +85,10 @@ func (cm *connectionManager) reload(c *config.C, initial bool) {
|
||||
old := cm.getInactivityTimeout()
|
||||
cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute)))
|
||||
if !initial {
|
||||
cm.l.WithField("oldDuration", old).
|
||||
WithField("newDuration", cm.getInactivityTimeout()).
|
||||
Info("Inactivity timeout has changed")
|
||||
cm.l.Info("Inactivity timeout has changed",
|
||||
"oldDuration", old,
|
||||
"newDuration", cm.getInactivityTimeout(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,9 +96,10 @@ func (cm *connectionManager) reload(c *config.C, initial bool) {
|
||||
old := cm.dropInactive.Load()
|
||||
cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false))
|
||||
if !initial {
|
||||
cm.l.WithField("oldBool", old).
|
||||
WithField("newBool", cm.dropInactive.Load()).
|
||||
Info("Drop inactive setting has changed")
|
||||
cm.l.Info("Drop inactive setting has changed",
|
||||
"oldBool", old,
|
||||
"newBool", cm.dropInactive.Load(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -256,7 +258,7 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
|
||||
var err error
|
||||
index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested)
|
||||
if err != nil {
|
||||
cm.l.WithError(err).Error("failed to migrate relay to new hostinfo")
|
||||
cm.l.Error("failed to migrate relay to new hostinfo", "error", err)
|
||||
continue
|
||||
}
|
||||
switch r.Type {
|
||||
@@ -304,16 +306,16 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
|
||||
|
||||
msg, err := req.Marshal()
|
||||
if err != nil {
|
||||
cm.l.WithError(err).Error("failed to marshal Control message to migrate relay")
|
||||
cm.l.Error("failed to marshal Control message to migrate relay", "error", err)
|
||||
} else {
|
||||
cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||
cm.l.WithFields(logrus.Fields{
|
||||
"relayFrom": req.RelayFromAddr,
|
||||
"relayTo": req.RelayToAddr,
|
||||
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
||||
"responderRelayIndex": req.ResponderRelayIndex,
|
||||
"vpnAddrs": newhostinfo.vpnAddrs}).
|
||||
Info("send CreateRelayRequest")
|
||||
cm.l.Info("send CreateRelayRequest",
|
||||
"relayFrom", req.RelayFromAddr,
|
||||
"relayTo", req.RelayToAddr,
|
||||
"initiatorRelayIndex", req.InitiatorRelayIndex,
|
||||
"responderRelayIndex", req.ResponderRelayIndex,
|
||||
"vpnAddrs", newhostinfo.vpnAddrs,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -325,7 +327,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
||||
|
||||
hostinfo := cm.hostMap.Indexes[localIndex]
|
||||
if hostinfo == nil {
|
||||
cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap")
|
||||
cm.l.Debug("Not found in hostmap", "localIndex", localIndex)
|
||||
return doNothing, nil, nil
|
||||
}
|
||||
|
||||
@@ -345,10 +347,10 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
||||
// A hostinfo is determined alive if there is incoming traffic
|
||||
if inTraffic {
|
||||
decision := doNothing
|
||||
if cm.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(cm.l).
|
||||
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
|
||||
Debug("Tunnel status")
|
||||
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
hostinfo.logger(cm.l).Debug("Tunnel status",
|
||||
"tunnelCheck", m{"state": "alive", "method": "passive"},
|
||||
)
|
||||
}
|
||||
hostinfo.pendingDeletion.Store(false)
|
||||
|
||||
@@ -375,9 +377,9 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
||||
|
||||
if hostinfo.pendingDeletion.Load() {
|
||||
// We have already sent a test packet and nothing was returned, this hostinfo is dead
|
||||
hostinfo.logger(cm.l).
|
||||
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
|
||||
Info("Tunnel status")
|
||||
hostinfo.logger(cm.l).Info("Tunnel status",
|
||||
"tunnelCheck", m{"state": "dead", "method": "active"},
|
||||
)
|
||||
|
||||
return deleteTunnel, hostinfo, nil
|
||||
}
|
||||
@@ -388,10 +390,10 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
||||
inactiveFor, isInactive := cm.isInactive(hostinfo, now)
|
||||
if isInactive {
|
||||
// Tunnel is inactive, tear it down
|
||||
hostinfo.logger(cm.l).
|
||||
WithField("inactiveDuration", inactiveFor).
|
||||
WithField("primary", mainHostInfo).
|
||||
Info("Dropping tunnel due to inactivity")
|
||||
hostinfo.logger(cm.l).Info("Dropping tunnel due to inactivity",
|
||||
"inactiveDuration", inactiveFor,
|
||||
"primary", mainHostInfo,
|
||||
)
|
||||
|
||||
return closeTunnel, hostinfo, primary
|
||||
}
|
||||
@@ -410,18 +412,18 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
||||
cm.sendPunch(hostinfo)
|
||||
}
|
||||
|
||||
if cm.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(cm.l).
|
||||
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
|
||||
Debug("Tunnel status")
|
||||
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
hostinfo.logger(cm.l).Debug("Tunnel status",
|
||||
"tunnelCheck", m{"state": "testing", "method": "active"},
|
||||
)
|
||||
}
|
||||
|
||||
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
|
||||
decision = sendTestPacket
|
||||
|
||||
} else {
|
||||
if cm.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(cm.l).Debugf("Hostinfo sadness")
|
||||
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
hostinfo.logger(cm.l).Debug("Hostinfo sadness")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -493,14 +495,16 @@ func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostI
|
||||
return false //cert is still valid! yay!
|
||||
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
|
||||
// Block listed certificates should always be disconnected
|
||||
hostinfo.logger(cm.l).WithError(err).
|
||||
WithField("fingerprint", remoteCert.Fingerprint).
|
||||
Info("Remote certificate is blocked, tearing down the tunnel")
|
||||
hostinfo.logger(cm.l).Info("Remote certificate is blocked, tearing down the tunnel",
|
||||
"error", err,
|
||||
"fingerprint", remoteCert.Fingerprint,
|
||||
)
|
||||
return true
|
||||
} else if cm.intf.disconnectInvalid.Load() {
|
||||
hostinfo.logger(cm.l).WithError(err).
|
||||
WithField("fingerprint", remoteCert.Fingerprint).
|
||||
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
||||
hostinfo.logger(cm.l).Info("Remote certificate is no longer valid, tearing down the tunnel",
|
||||
"error", err,
|
||||
"fingerprint", remoteCert.Fingerprint,
|
||||
)
|
||||
return true
|
||||
} else {
|
||||
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
|
||||
@@ -539,10 +543,11 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||
curCrtVersion := curCrt.Version()
|
||||
myCrt := cs.getCertificate(curCrtVersion)
|
||||
if myCrt == nil {
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("version", curCrtVersion).
|
||||
WithField("reason", "local certificate removed").
|
||||
Info("Re-handshaking with remote")
|
||||
cm.l.Info("Re-handshaking with remote",
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"version", curCrtVersion,
|
||||
"reason", "local certificate removed",
|
||||
)
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||
return
|
||||
}
|
||||
@@ -550,11 +555,12 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
|
||||
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
|
||||
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("version", curCrtVersion).
|
||||
WithField("peerVersion", peerCrt.Certificate.Version()).
|
||||
WithField("reason", "local certificate version lower than peer, attempting to correct").
|
||||
Info("Re-handshaking with remote")
|
||||
cm.l.Info("Re-handshaking with remote",
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"version", curCrtVersion,
|
||||
"peerVersion", peerCrt.Certificate.Version(),
|
||||
"reason", "local certificate version lower than peer, attempting to correct",
|
||||
)
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
|
||||
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
|
||||
})
|
||||
@@ -562,17 +568,19 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||
}
|
||||
}
|
||||
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("reason", "local certificate is not current").
|
||||
Info("Re-handshaking with remote")
|
||||
cm.l.Info("Re-handshaking with remote",
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"reason", "local certificate is not current",
|
||||
)
|
||||
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||
return
|
||||
}
|
||||
if curCrtVersion < cs.initiatingVersion {
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("reason", "current cert version < pki.initiatingVersion").
|
||||
Info("Re-handshaking with remote")
|
||||
cm.l.Info("Re-handshaking with remote",
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"reason", "current cert version < pki.initiatingVersion",
|
||||
)
|
||||
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||
return
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/flynn/noise"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/overlay/overlaytest"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -52,7 +53,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
||||
lh := newTestLighthouse()
|
||||
ifce := &Interface{
|
||||
hostMap: hostMap,
|
||||
inside: &test.NoopTun{},
|
||||
inside: &overlaytest.NoopTun{},
|
||||
outside: &udp.NoopConn{},
|
||||
firewall: &Firewall{},
|
||||
lightHouse: lh,
|
||||
@@ -63,9 +64,9 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
||||
ifce.pki.cs.Store(cs)
|
||||
|
||||
// Create manager
|
||||
conf := config.NewC(l)
|
||||
punchy := NewPunchyFromConfig(l, conf)
|
||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
||||
conf := config.NewC(test.NewLogger())
|
||||
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
|
||||
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
||||
nc.intf = ifce
|
||||
p := []byte("")
|
||||
nb := make([]byte, 12, 12)
|
||||
@@ -135,7 +136,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||
lh := newTestLighthouse()
|
||||
ifce := &Interface{
|
||||
hostMap: hostMap,
|
||||
inside: &test.NoopTun{},
|
||||
inside: &overlaytest.NoopTun{},
|
||||
outside: &udp.NoopConn{},
|
||||
firewall: &Firewall{},
|
||||
lightHouse: lh,
|
||||
@@ -146,9 +147,9 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||
ifce.pki.cs.Store(cs)
|
||||
|
||||
// Create manager
|
||||
conf := config.NewC(l)
|
||||
punchy := NewPunchyFromConfig(l, conf)
|
||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
||||
conf := config.NewC(test.NewLogger())
|
||||
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
|
||||
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
||||
nc.intf = ifce
|
||||
p := []byte("")
|
||||
nb := make([]byte, 12, 12)
|
||||
@@ -220,7 +221,7 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
|
||||
lh := newTestLighthouse()
|
||||
ifce := &Interface{
|
||||
hostMap: hostMap,
|
||||
inside: &test.NoopTun{},
|
||||
inside: &overlaytest.NoopTun{},
|
||||
outside: &udp.NoopConn{},
|
||||
firewall: &Firewall{},
|
||||
lightHouse: lh,
|
||||
@@ -231,12 +232,12 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
|
||||
ifce.pki.cs.Store(cs)
|
||||
|
||||
// Create manager
|
||||
conf := config.NewC(l)
|
||||
conf := config.NewC(test.NewLogger())
|
||||
conf.Settings["tunnels"] = map[string]any{
|
||||
"drop_inactive": true,
|
||||
}
|
||||
punchy := NewPunchyFromConfig(l, conf)
|
||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
||||
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
|
||||
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
||||
assert.True(t, nc.dropInactive.Load())
|
||||
nc.intf = ifce
|
||||
|
||||
@@ -347,7 +348,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
||||
lh := newTestLighthouse()
|
||||
ifce := &Interface{
|
||||
hostMap: hostMap,
|
||||
inside: &test.NoopTun{},
|
||||
inside: &overlaytest.NoopTun{},
|
||||
outside: &udp.NoopConn{},
|
||||
firewall: &Firewall{},
|
||||
lightHouse: lh,
|
||||
@@ -360,9 +361,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
||||
ifce.disconnectInvalid.Store(true)
|
||||
|
||||
// Create manager
|
||||
conf := config.NewC(l)
|
||||
punchy := NewPunchyFromConfig(l, conf)
|
||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
||||
conf := config.NewC(test.NewLogger())
|
||||
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
|
||||
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
||||
nc.intf = ifce
|
||||
ifce.connectionManager = nc
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/noiseutil"
|
||||
)
|
||||
@@ -27,7 +26,7 @@ type ConnectionState struct {
|
||||
writeLock sync.Mutex
|
||||
}
|
||||
|
||||
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
|
||||
func NewConnectionState(cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
|
||||
var dhFunc noise.DHFunc
|
||||
switch crt.Curve() {
|
||||
case cert.Curve_CURVE25519:
|
||||
|
||||
14
control.go
14
control.go
@@ -3,13 +3,13 @@ package nebula
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/overlay"
|
||||
@@ -46,7 +46,7 @@ type Control struct {
|
||||
state RunState
|
||||
|
||||
f *Interface
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
sshStart func()
|
||||
@@ -151,7 +151,7 @@ func (c *Control) Stop() {
|
||||
|
||||
c.CloseAllTunnels(false)
|
||||
if err := c.f.Close(); err != nil {
|
||||
c.l.WithError(err).Error("Close interface failed")
|
||||
c.l.Error("Close interface failed", "error", err)
|
||||
}
|
||||
c.stateLock.Lock()
|
||||
c.state = StateStopped
|
||||
@@ -166,7 +166,7 @@ func (c *Control) ShutdownBlock() {
|
||||
|
||||
rawSig := <-sigChan
|
||||
sig := rawSig.String()
|
||||
c.l.WithField("signal", sig).Info("Caught signal, shutting down")
|
||||
c.l.Info("Caught signal, shutting down", "signal", sig)
|
||||
c.Stop()
|
||||
}
|
||||
|
||||
@@ -303,8 +303,10 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
|
||||
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
||||
c.f.closeTunnel(h)
|
||||
|
||||
c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote).
|
||||
Debug("Sending close tunnel message")
|
||||
c.l.Debug("Sending close tunnel message",
|
||||
"vpnAddrs", h.vpnAddrs,
|
||||
"udpAddr", h.remote,
|
||||
)
|
||||
closed++
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -83,7 +82,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
||||
f: &Interface{
|
||||
hostMap: hm,
|
||||
},
|
||||
l: logrus.New(),
|
||||
l: test.NewLogger(),
|
||||
}
|
||||
|
||||
thi := c.GetHostInfoByVpnAddr(vpnIp, false)
|
||||
|
||||
@@ -3,6 +3,7 @@ package nebula
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
@@ -12,13 +13,12 @@ import (
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
)
|
||||
|
||||
type dnsServer struct {
|
||||
sync.RWMutex
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
ctx context.Context
|
||||
dnsMap4 map[string]netip.Addr
|
||||
dnsMap6 map[string]netip.Addr
|
||||
@@ -55,7 +55,7 @@ type dnsServer struct {
|
||||
// they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel
|
||||
// watcher that tears the listener down on nebula shutdown. The returned
|
||||
// pointer is always non-nil, even on error.
|
||||
func newDnsServerFromConfig(ctx context.Context, l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) {
|
||||
func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) {
|
||||
ds := &dnsServer{
|
||||
l: l,
|
||||
ctx: ctx,
|
||||
@@ -69,7 +69,7 @@ func newDnsServerFromConfig(ctx context.Context, l *logrus.Logger, cs *CertState
|
||||
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
if err := ds.reload(c, false); err != nil {
|
||||
l.WithError(err).Error("Failed to reload DNS responder from config")
|
||||
ds.l.Error("Failed to reload DNS responder from config", "error", err)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -145,7 +145,7 @@ func (d *dnsServer) shutdownServer(srv *dns.Server, started chan struct{}, reaso
|
||||
<-started
|
||||
}
|
||||
if err := srv.Shutdown(); err != nil {
|
||||
d.l.WithError(err).WithField("reason", reason).Warn("Failed to shut down the DNS responder")
|
||||
d.l.Warn("Failed to shut down the DNS responder", "reason", reason, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,7 +188,7 @@ func (d *dnsServer) Start() {
|
||||
}
|
||||
}()
|
||||
|
||||
d.l.WithField("dnsListener", addr).Info("Starting DNS responder")
|
||||
d.l.Info("Starting DNS responder", "dnsListener", addr)
|
||||
err := server.ListenAndServe()
|
||||
close(done)
|
||||
|
||||
@@ -201,7 +201,7 @@ func (d *dnsServer) Start() {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
d.l.WithError(err).Warn("Failed to run the DNS responder")
|
||||
d.l.Warn("Failed to run the DNS responder", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -314,6 +314,7 @@ func (d *dnsServer) isSelfNebulaOrLocalhost(addr string) bool {
|
||||
}
|
||||
|
||||
func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
||||
debugEnabled := d.l.Enabled(context.Background(), slog.LevelDebug)
|
||||
// Per RFC 2308 §2.2, a name that exists but has no record of the requested
|
||||
// type must be answered with NOERROR and an empty answer section (NODATA),
|
||||
// not NXDOMAIN (RFC 2308 §2.1), which is reserved for names that do not
|
||||
@@ -323,7 +324,9 @@ func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
||||
switch q.Qtype {
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
qType := dns.TypeToString[q.Qtype]
|
||||
d.l.Debugf("Query for %s %s", qType, q.Name)
|
||||
if debugEnabled {
|
||||
d.l.Debug("DNS query", "type", qType, "name", q.Name)
|
||||
}
|
||||
ip, nameExists := d.Query(q.Qtype, q.Name)
|
||||
if nameExists {
|
||||
anyNameExists = true
|
||||
@@ -339,7 +342,9 @@ func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
||||
if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
|
||||
return
|
||||
}
|
||||
d.l.Debugf("Query for TXT %s", q.Name)
|
||||
if debugEnabled {
|
||||
d.l.Debug("DNS query", "type", "TXT", "name", q.Name)
|
||||
}
|
||||
ip := d.QueryCert(q.Name)
|
||||
if ip != "" {
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
|
||||
|
||||
@@ -2,7 +2,7 @@ package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -30,7 +29,7 @@ func (stubDNSWriter) TsigTimersOnly(bool) {}
|
||||
func (stubDNSWriter) Hijack() {}
|
||||
|
||||
func TestParsequery(t *testing.T) {
|
||||
l := logrus.New()
|
||||
l := slog.New(slog.DiscardHandler)
|
||||
hostMap := &HostMap{}
|
||||
ds := &dnsServer{
|
||||
l: l,
|
||||
@@ -137,10 +136,9 @@ func Test_getDnsServerAddr(t *testing.T) {
|
||||
|
||||
func newTestDnsServer(t *testing.T) (*dnsServer, *config.C) {
|
||||
t.Helper()
|
||||
l := logrus.New()
|
||||
l.Out = io.Discard
|
||||
sl := slog.New(slog.DiscardHandler)
|
||||
ds := &dnsServer{
|
||||
l: l,
|
||||
l: sl,
|
||||
ctx: context.Background(),
|
||||
dnsMap4: make(map[string]netip.Addr),
|
||||
dnsMap6: make(map[string]netip.Addr),
|
||||
@@ -148,7 +146,7 @@ func newTestDnsServer(t *testing.T) (*dnsServer, *config.C) {
|
||||
}
|
||||
ds.mux = dns.NewServeMux()
|
||||
ds.mux.HandleFunc(".", ds.handleDnsRequest)
|
||||
return ds, config.NewC(l)
|
||||
return ds, config.NewC(nil)
|
||||
}
|
||||
|
||||
func setDnsConfig(c *config.C, host string, port string, amLighthouse, serveDns bool) {
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/cert_test"
|
||||
@@ -749,7 +748,6 @@ func TestStage1RaceRelays2(t *testing.T) {
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
||||
l := NewTestLogger()
|
||||
|
||||
// Teach my how to get to the relay and that their can be reached via the relay
|
||||
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
||||
@@ -771,49 +769,41 @@ func TestStage1RaceRelays2(t *testing.T) {
|
||||
theirControl.Start()
|
||||
|
||||
r.Log("Get a tunnel between me and relay")
|
||||
l.Info("Get a tunnel between me and relay")
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
|
||||
|
||||
r.Log("Get a tunnel between them and relay")
|
||||
l.Info("Get a tunnel between them and relay")
|
||||
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
|
||||
|
||||
r.Log("Trigger a handshake from both them and me via relay to them and me")
|
||||
l.Info("Trigger a handshake from both them and me via relay to them and me")
|
||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
|
||||
|
||||
//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
|
||||
//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
|
||||
|
||||
r.Log("Wait for a packet from them to me")
|
||||
l.Info("Wait for a packet from them to me; myControl")
|
||||
r.Log("Wait for a packet from them to me; myControl")
|
||||
r.RouteForAllUntilTxTun(myControl)
|
||||
l.Info("Wait for a packet from them to me; theirControl")
|
||||
r.Log("Wait for a packet from them to me; theirControl")
|
||||
r.RouteForAllUntilTxTun(theirControl)
|
||||
|
||||
r.Log("Assert the tunnel works")
|
||||
l.Info("Assert the tunnel works")
|
||||
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||
|
||||
t.Log("Wait until we remove extra tunnels")
|
||||
l.Info("Wait until we remove extra tunnels")
|
||||
l.WithFields(
|
||||
logrus.Fields{
|
||||
"myControl": len(myControl.GetHostmap().Indexes),
|
||||
"theirControl": len(theirControl.GetHostmap().Indexes),
|
||||
"relayControl": len(relayControl.GetHostmap().Indexes),
|
||||
}).Info("Waiting for hostinfos to be removed...")
|
||||
t.Logf("Waiting for hostinfos to be removed... myControl=%d theirControl=%d relayControl=%d",
|
||||
len(myControl.GetHostmap().Indexes),
|
||||
len(theirControl.GetHostmap().Indexes),
|
||||
len(relayControl.GetHostmap().Indexes),
|
||||
)
|
||||
hostInfos := len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes)
|
||||
retries := 60
|
||||
for hostInfos > 6 && retries > 0 {
|
||||
hostInfos = len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes)
|
||||
l.WithFields(
|
||||
logrus.Fields{
|
||||
"myControl": len(myControl.GetHostmap().Indexes),
|
||||
"theirControl": len(theirControl.GetHostmap().Indexes),
|
||||
"relayControl": len(relayControl.GetHostmap().Indexes),
|
||||
}).Info("Waiting for hostinfos to be removed...")
|
||||
t.Logf("Waiting for hostinfos to be removed... myControl=%d theirControl=%d relayControl=%d",
|
||||
len(myControl.GetHostmap().Indexes),
|
||||
len(theirControl.GetHostmap().Indexes),
|
||||
len(relayControl.GetHostmap().Indexes),
|
||||
)
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
t.Log("Connection manager hasn't ticked yet")
|
||||
time.Sleep(time.Second)
|
||||
@@ -821,7 +811,6 @@ func TestStage1RaceRelays2(t *testing.T) {
|
||||
}
|
||||
|
||||
r.Log("Assert the tunnel works")
|
||||
l.Info("Assert the tunnel works")
|
||||
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||
|
||||
myControl.Stop()
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
@@ -12,15 +11,18 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"dario.cat/mergo"
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/slackhq/nebula"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/cert_test"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/e2e/router"
|
||||
"github.com/slackhq/nebula/logging"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.yaml.in/yaml/v3"
|
||||
@@ -132,8 +134,7 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
|
||||
"port": udpAddr.Port(),
|
||||
},
|
||||
"logging": m{
|
||||
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name),
|
||||
"level": l.Level.String(),
|
||||
"level": testLogLevelName(),
|
||||
},
|
||||
"timers": m{
|
||||
"pending_deletion_interval": 2,
|
||||
@@ -234,8 +235,7 @@ func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, o
|
||||
"port": udpAddr.Port(),
|
||||
},
|
||||
"logging": m{
|
||||
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
|
||||
"level": l.Level.String(),
|
||||
"level": testLogLevelName(),
|
||||
},
|
||||
"timers": m{
|
||||
"pending_deletion_interval": 2,
|
||||
@@ -379,24 +379,32 @@ func getAddrs(ns []netip.Prefix) []netip.Addr {
|
||||
return a
|
||||
}
|
||||
|
||||
func NewTestLogger() *logrus.Logger {
|
||||
l := logrus.New()
|
||||
|
||||
func NewTestLogger() *slog.Logger {
|
||||
v := os.Getenv("TEST_LOGS")
|
||||
if v == "" {
|
||||
l.SetOutput(io.Discard)
|
||||
l.SetLevel(logrus.PanicLevel)
|
||||
return l
|
||||
return slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
}
|
||||
|
||||
level := slog.LevelInfo
|
||||
switch v {
|
||||
case "2":
|
||||
l.SetLevel(logrus.DebugLevel)
|
||||
level = slog.LevelDebug
|
||||
case "3":
|
||||
l.SetLevel(logrus.TraceLevel)
|
||||
default:
|
||||
l.SetLevel(logrus.InfoLevel)
|
||||
level = logging.LevelTrace
|
||||
}
|
||||
return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level}))
|
||||
}
|
||||
|
||||
return l
|
||||
// testLogLevelName returns the level name string accepted by logging.ApplyConfig
|
||||
// for the current TEST_LOGS setting. Kept in sync with NewTestLogger.
|
||||
func testLogLevelName() string {
|
||||
switch os.Getenv("TEST_LOGS") {
|
||||
case "2":
|
||||
return "debug"
|
||||
case "3":
|
||||
return "trace"
|
||||
case "":
|
||||
return "info"
|
||||
}
|
||||
return "info"
|
||||
}
|
||||
|
||||
@@ -292,23 +292,17 @@ tun:
|
||||
|
||||
# Configure logging level
|
||||
logging:
|
||||
# panic, fatal, error, warning, info, or debug. Default is info and is reloadable.
|
||||
#NOTE: Debug mode can log remotely controlled/untrusted data which can quickly fill a disk in some
|
||||
# scenarios. Debug logging is also CPU intensive and will decrease performance overall.
|
||||
# Only enable debug logging while actively investigating an issue.
|
||||
# trace, debug, info, warn, or error. Default is info and is reloadable.
|
||||
# fatal and panic are accepted for backwards compatibility and map to error.
|
||||
#NOTE: Debug and trace modes can log remotely controlled/untrusted data which can quickly fill a disk in some
|
||||
# scenarios. Debug and trace logging are also CPU intensive and will decrease performance overall.
|
||||
# Only enable debug or trace logging while actively investigating an issue.
|
||||
level: info
|
||||
# json or text formats currently available. Default is text
|
||||
# json or text formats currently available. Default is text.
|
||||
format: text
|
||||
# Disable timestamp logging. useful when output is redirected to logging system that already adds timestamps. Default is false
|
||||
# Disable timestamp logging. Useful when output is redirected to a logging system that already adds timestamps. Default is false.
|
||||
#disable_timestamp: true
|
||||
# timestamp format is specified in Go time format, see:
|
||||
# https://golang.org/pkg/time/#pkg-constants
|
||||
# default when `format: json`: "2006-01-02T15:04:05Z07:00" (RFC3339)
|
||||
# default when `format: text`:
|
||||
# when TTY attached: seconds since beginning of execution
|
||||
# otherwise: "2006-01-02T15:04:05Z07:00" (RFC3339)
|
||||
# As an example, to log as RFC3339 with millisecond precision, set to:
|
||||
#timestamp_format: "2006-01-02T15:04:05.000Z07:00"
|
||||
# Timestamps use RFC3339Nano ("2006-01-02T15:04:05.999999999Z07:00") and are not configurable.
|
||||
|
||||
#stats:
|
||||
#type: graphite
|
||||
|
||||
@@ -7,9 +7,9 @@ import (
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/logging"
|
||||
"github.com/slackhq/nebula/overlay"
|
||||
"github.com/slackhq/nebula/service"
|
||||
)
|
||||
@@ -64,8 +64,7 @@ pki:
|
||||
return err
|
||||
}
|
||||
|
||||
logger := logrus.New()
|
||||
logger.Out = os.Stdout
|
||||
logger := logging.NewLogger(os.Stdout)
|
||||
|
||||
ctrl, err := nebula.Main(&cfg, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
|
||||
if err != nil {
|
||||
|
||||
65
firewall.go
65
firewall.go
@@ -1,11 +1,13 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"slices"
|
||||
@@ -16,7 +18,6 @@ import (
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
@@ -67,7 +68,7 @@ type Firewall struct {
|
||||
incomingMetrics firewallMetrics
|
||||
outgoingMetrics firewallMetrics
|
||||
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
}
|
||||
|
||||
type firewallMetrics struct {
|
||||
@@ -131,7 +132,7 @@ type firewallLocalCIDR struct {
|
||||
|
||||
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
||||
// The certificate provided should be the highest version loaded in memory.
|
||||
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
|
||||
func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
|
||||
//TODO: error on 0 duration
|
||||
var tmin, tmax time.Duration
|
||||
|
||||
@@ -191,7 +192,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
||||
}
|
||||
}
|
||||
|
||||
func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
|
||||
func NewFirewallFromConfig(l *slog.Logger, cs *CertState, c *config.C) (*Firewall, error) {
|
||||
certificate := cs.getCertificate(cert.Version2)
|
||||
if certificate == nil {
|
||||
certificate = cs.getCertificate(cert.Version1)
|
||||
@@ -219,7 +220,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
|
||||
case "drop":
|
||||
fw.InSendReject = false
|
||||
default:
|
||||
l.WithField("action", inboundAction).Warn("invalid firewall.inbound_action, defaulting to `drop`")
|
||||
l.Warn("invalid firewall.inbound_action, defaulting to `drop`", "action", inboundAction)
|
||||
fw.InSendReject = false
|
||||
}
|
||||
|
||||
@@ -230,7 +231,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
|
||||
case "drop":
|
||||
fw.OutSendReject = false
|
||||
default:
|
||||
l.WithField("action", outboundAction).Warn("invalid firewall.outbound_action, defaulting to `drop`")
|
||||
l.Warn("invalid firewall.outbound_action, defaulting to `drop`", "action", outboundAction)
|
||||
fw.OutSendReject = false
|
||||
}
|
||||
|
||||
@@ -268,7 +269,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
||||
case firewall.ProtoICMP, firewall.ProtoICMPv6:
|
||||
//ICMP traffic doesn't have ports, so we always coerce to "any", even if a value is provided
|
||||
if startPort != firewall.PortAny {
|
||||
f.l.WithField("startPort", startPort).Warn("ignoring port specification for ICMP firewall rule")
|
||||
f.l.Warn("ignoring port specification for ICMP firewall rule", "startPort", startPort)
|
||||
}
|
||||
startPort = firewall.PortAny
|
||||
endPort = firewall.PortAny
|
||||
@@ -290,8 +291,9 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
||||
if !incoming {
|
||||
direction = "outgoing"
|
||||
}
|
||||
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}).
|
||||
Info("Firewall rule added")
|
||||
f.l.Info("Firewall rule added",
|
||||
"firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha},
|
||||
)
|
||||
|
||||
return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha)
|
||||
}
|
||||
@@ -314,7 +316,7 @@ func (f *Firewall) GetRuleHashes() string {
|
||||
return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10)
|
||||
}
|
||||
|
||||
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
|
||||
func AddFirewallRulesFromConfig(l *slog.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
|
||||
var table string
|
||||
if inbound {
|
||||
table = "firewall.inbound"
|
||||
@@ -372,7 +374,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
||||
startPort = firewall.PortAny
|
||||
endPort = firewall.PortAny
|
||||
if sPort != "" {
|
||||
l.WithField("port", sPort).Warn("ignoring port specification for ICMP firewall rule")
|
||||
l.Warn("ignoring port specification for ICMP firewall rule", "port", sPort)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
|
||||
@@ -396,7 +398,11 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
||||
}
|
||||
|
||||
if warning := r.sanity(); warning != nil {
|
||||
l.Warnf("%s rule #%v; %s", table, i, warning)
|
||||
l.Warn("firewall rule sanity check",
|
||||
"table", table,
|
||||
"rule", i,
|
||||
"warning", warning,
|
||||
)
|
||||
}
|
||||
|
||||
err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
|
||||
@@ -528,26 +534,26 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
||||
|
||||
// We now know which firewall table to check against
|
||||
if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
h.logger(f.l).
|
||||
WithField("fwPacket", fp).
|
||||
WithField("incoming", c.incoming).
|
||||
WithField("rulesVersion", f.rulesVersion).
|
||||
WithField("oldRulesVersion", c.rulesVersion).
|
||||
Debugln("dropping old conntrack entry, does not match new ruleset")
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
h.logger(f.l).Debug("dropping old conntrack entry, does not match new ruleset",
|
||||
"fwPacket", fp,
|
||||
"incoming", c.incoming,
|
||||
"rulesVersion", f.rulesVersion,
|
||||
"oldRulesVersion", c.rulesVersion,
|
||||
)
|
||||
}
|
||||
delete(conntrack.Conns, fp)
|
||||
conntrack.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
h.logger(f.l).
|
||||
WithField("fwPacket", fp).
|
||||
WithField("incoming", c.incoming).
|
||||
WithField("rulesVersion", f.rulesVersion).
|
||||
WithField("oldRulesVersion", c.rulesVersion).
|
||||
Debugln("keeping old conntrack entry, does match new ruleset")
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
h.logger(f.l).Debug("keeping old conntrack entry, does match new ruleset",
|
||||
"fwPacket", fp,
|
||||
"incoming", c.incoming,
|
||||
"rulesVersion", f.rulesVersion,
|
||||
"oldRulesVersion", c.rulesVersion,
|
||||
)
|
||||
}
|
||||
|
||||
c.rulesVersion = f.rulesVersion
|
||||
@@ -935,7 +941,7 @@ type rule struct {
|
||||
CASha string
|
||||
}
|
||||
|
||||
func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
|
||||
func convertRule(l *slog.Logger, p any, table string, i int) (rule, error) {
|
||||
r := rule{}
|
||||
|
||||
m, ok := p.(map[string]any)
|
||||
@@ -966,7 +972,10 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
|
||||
return r, errors.New("group should contain a single value, an array with more than one entry was provided")
|
||||
}
|
||||
|
||||
l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i)
|
||||
l.Warn("group was an array with a single value, converting to simple value",
|
||||
"table", table,
|
||||
"rule", i,
|
||||
)
|
||||
m["group"] = v[0]
|
||||
}
|
||||
|
||||
|
||||
@@ -2,10 +2,9 @@ package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ConntrackCache is used as a local routine cache to know if a given flow
|
||||
@@ -16,15 +15,17 @@ type ConntrackCacheTicker struct {
|
||||
cacheV uint64
|
||||
cacheTick atomic.Uint64
|
||||
|
||||
l *slog.Logger
|
||||
cache ConntrackCache
|
||||
}
|
||||
|
||||
func NewConntrackCacheTicker(ctx context.Context, d time.Duration) *ConntrackCacheTicker {
|
||||
func NewConntrackCacheTicker(ctx context.Context, l *slog.Logger, d time.Duration) *ConntrackCacheTicker {
|
||||
if d == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
c := &ConntrackCacheTicker{
|
||||
l: l,
|
||||
cache: ConntrackCache{},
|
||||
}
|
||||
|
||||
@@ -48,15 +49,15 @@ func (c *ConntrackCacheTicker) tick(ctx context.Context, d time.Duration) {
|
||||
|
||||
// Get checks if the cache ticker has moved to the next version before returning
|
||||
// the map. If it has moved, we reset the map.
|
||||
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
|
||||
func (c *ConntrackCacheTicker) Get() ConntrackCache {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if tick := c.cacheTick.Load(); tick != c.cacheV {
|
||||
c.cacheV = tick
|
||||
if ll := len(c.cache); ll > 0 {
|
||||
if l.Level == logrus.DebugLevel {
|
||||
l.WithField("len", ll).Debug("resetting conntrack cache")
|
||||
if c.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
c.l.Debug("resetting conntrack cache", "len", ll)
|
||||
}
|
||||
c.cache = make(ConntrackCache, ll)
|
||||
}
|
||||
|
||||
69
firewall/cache_test.go
Normal file
69
firewall/cache_test.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// The tests below pin the log format produced by ConntrackCacheTicker.Get
|
||||
// so changes cannot silently break what operators are grepping for. The
|
||||
// ticker's internal state (cache + cacheTick) is poked directly to avoid
|
||||
// racing a goroutine-driven tick in tests.
|
||||
|
||||
func newFixedTicker(t *testing.T, l *slog.Logger, cacheLen int) *ConntrackCacheTicker {
|
||||
t.Helper()
|
||||
c := &ConntrackCacheTicker{
|
||||
l: l,
|
||||
cache: make(ConntrackCache, cacheLen),
|
||||
}
|
||||
for i := 0; i < cacheLen; i++ {
|
||||
c.cache[Packet{LocalPort: uint16(i) + 1}] = struct{}{}
|
||||
}
|
||||
c.cacheTick.Store(1) // cacheV starts at 0, so Get() takes the reset path
|
||||
return c
|
||||
}
|
||||
|
||||
func TestConntrackCacheTicker_Get_TextFormat(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug)
|
||||
|
||||
c := newFixedTicker(t, l, 3)
|
||||
c.Get()
|
||||
|
||||
assert.Equal(t, "level=DEBUG msg=\"resetting conntrack cache\" len=3\n", buf.String())
|
||||
}
|
||||
|
||||
func TestConntrackCacheTicker_Get_JSONFormat(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
l := test.NewJSONLoggerWithOutput(buf, slog.LevelDebug)
|
||||
|
||||
c := newFixedTicker(t, l, 2)
|
||||
c.Get()
|
||||
|
||||
assert.JSONEq(t, `{"level":"DEBUG","msg":"resetting conntrack cache","len":2}`, strings.TrimSpace(buf.String()))
|
||||
}
|
||||
|
||||
func TestConntrackCacheTicker_Get_QuietBelowDebug(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelInfo)
|
||||
|
||||
c := newFixedTicker(t, l, 5)
|
||||
c.Get()
|
||||
|
||||
assert.Empty(t, buf.String())
|
||||
}
|
||||
|
||||
func TestConntrackCacheTicker_Get_QuietWhenCacheEmpty(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug)
|
||||
|
||||
c := newFixedTicker(t, l, 0)
|
||||
c.Get()
|
||||
|
||||
assert.Empty(t, buf.String())
|
||||
}
|
||||
100
firewall_test.go
100
firewall_test.go
@@ -3,13 +3,13 @@ package nebula
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
@@ -58,9 +58,8 @@ func TestNewFirewall(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFirewall_AddRule(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
l := test.NewLoggerWithOutput(ob)
|
||||
|
||||
c := &dummyCert{}
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
@@ -177,9 +176,8 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFirewall_Drop(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
l := test.NewLoggerWithOutput(ob)
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||
p := firewall.Packet{
|
||||
@@ -254,9 +252,8 @@ func TestFirewall_Drop(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFirewall_DropV6(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
l := test.NewLoggerWithOutput(ob)
|
||||
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
||||
@@ -485,9 +482,8 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
}
|
||||
|
||||
func TestFirewall_Drop2(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
l := test.NewLoggerWithOutput(ob)
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||
|
||||
@@ -544,9 +540,8 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFirewall_Drop3(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
l := test.NewLoggerWithOutput(ob)
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||
|
||||
@@ -633,9 +628,8 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFirewall_Drop3V6(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
l := test.NewLoggerWithOutput(ob)
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
||||
|
||||
@@ -671,9 +665,8 @@ func TestFirewall_Drop3V6(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
l := test.NewLoggerWithOutput(ob)
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||
|
||||
@@ -736,9 +729,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
l := test.NewLoggerWithOutput(ob)
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||
|
||||
@@ -880,9 +872,8 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFirewall_DropIPSpoofing(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
l := test.NewLoggerWithOutput(ob)
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
|
||||
|
||||
@@ -1045,25 +1036,25 @@ func TestNewFirewallFromConfig(t *testing.T) {
|
||||
cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
conf := config.NewC(l)
|
||||
conf := config.NewC(test.NewLogger())
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
|
||||
|
||||
// Test both port and code
|
||||
conf = config.NewC(l)
|
||||
conf = config.NewC(test.NewLogger())
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
|
||||
|
||||
// Test missing host, group, cidr, ca_name and ca_sha
|
||||
conf = config.NewC(l)
|
||||
conf = config.NewC(test.NewLogger())
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
|
||||
|
||||
// Test code/port error
|
||||
conf = config.NewC(l)
|
||||
conf = config.NewC(test.NewLogger())
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh", "proto": "any"}}}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
|
||||
@@ -1073,25 +1064,25 @@ func TestNewFirewallFromConfig(t *testing.T) {
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
|
||||
|
||||
// Test proto error
|
||||
conf = config.NewC(l)
|
||||
conf = config.NewC(test.NewLogger())
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
|
||||
|
||||
// Test cidr parse error
|
||||
conf = config.NewC(l)
|
||||
conf = config.NewC(test.NewLogger())
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
||||
|
||||
// Test local_cidr parse error
|
||||
conf = config.NewC(l)
|
||||
conf = config.NewC(test.NewLogger())
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
||||
|
||||
// Test both group and groups
|
||||
conf = config.NewC(l)
|
||||
conf = config.NewC(test.NewLogger())
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
|
||||
@@ -1100,35 +1091,35 @@ func TestNewFirewallFromConfig(t *testing.T) {
|
||||
func TestAddFirewallRulesFromConfig(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
// Test adding tcp rule
|
||||
conf := config.NewC(l)
|
||||
conf := config.NewC(test.NewLogger())
|
||||
mf := &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
||||
|
||||
// Test adding udp rule
|
||||
conf = config.NewC(l)
|
||||
conf = config.NewC(test.NewLogger())
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
||||
|
||||
// Test adding icmp rule
|
||||
conf = config.NewC(l)
|
||||
conf = config.NewC(test.NewLogger())
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
||||
|
||||
// Test adding icmp rule no port
|
||||
conf = config.NewC(l)
|
||||
conf = config.NewC(test.NewLogger())
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"proto": "icmp", "host": "a"}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
||||
|
||||
// Test adding any rule
|
||||
conf = config.NewC(l)
|
||||
conf = config.NewC(test.NewLogger())
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
@@ -1136,14 +1127,14 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
||||
|
||||
// Test adding rule with cidr
|
||||
cidr := netip.MustParsePrefix("10.0.0.0/8")
|
||||
conf = config.NewC(l)
|
||||
conf = config.NewC(test.NewLogger())
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall)
|
||||
|
||||
// Test adding rule with local_cidr
|
||||
conf = config.NewC(l)
|
||||
conf = config.NewC(test.NewLogger())
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
@@ -1151,82 +1142,82 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
||||
|
||||
// Test adding rule with cidr ipv6
|
||||
cidr6 := netip.MustParsePrefix("fd00::/8")
|
||||
conf = config.NewC(l)
|
||||
conf = config.NewC(test.NewLogger())
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall)
|
||||
|
||||
// Test adding rule with any cidr
|
||||
conf = config.NewC(l)
|
||||
conf = config.NewC(test.NewLogger())
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall)
|
||||
|
||||
// Test adding rule with junk cidr
|
||||
conf = config.NewC(l)
|
||||
conf = config.NewC(test.NewLogger())
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}}
|
||||
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
|
||||
|
||||
// Test adding rule with local_cidr ipv6
|
||||
conf = config.NewC(l)
|
||||
conf = config.NewC(test.NewLogger())
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall)
|
||||
|
||||
// Test adding rule with any local_cidr
|
||||
conf = config.NewC(l)
|
||||
conf = config.NewC(test.NewLogger())
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall)
|
||||
|
||||
// Test adding rule with junk local_cidr
|
||||
conf = config.NewC(l)
|
||||
conf = config.NewC(test.NewLogger())
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}}
|
||||
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
|
||||
|
||||
// Test adding rule with ca_sha
|
||||
conf = config.NewC(l)
|
||||
conf = config.NewC(test.NewLogger())
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall)
|
||||
|
||||
// Test adding rule with ca_name
|
||||
conf = config.NewC(l)
|
||||
conf = config.NewC(test.NewLogger())
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall)
|
||||
|
||||
// Test single group
|
||||
conf = config.NewC(l)
|
||||
conf = config.NewC(test.NewLogger())
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
|
||||
|
||||
// Test single groups
|
||||
conf = config.NewC(l)
|
||||
conf = config.NewC(test.NewLogger())
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
|
||||
|
||||
// Test multiple AND groups
|
||||
conf = config.NewC(l)
|
||||
conf = config.NewC(test.NewLogger())
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall)
|
||||
|
||||
// Test Add error
|
||||
conf = config.NewC(l)
|
||||
conf = config.NewC(test.NewLogger())
|
||||
mf = &mockFirewall{}
|
||||
mf.nextCallReturn = errors.New("test error")
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
|
||||
@@ -1234,9 +1225,8 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFirewall_convertRule(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
l := test.NewLoggerWithOutput(ob)
|
||||
|
||||
// Ensure group array of 1 is converted and a warning is printed
|
||||
c := map[string]any{
|
||||
@@ -1244,7 +1234,9 @@ func TestFirewall_convertRule(t *testing.T) {
|
||||
}
|
||||
|
||||
r, err := convertRule(l, c, "test", 1)
|
||||
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
|
||||
assert.Contains(t, ob.String(), "group was an array with a single value, converting to simple value")
|
||||
assert.Contains(t, ob.String(), "table=test")
|
||||
assert.Contains(t, ob.String(), "rule=1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"group1"}, r.Groups)
|
||||
|
||||
@@ -1270,9 +1262,8 @@ func TestFirewall_convertRule(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFirewall_convertRuleSanity(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
l := test.NewLoggerWithOutput(ob)
|
||||
|
||||
noWarningPlease := []map[string]any{
|
||||
{"group": "group1"},
|
||||
@@ -1386,7 +1377,7 @@ type testsetup struct {
|
||||
fw *Firewall
|
||||
}
|
||||
|
||||
func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup {
|
||||
func newSetup(t *testing.T, l *slog.Logger, myPrefixes ...netip.Prefix) testsetup {
|
||||
c := dummyCert{
|
||||
name: "me",
|
||||
networks: myPrefixes,
|
||||
@@ -1397,7 +1388,7 @@ func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testse
|
||||
return newSetupFromCert(t, l, c)
|
||||
}
|
||||
|
||||
func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
|
||||
func newSetupFromCert(t *testing.T, l *slog.Logger, c dummyCert) testsetup {
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
for _, prefix := range c.Networks() {
|
||||
myVpnNetworksTable.Insert(prefix)
|
||||
@@ -1414,9 +1405,8 @@ func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
|
||||
|
||||
func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
l := test.NewLoggerWithOutput(ob)
|
||||
|
||||
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
|
||||
// for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out
|
||||
|
||||
1
go.mod
1
go.mod
@@ -18,7 +18,6 @@ require (
|
||||
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
|
||||
github.com/sirupsen/logrus v1.9.4
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
||||
github.com/stretchr/testify v1.11.1
|
||||
|
||||
2
go.sum
2
go.sum
@@ -133,8 +133,6 @@ github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncj
|
||||
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
||||
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
|
||||
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
|
||||
github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w=
|
||||
github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g=
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
|
||||
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw=
|
||||
|
||||
541
handshake_ix.go
541
handshake_ix.go
@@ -2,11 +2,12 @@ package nebula
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/header"
|
||||
)
|
||||
@@ -18,8 +19,11 @@ import (
|
||||
func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||
err := f.handshakeManager.allocateIndex(hh)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
|
||||
f.l.Error("Failed to generate index",
|
||||
"error", err,
|
||||
"vpnAddrs", hh.hostinfo.vpnAddrs,
|
||||
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||
)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -39,28 +43,32 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||
|
||||
crt := cs.getCertificate(v)
|
||||
if crt == nil {
|
||||
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||
WithField("certVersion", v).
|
||||
Error("Unable to handshake with host because no certificate is available")
|
||||
f.l.Error("Unable to handshake with host because no certificate is available",
|
||||
"vpnAddrs", hh.hostinfo.vpnAddrs,
|
||||
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||
"certVersion", v,
|
||||
)
|
||||
return false
|
||||
}
|
||||
|
||||
crtHs := cs.getHandshakeBytes(v)
|
||||
if crtHs == nil {
|
||||
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||
WithField("certVersion", v).
|
||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||
f.l.Error("Unable to handshake with host because no certificate handshake bytes is available",
|
||||
"vpnAddrs", hh.hostinfo.vpnAddrs,
|
||||
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||
"certVersion", v,
|
||||
)
|
||||
return false
|
||||
}
|
||||
|
||||
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
|
||||
ci, err := NewConnectionState(cs, crt, true, noise.HandshakeIX)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||
WithField("certVersion", v).
|
||||
Error("Failed to create connection state")
|
||||
f.l.Error("Failed to create connection state",
|
||||
"error", err,
|
||||
"vpnAddrs", hh.hostinfo.vpnAddrs,
|
||||
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||
"certVersion", v,
|
||||
)
|
||||
return false
|
||||
}
|
||||
hh.hostinfo.ConnectionState = ci
|
||||
@@ -76,9 +84,12 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||
|
||||
hsBytes, err := hs.Marshal()
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
||||
WithField("certVersion", v).
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
||||
f.l.Error("Failed to marshal handshake message",
|
||||
"error", err,
|
||||
"vpnAddrs", hh.hostinfo.vpnAddrs,
|
||||
"certVersion", v,
|
||||
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||
)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -86,8 +97,11 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||
|
||||
msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
||||
f.l.Error("Failed to call noise.WriteMessage",
|
||||
"error", err,
|
||||
"vpnAddrs", hh.hostinfo.vpnAddrs,
|
||||
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||
)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -104,18 +118,21 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
||||
cs := f.pki.getCertState()
|
||||
crt := cs.GetDefaultCertificate()
|
||||
if crt == nil {
|
||||
f.l.WithField("from", via).
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||
WithField("certVersion", cs.initiatingVersion).
|
||||
Error("Unable to handshake with host because no certificate is available")
|
||||
f.l.Error("Unable to handshake with host because no certificate is available",
|
||||
"from", via,
|
||||
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||
"certVersion", cs.initiatingVersion,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
||||
ci, err := NewConnectionState(cs, crt, false, noise.HandshakeIX)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("from", via).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Error("Failed to create connection state")
|
||||
f.l.Error("Failed to create connection state",
|
||||
"error", err,
|
||||
"from", via,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -124,26 +141,32 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
||||
|
||||
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("from", via).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Error("Failed to call noise.ReadMessage")
|
||||
f.l.Error("Failed to call noise.ReadMessage",
|
||||
"error", err,
|
||||
"from", via,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
hs := &NebulaHandshake{}
|
||||
err = hs.Unmarshal(msg)
|
||||
if err != nil || hs.Details == nil {
|
||||
f.l.WithError(err).WithField("from", via).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Error("Failed unmarshal handshake message")
|
||||
f.l.Error("Failed unmarshal handshake message",
|
||||
"error", err,
|
||||
"from", via,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("from", via).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Info("Handshake did not contain a certificate")
|
||||
f.l.Info("Handshake did not contain a certificate",
|
||||
"error", err,
|
||||
"from", via,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -154,23 +177,30 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
||||
fp = "<error generating certificate fingerprint>"
|
||||
}
|
||||
|
||||
e := f.l.WithError(err).WithField("from", via).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
WithField("certVpnNetworks", rc.Networks()).
|
||||
WithField("certFingerprint", fp)
|
||||
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
e = e.WithField("cert", rc)
|
||||
attrs := []slog.Attr{
|
||||
slog.Any("error", err),
|
||||
slog.Any("from", via),
|
||||
slog.Any("handshake", m{"stage": 1, "style": "ix_psk0"}),
|
||||
slog.Any("certVpnNetworks", rc.Networks()),
|
||||
slog.String("certFingerprint", fp),
|
||||
}
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
attrs = append(attrs, slog.Any("cert", rc))
|
||||
}
|
||||
|
||||
e.Info("Invalid certificate from host")
|
||||
// LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that
|
||||
// callers grow conditionally, which has no pair-form equivalent.
|
||||
//nolint:sloglint
|
||||
f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...)
|
||||
return
|
||||
}
|
||||
|
||||
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
|
||||
f.l.WithField("from", via).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake")
|
||||
f.l.Info("public key mismatch between certificate and handshake",
|
||||
"from", via,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
"cert", remoteCert,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -178,12 +208,13 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
||||
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
||||
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
|
||||
if myCertOtherVersion == nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithError(err).WithFields(m{
|
||||
"from": via,
|
||||
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
||||
"cert": remoteCert,
|
||||
}).Debug("Might be unable to handshake with host due to missing certificate version")
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
f.l.Debug("Might be unable to handshake with host due to missing certificate version",
|
||||
"error", err,
|
||||
"from", via,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
"cert", remoteCert,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// Record the certificate we are actually using
|
||||
@@ -192,10 +223,12 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
||||
}
|
||||
|
||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||
f.l.WithError(err).WithField("from", via).
|
||||
WithField("cert", remoteCert).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Info("No networks in certificate")
|
||||
f.l.Info("No networks in certificate",
|
||||
"error", err,
|
||||
"from", via,
|
||||
"cert", remoteCert,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -209,12 +242,15 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
||||
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
||||
for i, network := range vpnNetworks {
|
||||
if f.myVpnAddrsTable.Contains(network.Addr()) {
|
||||
f.l.WithField("vpnNetworks", vpnNetworks).WithField("from", via).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
||||
f.l.Error("Refusing to handshake with myself",
|
||||
"vpnNetworks", vpnNetworks,
|
||||
"from", via,
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"fingerprint", fingerprint,
|
||||
"issuer", issuer,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
return
|
||||
}
|
||||
vpnAddrs[i] = network.Addr()
|
||||
@@ -226,20 +262,28 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
||||
if !via.IsRelayed {
|
||||
// We only want to apply the remote allow list for direct tunnels here
|
||||
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) {
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
||||
Debug("lighthouse.remote_allow_list denied incoming handshake")
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
f.l.Debug("lighthouse.remote_allow_list denied incoming handshake",
|
||||
"vpnAddrs", vpnAddrs,
|
||||
"from", via,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
myIndex, err := generateIndex(f.l)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
|
||||
f.l.Error("Failed to generate index",
|
||||
"error", err,
|
||||
"vpnAddrs", vpnAddrs,
|
||||
"from", via,
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"fingerprint", fingerprint,
|
||||
"issuer", issuer,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -257,18 +301,18 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
||||
},
|
||||
}
|
||||
|
||||
msgRxL := f.l.WithFields(m{
|
||||
"vpnAddrs": vpnAddrs,
|
||||
"from": via,
|
||||
"certName": certName,
|
||||
"certVersion": certVersion,
|
||||
"fingerprint": fingerprint,
|
||||
"issuer": issuer,
|
||||
"initiatorIndex": hs.Details.InitiatorIndex,
|
||||
"responderIndex": hs.Details.ResponderIndex,
|
||||
"remoteIndex": h.RemoteIndex,
|
||||
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
||||
})
|
||||
msgRxL := f.l.With(
|
||||
"vpnAddrs", vpnAddrs,
|
||||
"from", via,
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"fingerprint", fingerprint,
|
||||
"issuer", issuer,
|
||||
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||
"responderIndex", hs.Details.ResponderIndex,
|
||||
"remoteIndex", h.RemoteIndex,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
|
||||
if anyVpnAddrsInCommon {
|
||||
msgRxL.Info("Handshake message received")
|
||||
@@ -280,8 +324,9 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
||||
hs.Details.ResponderIndex = myIndex
|
||||
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
||||
if hs.Details.Cert == nil {
|
||||
msgRxL.WithField("myCertVersion", ci.myCert.Version()).
|
||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||
msgRxL.Error("Unable to handshake with host because no certificate handshake bytes is available",
|
||||
"myCertVersion", ci.myCert.Version(),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -291,32 +336,43 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
||||
|
||||
hsBytes, err := hs.Marshal()
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
||||
f.l.Error("Failed to marshal handshake message",
|
||||
"error", err,
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"from", via,
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"fingerprint", fingerprint,
|
||||
"issuer", issuer,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
|
||||
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
||||
f.l.Error("Failed to call noise.WriteMessage",
|
||||
"error", err,
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"from", via,
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"fingerprint", fingerprint,
|
||||
"issuer", issuer,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
return
|
||||
} else if dKey == nil || eKey == nil {
|
||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
|
||||
f.l.Error("Noise did not arrive at a key",
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"from", via,
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"fingerprint", fingerprint,
|
||||
"issuer", issuer,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -358,13 +414,20 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
||||
if !via.IsRelayed {
|
||||
err := f.outside.WriteTo(msg, via.UdpAddr)
|
||||
if err != nil {
|
||||
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||
WithError(err).Error("Failed to send handshake message")
|
||||
f.l.Error("Failed to send handshake message",
|
||||
"vpnAddrs", existing.vpnAddrs,
|
||||
"from", via,
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
"cached", true,
|
||||
"error", err,
|
||||
)
|
||||
} else {
|
||||
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||
Info("Handshake message sent")
|
||||
f.l.Info("Handshake message sent",
|
||||
"vpnAddrs", existing.vpnAddrs,
|
||||
"from", via,
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
"cached", true,
|
||||
)
|
||||
}
|
||||
return
|
||||
} else {
|
||||
@@ -374,50 +437,67 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
||||
}
|
||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
||||
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||
Info("Handshake message sent")
|
||||
f.l.Info("Handshake message sent",
|
||||
"vpnAddrs", existing.vpnAddrs,
|
||||
"relay", via.relayHI.vpnAddrs[0],
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
"cached", true,
|
||||
)
|
||||
return
|
||||
}
|
||||
case ErrExistingHostInfo:
|
||||
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("oldHandshakeTime", existing.lastHandshakeTime).
|
||||
WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Info("Handshake too old")
|
||||
f.l.Info("Handshake too old",
|
||||
"vpnAddrs", vpnAddrs,
|
||||
"from", via,
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"oldHandshakeTime", existing.lastHandshakeTime,
|
||||
"newHandshakeTime", hostinfo.lastHandshakeTime,
|
||||
"fingerprint", fingerprint,
|
||||
"issuer", issuer,
|
||||
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||
"responderIndex", hs.Details.ResponderIndex,
|
||||
"remoteIndex", h.RemoteIndex,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
|
||||
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
|
||||
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||
return
|
||||
case ErrLocalIndexCollision:
|
||||
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs).
|
||||
Error("Failed to add HostInfo due to localIndex collision")
|
||||
f.l.Error("Failed to add HostInfo due to localIndex collision",
|
||||
"vpnAddrs", vpnAddrs,
|
||||
"from", via,
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"fingerprint", fingerprint,
|
||||
"issuer", issuer,
|
||||
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||
"responderIndex", hs.Details.ResponderIndex,
|
||||
"remoteIndex", h.RemoteIndex,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
"localIndex", hostinfo.localIndexId,
|
||||
"collision", existing.vpnAddrs,
|
||||
)
|
||||
return
|
||||
default:
|
||||
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
|
||||
// And we forget to update it here
|
||||
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Error("Failed to add HostInfo to HostMap")
|
||||
f.l.Error("Failed to add HostInfo to HostMap",
|
||||
"error", err,
|
||||
"vpnAddrs", vpnAddrs,
|
||||
"from", via,
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"fingerprint", fingerprint,
|
||||
"issuer", issuer,
|
||||
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||
"responderIndex", hs.Details.ResponderIndex,
|
||||
"remoteIndex", h.RemoteIndex,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -426,15 +506,20 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
||||
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
|
||||
if !via.IsRelayed {
|
||||
err = f.outside.WriteTo(msg, via.UdpAddr)
|
||||
log := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
|
||||
log := f.l.With(
|
||||
"vpnAddrs", vpnAddrs,
|
||||
"from", via,
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"fingerprint", fingerprint,
|
||||
"issuer", issuer,
|
||||
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||
"responderIndex", hs.Details.ResponderIndex,
|
||||
"remoteIndex", h.RemoteIndex,
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to send handshake")
|
||||
log.Error("Failed to send handshake", "error", err)
|
||||
} else {
|
||||
log.Info("Handshake message sent")
|
||||
}
|
||||
@@ -448,14 +533,18 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
||||
// it's correctly marked as working.
|
||||
via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established)
|
||||
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
Info("Handshake message sent")
|
||||
f.l.Info("Handshake message sent",
|
||||
"vpnAddrs", vpnAddrs,
|
||||
"relay", via.relayHI.vpnAddrs[0],
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"fingerprint", fingerprint,
|
||||
"issuer", issuer,
|
||||
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||
"responderIndex", hs.Details.ResponderIndex,
|
||||
"remoteIndex", h.RemoteIndex,
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
)
|
||||
}
|
||||
|
||||
f.connectionManager.AddTrafficWatch(hostinfo)
|
||||
@@ -483,7 +572,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
||||
if !via.IsRelayed {
|
||||
// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
|
||||
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
|
||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
f.l.Debug("lighthouse.remote_allow_list denied incoming handshake",
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"from", via,
|
||||
)
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -491,18 +585,24 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
||||
ci := hostinfo.ConnectionState
|
||||
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
|
||||
Error("Failed to call noise.ReadMessage")
|
||||
f.l.Error("Failed to call noise.ReadMessage",
|
||||
"error", err,
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"from", via,
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
"header", h,
|
||||
)
|
||||
|
||||
// We don't want to tear down the connection on a bad ReadMessage because it could be an attacker trying
|
||||
// to DOS us. Every other error condition after should to allow a possible good handshake to complete in the
|
||||
// near future
|
||||
return false
|
||||
} else if dKey == nil || eKey == nil {
|
||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
Error("Noise did not arrive at a key")
|
||||
f.l.Error("Noise did not arrive at a key",
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"from", via,
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
)
|
||||
|
||||
// This should be impossible in IX but just in case, if we get here then there is no chance to recover
|
||||
// the handshake state machine. Tear it down
|
||||
@@ -512,8 +612,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
||||
hs := &NebulaHandshake{}
|
||||
err = hs.Unmarshal(msg)
|
||||
if err != nil || hs.Details == nil {
|
||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
|
||||
f.l.Error("Failed unmarshal handshake message",
|
||||
"error", err,
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"from", via,
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
)
|
||||
|
||||
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
|
||||
return true
|
||||
@@ -521,10 +625,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
||||
|
||||
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("from", via).
|
||||
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
Info("Handshake did not contain a certificate")
|
||||
f.l.Info("Handshake did not contain a certificate",
|
||||
"error", err,
|
||||
"from", via,
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -535,32 +641,41 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
||||
fp = "<error generating certificate fingerprint>"
|
||||
}
|
||||
|
||||
e := f.l.WithError(err).WithField("from", via).
|
||||
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
WithField("certFingerprint", fp).
|
||||
WithField("certVpnNetworks", rc.Networks())
|
||||
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
e = e.WithField("cert", rc)
|
||||
attrs := []slog.Attr{
|
||||
slog.Any("error", err),
|
||||
slog.Any("from", via),
|
||||
slog.Any("vpnAddrs", hostinfo.vpnAddrs),
|
||||
slog.Any("handshake", m{"stage": 2, "style": "ix_psk0"}),
|
||||
slog.String("certFingerprint", fp),
|
||||
slog.Any("certVpnNetworks", rc.Networks()),
|
||||
}
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
attrs = append(attrs, slog.Any("cert", rc))
|
||||
}
|
||||
|
||||
e.Info("Invalid certificate from host")
|
||||
// LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that
|
||||
// callers grow conditionally, which has no pair-form equivalent.
|
||||
//nolint:sloglint
|
||||
f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...)
|
||||
return true
|
||||
}
|
||||
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
|
||||
f.l.WithField("from", via).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake")
|
||||
f.l.Info("public key mismatch between certificate and handshake",
|
||||
"from", via,
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
"cert", remoteCert,
|
||||
)
|
||||
return true
|
||||
}
|
||||
|
||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||
f.l.WithError(err).WithField("from", via).
|
||||
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("cert", remoteCert).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
Info("No networks in certificate")
|
||||
f.l.Info("No networks in certificate",
|
||||
"error", err,
|
||||
"from", via,
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"cert", remoteCert,
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -601,12 +716,14 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
||||
|
||||
// Ensure the right host responded
|
||||
if !correctHostResponded {
|
||||
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
||||
WithField("from", via).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
Info("Incorrect host responded to handshake")
|
||||
f.l.Info("Incorrect host responded to handshake",
|
||||
"intendedVpnAddrs", hostinfo.vpnAddrs,
|
||||
"haveVpnNetworks", vpnNetworks,
|
||||
"from", via,
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
)
|
||||
|
||||
// Release our old handshake from pending, it should not continue
|
||||
f.handshakeManager.DeleteHostInfo(hostinfo)
|
||||
@@ -618,10 +735,11 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
||||
newHH.hostinfo.remotes = hostinfo.remotes
|
||||
newHH.hostinfo.remotes.BlockRemote(via)
|
||||
|
||||
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
|
||||
WithField("vpnNetworks", vpnNetworks).
|
||||
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
|
||||
Info("Blocked addresses for handshakes")
|
||||
f.l.Info("Blocked addresses for handshakes",
|
||||
"blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes(),
|
||||
"vpnNetworks", vpnNetworks,
|
||||
"remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges()),
|
||||
)
|
||||
|
||||
// Swap the packet store to benefit the original intended recipient
|
||||
newHH.packetStore = hh.packetStore
|
||||
@@ -639,15 +757,20 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
||||
ci.window.Update(f.l, 2)
|
||||
|
||||
duration := time.Since(hh.startTime).Nanoseconds()
|
||||
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
WithField("durationNs", duration).
|
||||
WithField("sentCachedPackets", len(hh.packetStore))
|
||||
msgRxL := f.l.With(
|
||||
"vpnAddrs", vpnAddrs,
|
||||
"from", via,
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"fingerprint", fingerprint,
|
||||
"issuer", issuer,
|
||||
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||
"responderIndex", hs.Details.ResponderIndex,
|
||||
"remoteIndex", h.RemoteIndex,
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
"durationNs", duration,
|
||||
"sentCachedPackets", len(hh.packetStore),
|
||||
)
|
||||
if anyVpnAddrsInCommon {
|
||||
msgRxL.Info("Handshake message received")
|
||||
} else {
|
||||
@@ -663,8 +786,10 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
||||
f.handshakeManager.Complete(hostinfo, f)
|
||||
f.connectionManager.AddTrafficWatch(hostinfo)
|
||||
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
hostinfo.logger(f.l).Debug("Sending stored packets",
|
||||
"count", len(hh.packetStore),
|
||||
)
|
||||
}
|
||||
|
||||
if len(hh.packetStore) > 0 {
|
||||
|
||||
@@ -6,13 +6,13 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
@@ -59,7 +59,7 @@ type HandshakeManager struct {
|
||||
metricInitiated metrics.Counter
|
||||
metricTimedOut metrics.Counter
|
||||
f *Interface
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
|
||||
// can be used to trigger outbound handshake for the given vpnIp
|
||||
trigger chan netip.Addr
|
||||
@@ -78,32 +78,32 @@ type HandshakeHostInfo struct {
|
||||
hostinfo *HostInfo
|
||||
}
|
||||
|
||||
func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
|
||||
func (hh *HandshakeHostInfo) cachePacket(l *slog.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
|
||||
if len(hh.packetStore) < 100 {
|
||||
tempPacket := make([]byte, len(packet))
|
||||
copy(tempPacket, packet)
|
||||
|
||||
hh.packetStore = append(hh.packetStore, &cachedPacket{t, st, f, tempPacket})
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
hh.hostinfo.logger(l).
|
||||
WithField("length", len(hh.packetStore)).
|
||||
WithField("stored", true).
|
||||
Debugf("Packet store")
|
||||
if l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
hh.hostinfo.logger(l).Debug("Packet store",
|
||||
"length", len(hh.packetStore),
|
||||
"stored", true,
|
||||
)
|
||||
}
|
||||
|
||||
} else {
|
||||
m.dropped.Inc(1)
|
||||
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
hh.hostinfo.logger(l).
|
||||
WithField("length", len(hh.packetStore)).
|
||||
WithField("stored", false).
|
||||
Debugf("Packet store")
|
||||
if l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
hh.hostinfo.logger(l).Debug("Packet store",
|
||||
"length", len(hh.packetStore),
|
||||
"stored", false,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
|
||||
func NewHandshakeManager(l *slog.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
|
||||
return &HandshakeManager{
|
||||
vpnIps: map[netip.Addr]*HandshakeHostInfo{},
|
||||
indexes: map[uint32]*HandshakeHostInfo{},
|
||||
@@ -140,7 +140,7 @@ func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *head
|
||||
// First remote allow list check before we know the vpnIp
|
||||
if !via.IsRelayed {
|
||||
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) {
|
||||
hm.l.WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
||||
hm.l.Debug("lighthouse.remote_allow_list denied incoming handshake", "from", via)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -183,12 +183,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
hostinfo := hh.hostinfo
|
||||
// If we are out of time, clean up
|
||||
if hh.counter >= hm.config.retries {
|
||||
hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())).
|
||||
WithField("initiatorIndex", hh.hostinfo.localIndexId).
|
||||
WithField("remoteIndex", hh.hostinfo.remoteIndexId).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
WithField("durationNs", time.Since(hh.startTime).Nanoseconds()).
|
||||
Info("Handshake timed out")
|
||||
hh.hostinfo.logger(hm.l).Info("Handshake timed out",
|
||||
"udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()),
|
||||
"initiatorIndex", hh.hostinfo.localIndexId,
|
||||
"remoteIndex", hh.hostinfo.remoteIndexId,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
"durationNs", time.Since(hh.startTime).Nanoseconds(),
|
||||
)
|
||||
hm.metricTimedOut.Inc(1)
|
||||
hm.DeleteHostInfo(hostinfo)
|
||||
return
|
||||
@@ -241,10 +242,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
|
||||
err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
|
||||
if err != nil {
|
||||
hostinfo.logger(hm.l).WithField("udpAddr", addr).
|
||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
WithError(err).Error("Failed to send handshake message")
|
||||
hostinfo.logger(hm.l).Error("Failed to send handshake message",
|
||||
"udpAddr", addr,
|
||||
"initiatorIndex", hostinfo.localIndexId,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
"error", err,
|
||||
)
|
||||
|
||||
} else {
|
||||
sentTo = append(sentTo, addr)
|
||||
@@ -254,19 +257,21 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
// Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout,
|
||||
// so only log when the list of remotes has changed
|
||||
if remotesHaveChanged {
|
||||
hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
|
||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Info("Handshake message sent")
|
||||
} else if hm.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
|
||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Debug("Handshake message sent")
|
||||
hostinfo.logger(hm.l).Info("Handshake message sent",
|
||||
"udpAddrs", sentTo,
|
||||
"initiatorIndex", hostinfo.localIndexId,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
} else if hm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
hostinfo.logger(hm.l).Debug("Handshake message sent",
|
||||
"udpAddrs", sentTo,
|
||||
"initiatorIndex", hostinfo.localIndexId,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
}
|
||||
|
||||
if hm.config.useRelays && len(hostinfo.remotes.relays) > 0 {
|
||||
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
|
||||
hostinfo.logger(hm.l).Info("Attempt to relay through hosts", "relays", hostinfo.remotes.relays)
|
||||
// Send a RelayRequest to all known Relay IP's
|
||||
for _, relay := range hostinfo.remotes.relays {
|
||||
// Don't relay through the host I'm trying to connect to
|
||||
@@ -281,7 +286,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
|
||||
relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay)
|
||||
if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
|
||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
|
||||
hostinfo.logger(hm.l).Info("Establish tunnel to relay target", "relay", relay.String())
|
||||
hm.f.Handshake(relay)
|
||||
continue
|
||||
}
|
||||
@@ -292,7 +297,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
if relayHostInfo.remote.IsValid() {
|
||||
idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
|
||||
if err != nil {
|
||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
|
||||
hostinfo.logger(hm.l).Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err)
|
||||
}
|
||||
|
||||
m := NebulaControl{
|
||||
@@ -326,17 +331,15 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
|
||||
msg, err := m.Marshal()
|
||||
if err != nil {
|
||||
hostinfo.logger(hm.l).
|
||||
WithError(err).
|
||||
Error("Failed to marshal Control message to create relay")
|
||||
hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err)
|
||||
} else {
|
||||
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||
hm.l.WithFields(logrus.Fields{
|
||||
"relayFrom": hm.f.myVpnAddrs[0],
|
||||
"relayTo": vpnIp,
|
||||
"initiatorRelayIndex": idx,
|
||||
"relay": relay}).
|
||||
Info("send CreateRelayRequest")
|
||||
hm.l.Info("send CreateRelayRequest",
|
||||
"relayFrom", hm.f.myVpnAddrs[0],
|
||||
"relayTo", vpnIp,
|
||||
"initiatorRelayIndex", idx,
|
||||
"relay", relay,
|
||||
)
|
||||
}
|
||||
}
|
||||
continue
|
||||
@@ -344,14 +347,14 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
|
||||
switch existingRelay.State {
|
||||
case Established:
|
||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay")
|
||||
hostinfo.logger(hm.l).Info("Send handshake via relay", "relay", relay.String())
|
||||
hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
|
||||
case Disestablished:
|
||||
// Mark this relay as 'requested'
|
||||
relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested)
|
||||
fallthrough
|
||||
case Requested:
|
||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
|
||||
hostinfo.logger(hm.l).Info("Re-send CreateRelay request", "relay", relay.String())
|
||||
// Re-send the CreateRelay request, in case the previous one was lost.
|
||||
m := NebulaControl{
|
||||
Type: NebulaControl_CreateRelayRequest,
|
||||
@@ -383,28 +386,26 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
}
|
||||
msg, err := m.Marshal()
|
||||
if err != nil {
|
||||
hostinfo.logger(hm.l).
|
||||
WithError(err).
|
||||
Error("Failed to marshal Control message to create relay")
|
||||
hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err)
|
||||
} else {
|
||||
// This must send over the hostinfo, not over hm.Hosts[ip]
|
||||
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||
hm.l.WithFields(logrus.Fields{
|
||||
"relayFrom": hm.f.myVpnAddrs[0],
|
||||
"relayTo": vpnIp,
|
||||
"initiatorRelayIndex": existingRelay.LocalIndex,
|
||||
"relay": relay}).
|
||||
Info("send CreateRelayRequest")
|
||||
hm.l.Info("send CreateRelayRequest",
|
||||
"relayFrom", hm.f.myVpnAddrs[0],
|
||||
"relayTo", vpnIp,
|
||||
"initiatorRelayIndex", existingRelay.LocalIndex,
|
||||
"relay", relay,
|
||||
)
|
||||
}
|
||||
case PeerRequested:
|
||||
// PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case.
|
||||
fallthrough
|
||||
default:
|
||||
hostinfo.logger(hm.l).
|
||||
WithField("vpnIp", vpnIp).
|
||||
WithField("state", existingRelay.State).
|
||||
WithField("relay", relay).
|
||||
Errorf("Relay unexpected state")
|
||||
hostinfo.logger(hm.l).Error("Relay unexpected state",
|
||||
"vpnIp", vpnIp,
|
||||
"state", existingRelay.State,
|
||||
"relay", relay,
|
||||
)
|
||||
|
||||
}
|
||||
}
|
||||
@@ -549,9 +550,10 @@ func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
|
||||
if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] {
|
||||
// We have a collision, but this can happen since we can't control
|
||||
// the remote ID. Just log about the situation as a note.
|
||||
hostinfo.logger(hm.l).
|
||||
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
|
||||
Info("New host shadows existing host remoteIndex")
|
||||
hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex",
|
||||
"remoteIndex", hostinfo.remoteIndexId,
|
||||
"collision", existingRemoteIndex.vpnAddrs,
|
||||
)
|
||||
}
|
||||
|
||||
hm.mainHostMap.unlockedAddHostInfo(hostinfo, f)
|
||||
@@ -571,9 +573,10 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
|
||||
if found && existingRemoteIndex != nil {
|
||||
// We have a collision, but this can happen since we can't control
|
||||
// the remote ID. Just log about the situation as a note.
|
||||
hostinfo.logger(hm.l).
|
||||
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
|
||||
Info("New host shadows existing host remoteIndex")
|
||||
hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex",
|
||||
"remoteIndex", hostinfo.remoteIndexId,
|
||||
"collision", existingRemoteIndex.vpnAddrs,
|
||||
)
|
||||
}
|
||||
|
||||
// We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap.
|
||||
@@ -629,10 +632,11 @@ func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
|
||||
hm.indexes = map[uint32]*HandshakeHostInfo{}
|
||||
}
|
||||
|
||||
if hm.l.Level >= logrus.DebugLevel {
|
||||
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps),
|
||||
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
|
||||
Debug("Pending hostmap hostInfo deleted")
|
||||
if hm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
hm.l.Debug("Pending hostmap hostInfo deleted",
|
||||
"hostMap", m{"mapTotalSize": len(hm.vpnIps),
|
||||
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -700,7 +704,7 @@ func (hm *HandshakeManager) EmitStats() {
|
||||
|
||||
// Utility functions below
|
||||
|
||||
func generateIndex(l *logrus.Logger) (uint32, error) {
|
||||
func generateIndex(l *slog.Logger) (uint32, error) {
|
||||
b := make([]byte, 4)
|
||||
|
||||
// Let zero mean we don't know the ID, so don't generate zero
|
||||
@@ -708,16 +712,15 @@ func generateIndex(l *logrus.Logger) (uint32, error) {
|
||||
for index == 0 {
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
l.Errorln(err)
|
||||
l.Error("Failed to generate index", "error", err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
index = binary.BigEndian.Uint32(b)
|
||||
}
|
||||
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("index", index).
|
||||
Debug("Generated index")
|
||||
if l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
l.Debug("Generated index", "index", index)
|
||||
}
|
||||
return index, nil
|
||||
}
|
||||
|
||||
76
hostmap.go
76
hostmap.go
@@ -1,9 +1,11 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
@@ -13,10 +15,10 @@ import (
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/logging"
|
||||
)
|
||||
|
||||
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
|
||||
@@ -60,7 +62,7 @@ type HostMap struct {
|
||||
RemoteIndexes map[uint32]*HostInfo
|
||||
Hosts map[netip.Addr]*HostInfo
|
||||
preferredRanges atomic.Pointer[[]netip.Prefix]
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
}
|
||||
|
||||
// For synchronization, treat the pointed-to Relay struct as immutable. To edit the Relay
|
||||
@@ -313,7 +315,7 @@ type cachedPacketMetrics struct {
|
||||
dropped metrics.Counter
|
||||
}
|
||||
|
||||
func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
|
||||
func NewHostMapFromConfig(l *slog.Logger, c *config.C) *HostMap {
|
||||
hm := newHostMap(l)
|
||||
|
||||
hm.reload(c, true)
|
||||
@@ -321,13 +323,12 @@ func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
|
||||
hm.reload(c, false)
|
||||
})
|
||||
|
||||
l.WithField("preferredRanges", hm.GetPreferredRanges()).
|
||||
Info("Main HostMap created")
|
||||
l.Info("Main HostMap created", "preferredRanges", hm.GetPreferredRanges())
|
||||
|
||||
return hm
|
||||
}
|
||||
|
||||
func newHostMap(l *logrus.Logger) *HostMap {
|
||||
func newHostMap(l *slog.Logger) *HostMap {
|
||||
return &HostMap{
|
||||
Indexes: map[uint32]*HostInfo{},
|
||||
Relays: map[uint32]*HostInfo{},
|
||||
@@ -346,7 +347,10 @@ func (hm *HostMap) reload(c *config.C, initial bool) {
|
||||
preferredRange, err := netip.ParsePrefix(rawPreferredRange)
|
||||
|
||||
if err != nil {
|
||||
hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring")
|
||||
hm.l.Warn("Failed to parse preferred ranges, ignoring",
|
||||
"error", err,
|
||||
"range", rawPreferredRanges,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -355,7 +359,10 @@ func (hm *HostMap) reload(c *config.C, initial bool) {
|
||||
|
||||
oldRanges := hm.preferredRanges.Swap(&preferredRanges)
|
||||
if !initial {
|
||||
hm.l.WithField("oldPreferredRanges", *oldRanges).WithField("newPreferredRanges", preferredRanges).Info("preferred_ranges changed")
|
||||
hm.l.Info("preferred_ranges changed",
|
||||
"oldPreferredRanges", *oldRanges,
|
||||
"newPreferredRanges", preferredRanges,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -488,10 +495,11 @@ func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Ad
|
||||
hm.Indexes = map[uint32]*HostInfo{}
|
||||
}
|
||||
|
||||
if hm.l.Level >= logrus.DebugLevel {
|
||||
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts),
|
||||
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
|
||||
Debug("Hostmap hostInfo deleted")
|
||||
if hm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
hm.l.Debug("Hostmap hostInfo deleted",
|
||||
"hostMap", m{"mapTotalSize": len(hm.Hosts),
|
||||
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId},
|
||||
)
|
||||
}
|
||||
|
||||
if isLastHostinfo {
|
||||
@@ -615,10 +623,11 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
|
||||
hm.Indexes[hostinfo.localIndexId] = hostinfo
|
||||
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
|
||||
|
||||
if hm.l.Level >= logrus.DebugLevel {
|
||||
hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
|
||||
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}).
|
||||
Debug("Hostmap vpnIp added")
|
||||
if hm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
hm.l.Debug("Hostmap vpnIp added",
|
||||
"hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
|
||||
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -784,18 +793,21 @@ func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certifica
|
||||
}
|
||||
}
|
||||
|
||||
func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
|
||||
// logger returns a derived slog.Logger with per-hostinfo fields pre-bound.
|
||||
func (i *HostInfo) logger(l *slog.Logger) *slog.Logger {
|
||||
if i == nil {
|
||||
return logrus.NewEntry(l)
|
||||
return l
|
||||
}
|
||||
|
||||
li := l.WithField("vpnAddrs", i.vpnAddrs).
|
||||
WithField("localIndex", i.localIndexId).
|
||||
WithField("remoteIndex", i.remoteIndexId)
|
||||
li := l.With(
|
||||
"vpnAddrs", i.vpnAddrs,
|
||||
"localIndex", i.localIndexId,
|
||||
"remoteIndex", i.remoteIndexId,
|
||||
)
|
||||
|
||||
if connState := i.ConnectionState; connState != nil {
|
||||
if peerCert := connState.peerCert; peerCert != nil {
|
||||
li = li.WithField("certName", peerCert.Certificate.Name())
|
||||
li = li.With("certName", peerCert.Certificate.Name())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -804,14 +816,17 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
|
||||
|
||||
// Utility functions
|
||||
|
||||
func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
|
||||
func localAddrs(l *slog.Logger, allowList *LocalAllowList) []netip.Addr {
|
||||
//FIXME: This function is pretty garbage
|
||||
var finalAddrs []netip.Addr
|
||||
ifaces, _ := net.Interfaces()
|
||||
for _, i := range ifaces {
|
||||
allow := allowList.AllowName(i.Name)
|
||||
if l.Level >= logrus.TraceLevel {
|
||||
l.WithField("interfaceName", i.Name).WithField("allow", allow).Trace("localAllowList.AllowName")
|
||||
if l.Enabled(context.Background(), logging.LevelTrace) {
|
||||
l.Log(context.Background(), logging.LevelTrace, "localAllowList.AllowName",
|
||||
"interfaceName", i.Name,
|
||||
"allow", allow,
|
||||
)
|
||||
}
|
||||
|
||||
if !allow {
|
||||
@@ -829,8 +844,8 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
|
||||
}
|
||||
|
||||
if !addr.IsValid() {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("localAddr", rawAddr).Debug("addr was invalid")
|
||||
if l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
l.Debug("addr was invalid", "localAddr", rawAddr)
|
||||
}
|
||||
continue
|
||||
}
|
||||
@@ -838,8 +853,11 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
|
||||
|
||||
if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false {
|
||||
isAllowed := allowList.Allow(addr)
|
||||
if l.Level >= logrus.TraceLevel {
|
||||
l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow")
|
||||
if l.Enabled(context.Background(), logging.LevelTrace) {
|
||||
l.Log(context.Background(), logging.LevelTrace, "localAllowList.Allow",
|
||||
"localAddr", addr,
|
||||
"allowed", isAllowed,
|
||||
)
|
||||
}
|
||||
if !isAllowed {
|
||||
continue
|
||||
|
||||
@@ -196,7 +196,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
|
||||
|
||||
func TestHostMap_reload(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
c := config.NewC(l)
|
||||
c := config.NewC(test.NewLogger())
|
||||
|
||||
hm := NewHostMapFromConfig(l, c)
|
||||
|
||||
|
||||
119
inside.go
119
inside.go
@@ -1,9 +1,10 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
@@ -14,8 +15,11 @@ import (
|
||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
||||
err := newPacket(packet, false, fwPacket)
|
||||
if err != nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
f.l.Debug("Error while validating outbound packet",
|
||||
"packet", packet,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -35,7 +39,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
||||
if immediatelyForwardToSelf {
|
||||
_, err := f.readers[q].Write(packet)
|
||||
if err != nil {
|
||||
f.l.WithError(err).Error("Failed to forward to tun")
|
||||
f.l.Error("Failed to forward to tun", "error", err)
|
||||
}
|
||||
}
|
||||
// Otherwise, drop. On linux, we should never see these packets - Linux
|
||||
@@ -54,10 +58,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
||||
|
||||
if hostinfo == nil {
|
||||
f.rejectInside(packet, out, q)
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
|
||||
WithField("fwPacket", fwPacket).
|
||||
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
f.l.Debug("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks",
|
||||
"vpnAddr", fwPacket.RemoteAddr,
|
||||
"fwPacket", fwPacket,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -72,11 +77,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
||||
|
||||
} else {
|
||||
f.rejectInside(packet, out, q)
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(f.l).
|
||||
WithField("fwPacket", fwPacket).
|
||||
WithField("reason", dropReason).
|
||||
Debugln("dropping outbound packet")
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
hostinfo.logger(f.l).Debug("dropping outbound packet",
|
||||
"fwPacket", fwPacket,
|
||||
"reason", dropReason,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -93,7 +98,7 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
|
||||
|
||||
_, err := f.readers[q].Write(out)
|
||||
if err != nil {
|
||||
f.l.WithError(err).Error("Failed to write to tun")
|
||||
f.l.Error("Failed to write to tun", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,11 +113,11 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
|
||||
}
|
||||
|
||||
if len(out) > iputil.MaxRejectPacketSize {
|
||||
if f.l.GetLevel() >= logrus.InfoLevel {
|
||||
f.l.
|
||||
WithField("packet", packet).
|
||||
WithField("outPacket", out).
|
||||
Info("rejectOutside: packet too big, not sending")
|
||||
if f.l.Enabled(context.Background(), slog.LevelInfo) {
|
||||
f.l.Info("rejectOutside: packet too big, not sending",
|
||||
"packet", packet,
|
||||
"outPacket", out,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -184,10 +189,11 @@ func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cac
|
||||
// This would also need to interact with unsafe_route updates through reloading the config or
|
||||
// use of the use_system_route_table option
|
||||
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("destination", destinationAddr).
|
||||
WithField("originalGateway", gatewayAddr).
|
||||
Debugln("Calculated gateway for ECMP not available, attempting other gateways")
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
f.l.Debug("Calculated gateway for ECMP not available, attempting other gateways",
|
||||
"destination", destinationAddr,
|
||||
"originalGateway", gatewayAddr,
|
||||
)
|
||||
}
|
||||
|
||||
for i := range gateways {
|
||||
@@ -213,17 +219,18 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
||||
fp := &firewall.Packet{}
|
||||
err := newPacket(p, false, fp)
|
||||
if err != nil {
|
||||
f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
|
||||
f.l.Warn("error while parsing outgoing packet for firewall check", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// check if packet is in outbound fw rules
|
||||
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
|
||||
if dropReason != nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("fwPacket", fp).
|
||||
WithField("reason", dropReason).
|
||||
Debugln("dropping cached packet")
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
f.l.Debug("dropping cached packet",
|
||||
"fwPacket", fp,
|
||||
"reason", dropReason,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -239,9 +246,10 @@ func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.Message
|
||||
})
|
||||
|
||||
if hostInfo == nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("vpnAddr", vpnAddr).
|
||||
Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes")
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
f.l.Debug("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes",
|
||||
"vpnAddr", vpnAddr,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -297,12 +305,12 @@ func (f *Interface) SendVia(via *HostInfo,
|
||||
if noiseutil.EncryptLockNeeded {
|
||||
via.ConnectionState.writeLock.Unlock()
|
||||
}
|
||||
via.logger(f.l).
|
||||
WithField("outCap", cap(out)).
|
||||
WithField("payloadLen", len(ad)).
|
||||
WithField("headerLen", len(out)).
|
||||
WithField("cipherOverhead", via.ConnectionState.eKey.Overhead()).
|
||||
Error("SendVia out buffer not large enough for relay")
|
||||
via.logger(f.l).Error("SendVia out buffer not large enough for relay",
|
||||
"outCap", cap(out),
|
||||
"payloadLen", len(ad),
|
||||
"headerLen", len(out),
|
||||
"cipherOverhead", via.ConnectionState.eKey.Overhead(),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -322,12 +330,12 @@ func (f *Interface) SendVia(via *HostInfo,
|
||||
via.ConnectionState.writeLock.Unlock()
|
||||
}
|
||||
if err != nil {
|
||||
via.logger(f.l).WithError(err).Info("Failed to EncryptDanger in sendVia")
|
||||
via.logger(f.l).Info("Failed to EncryptDanger in sendVia", "error", err)
|
||||
return
|
||||
}
|
||||
err = f.writers[0].WriteTo(out, via.remote)
|
||||
if err != nil {
|
||||
via.logger(f.l).WithError(err).Info("Failed to WriteTo in sendVia")
|
||||
via.logger(f.l).Info("Failed to WriteTo in sendVia", "error", err)
|
||||
}
|
||||
f.connectionManager.RelayUsed(relay.LocalIndex)
|
||||
}
|
||||
@@ -366,8 +374,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
||||
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
|
||||
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
|
||||
hostinfo.lastRebindCount = f.rebindCount
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
f.l.Debug("Lighthouse update triggered for punch due to rebind counter",
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -377,24 +387,30 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
||||
ci.writeLock.Unlock()
|
||||
}
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).
|
||||
WithField("udpAddr", remote).WithField("counter", c).
|
||||
WithField("attemptedCounter", c).
|
||||
Error("Failed to encrypt outgoing packet")
|
||||
hostinfo.logger(f.l).Error("Failed to encrypt outgoing packet",
|
||||
"error", err,
|
||||
"udpAddr", remote,
|
||||
"counter", c,
|
||||
"attemptedCounter", c,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if remote.IsValid() {
|
||||
err = f.writers[q].WriteTo(out, remote)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).
|
||||
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
||||
hostinfo.logger(f.l).Error("Failed to write outgoing packet",
|
||||
"error", err,
|
||||
"udpAddr", remote,
|
||||
)
|
||||
}
|
||||
} else if hostinfo.remote.IsValid() {
|
||||
err = f.writers[q].WriteTo(out, hostinfo.remote)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).
|
||||
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
||||
hostinfo.logger(f.l).Error("Failed to write outgoing packet",
|
||||
"error", err,
|
||||
"udpAddr", remote,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// Try to send via a relay
|
||||
@@ -402,7 +418,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
||||
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
||||
if err != nil {
|
||||
hostinfo.relayState.DeleteRelay(relayIP)
|
||||
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
|
||||
hostinfo.logger(f.l).Info("sendNoMetrics failed to find HostInfo",
|
||||
"relay", relayIP,
|
||||
"error", err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
|
||||
|
||||
66
interface.go
66
interface.go
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -12,7 +13,7 @@ import (
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/slackhq/nebula/header"
|
||||
@@ -46,7 +47,7 @@ type InterfaceConfig struct {
|
||||
reQueryWait time.Duration
|
||||
|
||||
ConntrackCacheTimeout time.Duration
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
}
|
||||
|
||||
type Interface struct {
|
||||
@@ -100,7 +101,7 @@ type Interface struct {
|
||||
messageMetrics *MessageMetrics
|
||||
cachedPacketMetrics *cachedPacketMetrics
|
||||
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
}
|
||||
|
||||
type EncWriter interface {
|
||||
@@ -223,13 +224,16 @@ func (f *Interface) activate() error {
|
||||
|
||||
addr, err := f.outside.LocalAddr()
|
||||
if err != nil {
|
||||
f.l.WithError(err).Error("Failed to get udp listen address")
|
||||
f.l.Error("Failed to get udp listen address", "error", err)
|
||||
}
|
||||
|
||||
f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks).
|
||||
WithField("build", f.version).WithField("udpAddr", addr).
|
||||
WithField("boringcrypto", boringEnabled()).
|
||||
Info("Nebula interface is active")
|
||||
f.l.Info("Nebula interface is active",
|
||||
"interface", f.inside.Name(),
|
||||
"networks", f.myVpnNetworks,
|
||||
"build", f.version,
|
||||
"udpAddr", addr,
|
||||
"boringcrypto", boringEnabled(),
|
||||
)
|
||||
|
||||
if f.routines > 1 {
|
||||
if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() {
|
||||
@@ -305,7 +309,7 @@ func (f *Interface) listenOut(i int) {
|
||||
li = f.outside
|
||||
}
|
||||
|
||||
ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.conntrackCacheTimeout)
|
||||
ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
|
||||
lhh := f.lightHouse.NewRequestHandler()
|
||||
plaintext := make([]byte, udp.MTU)
|
||||
h := &header.H{}
|
||||
@@ -313,15 +317,15 @@ func (f *Interface) listenOut(i int) {
|
||||
nb := make([]byte, 12, 12)
|
||||
|
||||
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
||||
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
||||
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get())
|
||||
})
|
||||
|
||||
if err != nil && !f.closed.Load() {
|
||||
f.l.WithError(err).Error("Error while reading inbound packet, closing")
|
||||
f.l.Error("Error while reading inbound packet, closing", "error", err)
|
||||
f.onFatal(err)
|
||||
}
|
||||
|
||||
f.l.Debugf("underlay reader %v is done", i)
|
||||
f.l.Debug("underlay reader is done", "reader", i)
|
||||
}
|
||||
|
||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||
@@ -330,22 +334,22 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||
fwPacket := &firewall.Packet{}
|
||||
nb := make([]byte, 12, 12)
|
||||
|
||||
conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.conntrackCacheTimeout)
|
||||
conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
|
||||
|
||||
for {
|
||||
n, err := reader.Read(packet)
|
||||
if err != nil {
|
||||
if !f.closed.Load() {
|
||||
f.l.WithError(err).WithField("reader", i).Error("Error while reading outbound packet, closing")
|
||||
f.l.Error("Error while reading outbound packet, closing", "error", err, "reader", i)
|
||||
f.onFatal(err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
||||
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get())
|
||||
}
|
||||
|
||||
f.l.Debugf("overlay reader %v is done", i)
|
||||
f.l.Debug("overlay reader is done", "reader", i)
|
||||
}
|
||||
|
||||
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
||||
@@ -365,7 +369,7 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) {
|
||||
if initial || c.HasChanged("pki.disconnect_invalid") {
|
||||
f.disconnectInvalid.Store(c.GetBool("pki.disconnect_invalid", true))
|
||||
if !initial {
|
||||
f.l.Infof("pki.disconnect_invalid changed to %v", f.disconnectInvalid.Load())
|
||||
f.l.Info("pki.disconnect_invalid changed", "value", f.disconnectInvalid.Load())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -379,7 +383,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
|
||||
|
||||
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
|
||||
if err != nil {
|
||||
f.l.WithError(err).Error("Error while creating firewall during reload")
|
||||
f.l.Error("Error while creating firewall during reload", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -392,10 +396,11 @@ func (f *Interface) reloadFirewall(c *config.C) {
|
||||
// If rulesVersion is back to zero, we have wrapped all the way around. Be
|
||||
// safe and just reset conntrack in this case.
|
||||
if fw.rulesVersion == 0 {
|
||||
f.l.WithField("firewallHashes", fw.GetRuleHashes()).
|
||||
WithField("oldFirewallHashes", oldFw.GetRuleHashes()).
|
||||
WithField("rulesVersion", fw.rulesVersion).
|
||||
Warn("firewall rulesVersion has overflowed, resetting conntrack")
|
||||
f.l.Warn("firewall rulesVersion has overflowed, resetting conntrack",
|
||||
"firewallHashes", fw.GetRuleHashes(),
|
||||
"oldFirewallHashes", oldFw.GetRuleHashes(),
|
||||
"rulesVersion", fw.rulesVersion,
|
||||
)
|
||||
} else {
|
||||
fw.Conntrack = conntrack
|
||||
}
|
||||
@@ -403,10 +408,11 @@ func (f *Interface) reloadFirewall(c *config.C) {
|
||||
f.firewall = fw
|
||||
|
||||
oldFw.Destroy()
|
||||
f.l.WithField("firewallHashes", fw.GetRuleHashes()).
|
||||
WithField("oldFirewallHashes", oldFw.GetRuleHashes()).
|
||||
WithField("rulesVersion", fw.rulesVersion).
|
||||
Info("New firewall has been installed")
|
||||
f.l.Info("New firewall has been installed",
|
||||
"firewallHashes", fw.GetRuleHashes(),
|
||||
"oldFirewallHashes", oldFw.GetRuleHashes(),
|
||||
"rulesVersion", fw.rulesVersion,
|
||||
)
|
||||
}
|
||||
|
||||
func (f *Interface) reloadSendRecvError(c *config.C) {
|
||||
@@ -428,8 +434,7 @@ func (f *Interface) reloadSendRecvError(c *config.C) {
|
||||
}
|
||||
}
|
||||
|
||||
f.l.WithField("sendRecvError", f.sendRecvErrorConfig.String()).
|
||||
Info("Loaded send_recv_error config")
|
||||
f.l.Info("Loaded send_recv_error config", "sendRecvError", f.sendRecvErrorConfig.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -452,8 +457,7 @@ func (f *Interface) reloadAcceptRecvError(c *config.C) {
|
||||
}
|
||||
}
|
||||
|
||||
f.l.WithField("acceptRecvError", f.acceptRecvErrorConfig.String()).
|
||||
Info("Loaded accept_recv_error config")
|
||||
f.l.Info("Loaded accept_recv_error config", "acceptRecvError", f.acceptRecvErrorConfig.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -527,7 +531,7 @@ func (f *Interface) Close() error {
|
||||
for i, u := range f.writers {
|
||||
err := u.Close()
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("writer", i).Error("Error while closing udp socket")
|
||||
f.l.Error("Error while closing udp socket", "error", err, "writer", i)
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
228
lighthouse.go
228
lighthouse.go
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
@@ -15,10 +16,10 @@ import (
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/logging"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
@@ -76,12 +77,12 @@ type LightHouse struct {
|
||||
|
||||
metrics *MessageMetrics
|
||||
metricHolepunchTx metrics.Counter
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
}
|
||||
|
||||
// NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
|
||||
// addrMap should be nil unless this is during a config reload
|
||||
func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) {
|
||||
func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) {
|
||||
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
|
||||
nebulaPort := uint32(c.GetInt("listen.port", 0))
|
||||
if amLighthouse && nebulaPort == 0 {
|
||||
@@ -133,7 +134,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
|
||||
case *util.ContextualError:
|
||||
v.Log(l)
|
||||
case error:
|
||||
l.WithError(err).Error("failed to reload lighthouse")
|
||||
l.Error("failed to reload lighthouse", "error", err)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -205,8 +206,10 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
||||
//TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used
|
||||
addr := addrs[0].Unmap()
|
||||
if lh.myVpnNetworksTable.Contains(addr) {
|
||||
lh.l.WithField("addr", rawAddr).WithField("entry", i+1).
|
||||
Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range")
|
||||
lh.l.Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range",
|
||||
"addr", rawAddr,
|
||||
"entry", i+1,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -224,7 +227,9 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
||||
lh.interval.Store(int64(c.GetInt("lighthouse.interval", 10)))
|
||||
|
||||
if !initial {
|
||||
lh.l.Infof("lighthouse.interval changed to %v", lh.interval.Load())
|
||||
lh.l.Info("lighthouse.interval changed",
|
||||
"interval", lh.interval.Load(),
|
||||
)
|
||||
|
||||
if lh.updateCancel != nil {
|
||||
// May not always have a running routine
|
||||
@@ -336,9 +341,12 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
||||
for _, v := range c.GetStringSlice("relay.relays", nil) {
|
||||
configRIP, err := netip.ParseAddr(v)
|
||||
if err != nil {
|
||||
lh.l.WithField("relay", v).WithError(err).Warn("Parse relay from config failed")
|
||||
lh.l.Warn("Parse relay from config failed",
|
||||
"relay", v,
|
||||
"error", err,
|
||||
)
|
||||
} else {
|
||||
lh.l.WithField("relay", v).Info("Read relay from config")
|
||||
lh.l.Info("Read relay from config", "relay", v)
|
||||
relaysForMe = append(relaysForMe, configRIP)
|
||||
}
|
||||
}
|
||||
@@ -363,8 +371,10 @@ func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
|
||||
}
|
||||
|
||||
if !lh.myVpnNetworksTable.Contains(addr) {
|
||||
lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}).
|
||||
Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not")
|
||||
lh.l.Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not",
|
||||
"vpnAddr", addr,
|
||||
"networks", lh.myVpnNetworks,
|
||||
)
|
||||
}
|
||||
out[i] = addr
|
||||
}
|
||||
@@ -435,8 +445,11 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
|
||||
}
|
||||
|
||||
if !lh.myVpnNetworksTable.Contains(vpnAddr) {
|
||||
lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}).
|
||||
Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work")
|
||||
lh.l.Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work",
|
||||
"vpnAddr", vpnAddr,
|
||||
"networks", lh.myVpnNetworks,
|
||||
"entry", i+1,
|
||||
)
|
||||
}
|
||||
|
||||
vals, ok := v.([]any)
|
||||
@@ -537,12 +550,13 @@ func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) {
|
||||
lh.Lock()
|
||||
rm, ok := lh.addrMap[allVpnAddrs[0]]
|
||||
if ok {
|
||||
debugEnabled := lh.l.Enabled(context.Background(), slog.LevelDebug)
|
||||
for _, addr := range allVpnAddrs {
|
||||
srm := lh.addrMap[addr]
|
||||
if srm == rm {
|
||||
delete(lh.addrMap, addr)
|
||||
if lh.l.Level >= logrus.DebugLevel {
|
||||
lh.l.Debugf("deleting %s from lighthouse.", addr)
|
||||
if debugEnabled {
|
||||
lh.l.Debug("deleting from lighthouse", "vpnAddr", addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -659,9 +673,12 @@ func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
|
||||
|
||||
func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
|
||||
allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
|
||||
if lh.l.Level >= logrus.TraceLevel {
|
||||
lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow).
|
||||
Trace("remoteAllowList.Allow")
|
||||
if lh.l.Enabled(context.Background(), logging.LevelTrace) {
|
||||
lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow",
|
||||
"vpnAddrs", vpnAddrs,
|
||||
"udpAddr", to,
|
||||
"allow", allow,
|
||||
)
|
||||
}
|
||||
if !allow {
|
||||
return false
|
||||
@@ -678,9 +695,12 @@ func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
|
||||
func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bool {
|
||||
udpAddr := protoV4AddrPortToNetAddrPort(to)
|
||||
allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr())
|
||||
if lh.l.Level >= logrus.TraceLevel {
|
||||
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow).
|
||||
Trace("remoteAllowList.Allow")
|
||||
if lh.l.Enabled(context.Background(), logging.LevelTrace) {
|
||||
lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow",
|
||||
"vpnAddr", vpnAddr,
|
||||
"udpAddr", udpAddr,
|
||||
"allow", allow,
|
||||
)
|
||||
}
|
||||
|
||||
if !allow {
|
||||
@@ -698,9 +718,12 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bo
|
||||
func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bool {
|
||||
udpAddr := protoV6AddrPortToNetAddrPort(to)
|
||||
allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr())
|
||||
if lh.l.Level >= logrus.TraceLevel {
|
||||
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow).
|
||||
Trace("remoteAllowList.Allow")
|
||||
if lh.l.Enabled(context.Background(), logging.LevelTrace) {
|
||||
lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow",
|
||||
"vpnAddr", vpnAddr,
|
||||
"udpAddr", udpAddr,
|
||||
"allow", allow,
|
||||
)
|
||||
}
|
||||
|
||||
if !allow {
|
||||
@@ -775,8 +798,10 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
|
||||
|
||||
if v == cert.Version1 {
|
||||
if !addr.Is4() {
|
||||
lh.l.WithField("queryVpnAddr", addr).WithField("lighthouseAddr", lhVpnAddr).
|
||||
Error("Can't query lighthouse for v6 address using a v1 protocol")
|
||||
lh.l.Error("Can't query lighthouse for v6 address using a v1 protocol",
|
||||
"queryVpnAddr", addr,
|
||||
"lighthouseAddr", lhVpnAddr,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -787,9 +812,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
|
||||
|
||||
v1Query, err = msg.Marshal()
|
||||
if err != nil {
|
||||
lh.l.WithError(err).WithField("queryVpnAddr", addr).
|
||||
WithField("lighthouseAddr", lhVpnAddr).
|
||||
Error("Failed to marshal lighthouse v1 query payload")
|
||||
lh.l.Error("Failed to marshal lighthouse v1 query payload",
|
||||
"error", err,
|
||||
"queryVpnAddr", addr,
|
||||
"lighthouseAddr", lhVpnAddr,
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -804,9 +831,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
|
||||
|
||||
v2Query, err = msg.Marshal()
|
||||
if err != nil {
|
||||
lh.l.WithError(err).WithField("queryVpnAddr", addr).
|
||||
WithField("lighthouseAddr", lhVpnAddr).
|
||||
Error("Failed to marshal lighthouse v2 query payload")
|
||||
lh.l.Error("Failed to marshal lighthouse v2 query payload",
|
||||
"error", err,
|
||||
"queryVpnAddr", addr,
|
||||
"lighthouseAddr", lhVpnAddr,
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -815,7 +844,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
|
||||
queried++
|
||||
|
||||
} else {
|
||||
lh.l.Debugf("Can not query lighthouse for %v using unknown protocol version: %v", addr, v)
|
||||
lh.l.Debug("unsupported protocol version",
|
||||
"op", "query",
|
||||
"queryVpnAddr", addr,
|
||||
"version", v,
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -907,8 +940,9 @@ func (lh *LightHouse) SendUpdate() {
|
||||
if v == cert.Version1 {
|
||||
if v1Update == nil {
|
||||
if !lh.myVpnNetworks[0].Addr().Is4() {
|
||||
lh.l.WithField("lighthouseAddr", lhVpnAddr).
|
||||
Warn("cannot update lighthouse using v1 protocol without an IPv4 address")
|
||||
lh.l.Warn("cannot update lighthouse using v1 protocol without an IPv4 address",
|
||||
"lighthouseAddr", lhVpnAddr,
|
||||
)
|
||||
continue
|
||||
}
|
||||
var relays []uint32
|
||||
@@ -932,8 +966,10 @@ func (lh *LightHouse) SendUpdate() {
|
||||
|
||||
v1Update, err = msg.Marshal()
|
||||
if err != nil {
|
||||
lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr).
|
||||
Error("Error while marshaling for lighthouse v1 update")
|
||||
lh.l.Error("Error while marshaling for lighthouse v1 update",
|
||||
"error", err,
|
||||
"lighthouseAddr", lhVpnAddr,
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -959,8 +995,10 @@ func (lh *LightHouse) SendUpdate() {
|
||||
|
||||
v2Update, err = msg.Marshal()
|
||||
if err != nil {
|
||||
lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr).
|
||||
Error("Error while marshaling for lighthouse v2 update")
|
||||
lh.l.Error("Error while marshaling for lighthouse v2 update",
|
||||
"error", err,
|
||||
"lighthouseAddr", lhVpnAddr,
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -969,7 +1007,10 @@ func (lh *LightHouse) SendUpdate() {
|
||||
updated++
|
||||
|
||||
} else {
|
||||
lh.l.Debugf("Can not update lighthouse using unknown protocol version: %v", v)
|
||||
lh.l.Debug("unsupported protocol version",
|
||||
"op", "update",
|
||||
"version", v,
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -983,7 +1024,7 @@ type LightHouseHandler struct {
|
||||
out []byte
|
||||
pb []byte
|
||||
meta *NebulaMeta
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
}
|
||||
|
||||
func (lh *LightHouse) NewRequestHandler() *LightHouseHandler {
|
||||
@@ -1032,14 +1073,19 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [
|
||||
n := lhh.resetMeta()
|
||||
err := n.Unmarshal(p)
|
||||
if err != nil {
|
||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr).
|
||||
Error("Failed to unmarshal lighthouse packet")
|
||||
lhh.l.Error("Failed to unmarshal lighthouse packet",
|
||||
"error", err,
|
||||
"vpnAddrs", fromVpnAddrs,
|
||||
"udpAddr", rAddr,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if n.Details == nil {
|
||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr).
|
||||
Error("Invalid lighthouse update")
|
||||
lhh.l.Error("Invalid lighthouse update",
|
||||
"vpnAddrs", fromVpnAddrs,
|
||||
"udpAddr", rAddr,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1067,25 +1113,29 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [
|
||||
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) {
|
||||
// Exit if we don't answer queries
|
||||
if !lhh.lh.amLighthouse {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.Debugln("I don't answer queries, but received from: ", addr)
|
||||
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
lhh.l.Debug("I don't answer queries, but received one", "from", addr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
|
||||
if err != nil {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
|
||||
Debugln("Dropping malformed HostQuery")
|
||||
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
lhh.l.Debug("Dropping malformed HostQuery",
|
||||
"from", fromVpnAddrs,
|
||||
"details", n.Details,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
|
||||
// this case really shouldn't be possible to represent, but reject it anyway.
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
|
||||
Debugln("invalid vpn addr for v1 handleHostQuery")
|
||||
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
lhh.l.Debug("invalid vpn addr for v1 handleHostQuery",
|
||||
"vpnAddrs", fromVpnAddrs,
|
||||
"queryVpnAddr", queryVpnAddr,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -1110,7 +1160,10 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host query reply")
|
||||
lhh.l.Error("Failed to marshal lighthouse host query reply",
|
||||
"error", err,
|
||||
"vpnAddrs", fromVpnAddrs,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1138,8 +1191,10 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
|
||||
if ok {
|
||||
whereToPunch = newDest
|
||||
} else {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
|
||||
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
lhh.l.Debug("unable to punch to host, no addresses in common",
|
||||
"to", crt.Networks(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1165,7 +1220,10 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host was queried for")
|
||||
lhh.l.Error("Failed to marshal lighthouse host was queried for",
|
||||
"error", err,
|
||||
"vpnAddrs", fromVpnAddrs,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1207,8 +1265,11 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
|
||||
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
|
||||
}
|
||||
} else {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("version", v).Debug("unsupported protocol version")
|
||||
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
lhh.l.Debug("unsupported protocol version",
|
||||
"op", "coalesceAnswers",
|
||||
"version", v,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1221,8 +1282,11 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
|
||||
|
||||
certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
||||
if err != nil {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply")
|
||||
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
lhh.l.Error("dropping malformed HostQueryReply",
|
||||
"error", err,
|
||||
"vpnAddrs", fromVpnAddrs,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -1247,8 +1311,8 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
|
||||
|
||||
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
||||
if !lhh.lh.amLighthouse {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs)
|
||||
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
lhh.l.Debug("I am not a lighthouse, do not take host updates", "from", fromVpnAddrs)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -1271,8 +1335,11 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
||||
|
||||
//Simple check that the host sent this not someone else, if detailsVpnAddr is filled
|
||||
if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
|
||||
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
lhh.l.Debug("Host sent invalid update",
|
||||
"vpnAddrs", fromVpnAddrs,
|
||||
"answer", detailsVpnAddr,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -1294,7 +1361,9 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
||||
switch useVersion {
|
||||
case cert.Version1:
|
||||
if !fromVpnAddrs[0].Is4() {
|
||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
|
||||
lhh.l.Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message",
|
||||
"vpnAddrs", fromVpnAddrs,
|
||||
)
|
||||
return
|
||||
}
|
||||
vpnAddrB := fromVpnAddrs[0].As4()
|
||||
@@ -1302,13 +1371,16 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
||||
case cert.Version2:
|
||||
// do nothing, we want to send a blank message
|
||||
default:
|
||||
lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
|
||||
lhh.l.Error("invalid protocol version", "useVersion", useVersion)
|
||||
return
|
||||
}
|
||||
|
||||
ln, err := n.MarshalTo(lhh.pb)
|
||||
if err != nil {
|
||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host update ack")
|
||||
lhh.l.Error("Failed to marshal lighthouse host update ack",
|
||||
"error", err,
|
||||
"vpnAddrs", fromVpnAddrs,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1325,8 +1397,11 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
||||
|
||||
detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
||||
if err != nil {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification")
|
||||
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
lhh.l.Debug("dropping invalid HostPunchNotification",
|
||||
"details", n.Details,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -1343,8 +1418,11 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
||||
lhh.lh.punchConn.WriteTo(empty, vpnPeer)
|
||||
}()
|
||||
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr)
|
||||
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
lhh.l.Debug("Punching",
|
||||
"vpnPeer", vpnPeer,
|
||||
"logVpnAddr", logVpnAddr,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1369,8 +1447,10 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
||||
if lhh.lh.punchy.GetRespond() {
|
||||
go func() {
|
||||
time.Sleep(lhh.lh.punchy.GetRespondDelay())
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr)
|
||||
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
lhh.l.Debug("Sending a nebula test packet",
|
||||
"vpnAddr", detailsVpnAddr,
|
||||
)
|
||||
}
|
||||
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
|
||||
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
|
||||
|
||||
45
logger.go
45
logger.go
@@ -1,45 +0,0 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
)
|
||||
|
||||
func configLogger(l *logrus.Logger, c *config.C) error {
|
||||
// set up our logging level
|
||||
logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
|
||||
}
|
||||
l.SetLevel(logLevel)
|
||||
|
||||
disableTimestamp := c.GetBool("logging.disable_timestamp", false)
|
||||
timestampFormat := c.GetString("logging.timestamp_format", "")
|
||||
fullTimestamp := (timestampFormat != "")
|
||||
if timestampFormat == "" {
|
||||
timestampFormat = time.RFC3339
|
||||
}
|
||||
|
||||
logFormat := strings.ToLower(c.GetString("logging.format", "text"))
|
||||
switch logFormat {
|
||||
case "text":
|
||||
l.Formatter = &logrus.TextFormatter{
|
||||
TimestampFormat: timestampFormat,
|
||||
FullTimestamp: fullTimestamp,
|
||||
DisableTimestamp: disableTimestamp,
|
||||
}
|
||||
case "json":
|
||||
l.Formatter = &logrus.JSONFormatter{
|
||||
TimestampFormat: timestampFormat,
|
||||
DisableTimestamp: disableTimestamp,
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
233
logging/logger.go
Normal file
233
logging/logger.go
Normal file
@@ -0,0 +1,233 @@
|
||||
// Package logging wires the nebula runtime-reconfigurable slog handler used
|
||||
// by nebula.Main and the nebula CLI binaries. Callers build a logger with
|
||||
// NewLogger, then call ApplyConfig at startup and from a config reload
|
||||
// callback to push logging.level, logging.format, and
|
||||
// logging.disable_timestamp changes onto the logger without rebuilding it.
|
||||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config is the subset of *config.C that ApplyConfig reads. Declaring it
|
||||
// here keeps the logging package from depending on config directly, which
|
||||
// would cycle through the shared test helpers (test.NewLogger imports
|
||||
// logging, and config's tests import test). *config.C satisfies this
|
||||
// interface structurally with no adapter.
|
||||
type Config interface {
|
||||
GetString(key, def string) string
|
||||
GetBool(key string, def bool) bool
|
||||
}
|
||||
|
||||
// LevelTrace is a custom slog level below Debug, used when logging.level is
|
||||
// "trace". slog has no builtin trace level; the value is one step below
|
||||
// slog.LevelDebug in slog's 4-point spacing.
|
||||
const LevelTrace = slog.Level(-8)
|
||||
|
||||
// NewLogger returns a *slog.Logger whose level, format, and timestamp
|
||||
// emission can be reconfigured at runtime via ApplyConfig and the SSH debug
|
||||
// commands. The default configuration is info-level text output so log
|
||||
// calls made before ApplyConfig runs still produce output. Timestamps
|
||||
// follow slog's default RFC3339Nano format; set logging.disable_timestamp
|
||||
// in config to suppress them.
|
||||
//
|
||||
// ApplyConfig and the SSH commands discover the reconfig surface via
|
||||
// structural type-assertion on l.Handler(), so replacement implementations
|
||||
// (tests, platform-specific sinks) need only implement the subset of
|
||||
// {SetLevel(slog.Level), SetFormat(string) error, SetDisableTimestamp(bool)}
|
||||
// they care about. Callers that pass a plain *slog.Logger without these
|
||||
// methods get a silent no-op; reconfiguration is always opt-in.
|
||||
func NewLogger(w io.Writer) *slog.Logger {
|
||||
return slog.New(NewHandler(w))
|
||||
}
|
||||
|
||||
// NewHandler builds the *Handler that NewLogger wraps. Exported for
|
||||
// platform-specific sinks (notably cmd/nebula-service/logs_windows.go)
|
||||
// that want to wrap the handler with extra behavior, such as tagging each
|
||||
// record with its Event Log severity, while still benefiting from all the
|
||||
// level / format / timestamp / WithAttrs machinery implemented here.
|
||||
func NewHandler(w io.Writer) *Handler {
|
||||
root := &handlerRoot{}
|
||||
root.level.Set(slog.LevelInfo)
|
||||
opts := &slog.HandlerOptions{Level: &root.level}
|
||||
return &Handler{
|
||||
root: root,
|
||||
text: slog.NewTextHandler(w, opts),
|
||||
json: slog.NewJSONHandler(w, opts),
|
||||
}
|
||||
}
|
||||
|
||||
// handlerRoot carries the reconfiguration state shared by every logger
|
||||
// derived from a NewHandler call. All fields are consulted on the log
|
||||
// path and updated lock-free.
|
||||
type handlerRoot struct {
|
||||
level slog.LevelVar
|
||||
disableTimestamp atomic.Bool
|
||||
// jsonMode picks which of the pre-derived inner handlers Handler.Handle
|
||||
// dispatches to. Flipping it propagates instantly to every derived logger
|
||||
// without rebuilding or chain-replaying anything.
|
||||
jsonMode atomic.Bool
|
||||
}
|
||||
|
||||
// Handler is the slog.Handler returned by NewHandler. It holds two
|
||||
// pre-derived slog handlers -- one text, one json -- both built from the
|
||||
// same accumulated WithAttrs/WithGroup state. Handle picks which one to
|
||||
// dispatch to based on handlerRoot.jsonMode, so a SetFormat call takes
|
||||
// effect immediately across the whole process without having to rebuild
|
||||
// any derived loggers.
|
||||
type Handler struct {
|
||||
root *handlerRoot
|
||||
text slog.Handler
|
||||
json slog.Handler
|
||||
}
|
||||
|
||||
func (h *Handler) Enabled(_ context.Context, l slog.Level) bool {
|
||||
return h.root.level.Level() <= l
|
||||
}
|
||||
|
||||
func (h *Handler) Handle(ctx context.Context, r slog.Record) error {
|
||||
if h.root.disableTimestamp.Load() {
|
||||
r.Time = time.Time{}
|
||||
}
|
||||
if h.root.jsonMode.Load() {
|
||||
return h.json.Handle(ctx, r)
|
||||
}
|
||||
return h.text.Handle(ctx, r)
|
||||
}
|
||||
|
||||
func (h *Handler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||
if len(attrs) == 0 {
|
||||
return h
|
||||
}
|
||||
return &Handler{
|
||||
root: h.root,
|
||||
text: h.text.WithAttrs(attrs),
|
||||
json: h.json.WithAttrs(attrs),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) WithGroup(name string) slog.Handler {
|
||||
if name == "" {
|
||||
return h
|
||||
}
|
||||
return &Handler{
|
||||
root: h.root,
|
||||
text: h.text.WithGroup(name),
|
||||
json: h.json.WithGroup(name),
|
||||
}
|
||||
}
|
||||
|
||||
// SetLevel updates the effective log level. Propagates to every derived
|
||||
// logger via the shared LevelVar.
|
||||
func (h *Handler) SetLevel(level slog.Level) { h.root.level.Set(level) }
|
||||
|
||||
// GetLevel reports the current log level.
|
||||
func (h *Handler) GetLevel() slog.Level { return h.root.level.Level() }
|
||||
|
||||
// SetFormat flips the output format atomically. Valid formats are "text"
|
||||
// and "json". Every derived logger sees the new format on its next Handle
|
||||
// call; no rebuild or registration is required.
|
||||
func (h *Handler) SetFormat(format string) error {
|
||||
switch format {
|
||||
case "text":
|
||||
h.root.jsonMode.Store(false)
|
||||
case "json":
|
||||
h.root.jsonMode.Store(true)
|
||||
default:
|
||||
return fmt.Errorf("unknown log format `%s`. possible formats: %s", format, []string{"text", "json"})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFormat reports the currently selected format name.
|
||||
func (h *Handler) GetFormat() string {
|
||||
if h.root.jsonMode.Load() {
|
||||
return "json"
|
||||
}
|
||||
return "text"
|
||||
}
|
||||
|
||||
// SetDisableTimestamp toggles whether Handle zeroes r.Time before
|
||||
// dispatching (slog's builtin text/json handlers skip emitting the time
|
||||
// attribute on a zero time).
|
||||
func (h *Handler) SetDisableTimestamp(v bool) { h.root.disableTimestamp.Store(v) }
|
||||
|
||||
// ApplyConfig reads logging.level, logging.format, and (optionally)
|
||||
// logging.disable_timestamp from c and applies them to l. The reconfig
|
||||
// surface is discovered via structural type-assertion on l.Handler(), so
|
||||
// foreign handlers silently opt out of whichever capabilities they do not
|
||||
// implement.
|
||||
//
|
||||
// nebula.Main does NOT call this function on your behalf; callers that want
|
||||
// config-driven log level / format / timestamp updates invoke it at
|
||||
// startup and register it as a reload callback themselves. This keeps the
|
||||
// library from mutating an embedder's logger without their say-so.
|
||||
func ApplyConfig(l *slog.Logger, c Config) error {
|
||||
h := l.Handler()
|
||||
|
||||
lvl, err := ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ls, ok := h.(interface{ SetLevel(slog.Level) }); ok {
|
||||
ls.SetLevel(lvl)
|
||||
}
|
||||
|
||||
format := strings.ToLower(c.GetString("logging.format", "text"))
|
||||
if fs, ok := h.(interface{ SetFormat(string) error }); ok {
|
||||
if err := fs.SetFormat(format); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if ts, ok := h.(interface{ SetDisableTimestamp(bool) }); ok {
|
||||
ts.SetDisableTimestamp(c.GetBool("logging.disable_timestamp", false))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseLevel converts a config-string level name ("trace", "debug", "info",
|
||||
// "warn"/"warning", "error", "fatal"/"panic") to a slog.Level. "fatal" and
|
||||
// "panic" are accepted for backwards compatibility with pre-slog configs
|
||||
// and both map to slog.LevelError.
|
||||
func ParseLevel(s string) (slog.Level, error) {
|
||||
switch s {
|
||||
case "trace":
|
||||
return LevelTrace, nil
|
||||
case "debug":
|
||||
return slog.LevelDebug, nil
|
||||
case "info":
|
||||
return slog.LevelInfo, nil
|
||||
case "warn", "warning":
|
||||
return slog.LevelWarn, nil
|
||||
case "error":
|
||||
return slog.LevelError, nil
|
||||
case "fatal", "panic":
|
||||
return slog.LevelError, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("not a valid logging level: %q", s)
|
||||
}
|
||||
}
|
||||
|
||||
// LevelName returns a human-readable name for a slog.Level matching the
|
||||
// strings accepted by ParseLevel.
|
||||
func LevelName(l slog.Level) string {
|
||||
switch {
|
||||
case l <= LevelTrace:
|
||||
return "trace"
|
||||
case l <= slog.LevelDebug:
|
||||
return "debug"
|
||||
case l <= slog.LevelInfo:
|
||||
return "info"
|
||||
case l <= slog.LevelWarn:
|
||||
return "warn"
|
||||
default:
|
||||
return "error"
|
||||
}
|
||||
}
|
||||
90
logging/logger_bench_test.go
Normal file
90
logging/logger_bench_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// BenchmarkLogger_* compare the handler returned by NewLogger against a
|
||||
// stock slog text handler. The key thing we care about is the per-log
|
||||
// cost on a logger that has been derived via .With(), because that is the
|
||||
// shape subsystems store on their structs (HostInfo.logger(),
|
||||
// lh.l.With("subsystem", ...), etc.) and call from hot paths.
|
||||
|
||||
func BenchmarkLogger_Stock_RootInfo(b *testing.B) {
|
||||
l := slog.New(slog.DiscardHandler)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
l.Info("hello", "i", i)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLogger_Nebula_RootInfo(b *testing.B) {
|
||||
l := NewLogger(io.Discard)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
l.Info("hello", "i", i)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLogger_Stock_DerivedInfo(b *testing.B) {
|
||||
l := slog.New(slog.DiscardHandler).With(
|
||||
"subsystem", "bench",
|
||||
"localIndex", 1234,
|
||||
)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
l.Info("hello", "i", i)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLogger_Nebula_DerivedInfo(b *testing.B) {
|
||||
l := NewLogger(io.Discard).With(
|
||||
"subsystem", "bench",
|
||||
"localIndex", 1234,
|
||||
)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
l.Info("hello", "i", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Gated-off-path benchmarks: mimic the typical hot-path shape
|
||||
// `if l.Enabled(ctx, slog.LevelDebug) { ... }` where the log is gated below
|
||||
// the active level. This is the dominant pattern in inside.go/outside.go and
|
||||
// what we pay on every packet.
|
||||
func BenchmarkLogger_Stock_DerivedEnabledGateMiss(b *testing.B) {
|
||||
l := slog.New(slog.DiscardHandler).With(
|
||||
"subsystem", "bench",
|
||||
"localIndex", 1234,
|
||||
)
|
||||
ctx := context.Background()
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if l.Enabled(ctx, slog.LevelDebug) {
|
||||
l.Debug("hello", "i", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLogger_Nebula_DerivedEnabledGateMiss(b *testing.B) {
|
||||
l := NewLogger(io.Discard).With(
|
||||
"subsystem", "bench",
|
||||
"localIndex", 1234,
|
||||
)
|
||||
ctx := context.Background()
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if l.Enabled(ctx, slog.LevelDebug) {
|
||||
l.Debug("hello", "i", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
39
main.go
39
main.go
@@ -3,13 +3,13 @@ package nebula
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/overlay"
|
||||
"github.com/slackhq/nebula/sshd"
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
|
||||
type m = map[string]any
|
||||
|
||||
func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
|
||||
func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
|
||||
defer func() {
|
||||
@@ -33,11 +33,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
buildVersion = moduleVersion()
|
||||
}
|
||||
|
||||
l := logger
|
||||
l.Formatter = &logrus.TextFormatter{
|
||||
FullTimestamp: true,
|
||||
}
|
||||
|
||||
// Print the config if in test, the exit comes later
|
||||
if configTest {
|
||||
b, err := yaml.Marshal(c.Settings)
|
||||
@@ -46,21 +41,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
}
|
||||
|
||||
// Print the final config
|
||||
l.Println(string(b))
|
||||
l.Info(string(b))
|
||||
}
|
||||
|
||||
err := configLogger(l, c)
|
||||
if err != nil {
|
||||
return nil, util.ContextualizeIfNeeded("Failed to configure the logger", err)
|
||||
}
|
||||
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
err := configLogger(l, c)
|
||||
if err != nil {
|
||||
l.WithError(err).Error("Failed to configure the logger")
|
||||
}
|
||||
})
|
||||
|
||||
pki, err := NewPKIFromConfig(l, c)
|
||||
if err != nil {
|
||||
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
|
||||
@@ -70,9 +53,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
if err != nil {
|
||||
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
|
||||
}
|
||||
l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
|
||||
l.Info("Firewall started", "firewallHashes", fw.GetRuleHashes())
|
||||
|
||||
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
||||
ssh, err := sshd.NewSSHServer(l.With("subsystem", "sshd"))
|
||||
if err != nil {
|
||||
return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
|
||||
}
|
||||
@@ -81,7 +64,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
if c.GetBool("sshd.enabled", false) {
|
||||
sshStart, err = configSSH(l, ssh, c)
|
||||
if err != nil {
|
||||
l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available")
|
||||
l.Warn("Failed to configure sshd, ssh debugging will not be available", "error", err)
|
||||
sshStart = nil
|
||||
}
|
||||
}
|
||||
@@ -99,7 +82,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
routines = 1
|
||||
}
|
||||
if routines > 1 {
|
||||
l.WithField("routines", routines).Info("Using multiple routines")
|
||||
l.Info("Using multiple routines", "routines", routines)
|
||||
}
|
||||
} else {
|
||||
// deprecated and undocumented
|
||||
@@ -107,7 +90,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
udpQueues := c.GetInt("listen.routines", 1)
|
||||
routines = max(tunQueues, udpQueues)
|
||||
if routines != 1 {
|
||||
l.WithField("routines", routines).Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead")
|
||||
l.Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead", "routines", routines)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,7 +103,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
conntrackCacheTimeout = 1 * time.Second
|
||||
}
|
||||
if conntrackCacheTimeout > 0 {
|
||||
l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache")
|
||||
l.Info("Using routine-local conntrack cache", "duration", conntrackCacheTimeout)
|
||||
}
|
||||
|
||||
var tun overlay.Device
|
||||
@@ -166,7 +149,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
}
|
||||
|
||||
for i := 0; i < routines; i++ {
|
||||
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
|
||||
l.Info("listening", "addr", netip.AddrPortFrom(listenHost, uint16(port)))
|
||||
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
|
||||
if err != nil {
|
||||
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
||||
@@ -217,7 +200,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
|
||||
ds, err := newDnsServerFromConfig(ctx, l, pki.getCertState(), hostMap, c)
|
||||
if err != nil {
|
||||
l.WithError(err).Warn("Failed to start DNS responder")
|
||||
l.Warn("Failed to start DNS responder", "error", err)
|
||||
}
|
||||
|
||||
ifConfig := &InterfaceConfig{
|
||||
|
||||
149
outside.go
149
outside.go
@@ -1,15 +1,16 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
"golang.org/x/net/ipv6"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"golang.org/x/net/ipv4"
|
||||
@@ -24,7 +25,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
||||
if err != nil {
|
||||
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
||||
if len(packet) > 1 {
|
||||
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err)
|
||||
f.l.Info("Error while parsing inbound packet",
|
||||
"from", via,
|
||||
"error", err,
|
||||
"packet", packet,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -32,8 +37,8 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
||||
//l.Error("in packet ", header, packet[HeaderLen:])
|
||||
if !via.IsRelayed {
|
||||
if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("from", via).Debug("Refusing to process double encrypted packet")
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
f.l.Debug("Refusing to process double encrypted packet", "from", via)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -87,7 +92,10 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
||||
if !ok {
|
||||
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
|
||||
// its internal mapping. This should never happen.
|
||||
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
|
||||
hostinfo.logger(f.l).Error("HostInfo missing remote relay index",
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"remoteIndex", h.RemoteIndex,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -108,7 +116,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
||||
// Find the target HostInfo relay object
|
||||
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
|
||||
hostinfo.logger(f.l).Info("Failed to find target host info by ip",
|
||||
"relayTo", relay.PeerAddr,
|
||||
"error", err,
|
||||
"hostinfo.vpnAddrs", hostinfo.vpnAddrs,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -124,7 +136,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
||||
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
|
||||
}
|
||||
} else {
|
||||
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
|
||||
hostinfo.logger(f.l).Info("Unexpected target relay state",
|
||||
"relayTo", relay.PeerAddr,
|
||||
"relayFrom", hostinfo.vpnAddrs[0],
|
||||
"targetRelayState", targetRelay.State,
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -138,9 +154,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
||||
|
||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).WithField("from", via).
|
||||
WithField("packet", packet).
|
||||
Error("Failed to decrypt lighthouse packet")
|
||||
hostinfo.logger(f.l).Error("Failed to decrypt lighthouse packet",
|
||||
"error", err,
|
||||
"from", via,
|
||||
"packet", packet,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -157,9 +175,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
||||
|
||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).WithField("from", via).
|
||||
WithField("packet", packet).
|
||||
Error("Failed to decrypt test packet")
|
||||
hostinfo.logger(f.l).Error("Failed to decrypt test packet",
|
||||
"error", err,
|
||||
"from", via,
|
||||
"packet", packet,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -192,14 +212,15 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
||||
}
|
||||
_, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).WithField("from", via).
|
||||
WithField("packet", packet).
|
||||
Error("Failed to decrypt CloseTunnel packet")
|
||||
hostinfo.logger(f.l).Error("Failed to decrypt CloseTunnel packet",
|
||||
"error", err,
|
||||
"from", via,
|
||||
"packet", packet,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
hostinfo.logger(f.l).WithField("from", via).
|
||||
Info("Close tunnel received, tearing down.")
|
||||
hostinfo.logger(f.l).Info("Close tunnel received, tearing down.", "from", via)
|
||||
|
||||
f.closeTunnel(hostinfo)
|
||||
return
|
||||
@@ -211,9 +232,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
||||
|
||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).WithField("from", via).
|
||||
WithField("packet", packet).
|
||||
Error("Failed to decrypt Control packet")
|
||||
hostinfo.logger(f.l).Error("Failed to decrypt Control packet",
|
||||
"error", err,
|
||||
"from", via,
|
||||
"packet", packet,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -221,7 +244,9 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
||||
|
||||
default:
|
||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via)
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
hostinfo.logger(f.l).Debug("Unexpected packet received", "from", via)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -247,20 +272,27 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
|
||||
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) {
|
||||
if !via.IsRelayed && hostinfo.remote != via.UdpAddr {
|
||||
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
|
||||
hostinfo.logger(f.l).WithField("newAddr", via.UdpAddr).Debug("lighthouse.remote_allow_list denied roaming")
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
hostinfo.logger(f.l).Debug("lighthouse.remote_allow_list denied roaming", "newAddr", via.UdpAddr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
|
||||
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
hostinfo.logger(f.l).Debug("Suppressing roam back to previous remote",
|
||||
"suppressSeconds", RoamingSuppressSeconds,
|
||||
"udpAddr", hostinfo.remote,
|
||||
"newAddr", via.UdpAddr,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
|
||||
Info("Host roamed to new udp ip/port.")
|
||||
hostinfo.logger(f.l).Info("Host roamed to new udp ip/port.",
|
||||
"udpAddr", hostinfo.remote,
|
||||
"newAddr", via.UdpAddr,
|
||||
)
|
||||
hostinfo.lastRoam = time.Now()
|
||||
hostinfo.lastRoamRemote = hostinfo.remote
|
||||
hostinfo.SetRemote(via.UdpAddr)
|
||||
@@ -491,8 +523,9 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
|
||||
}
|
||||
|
||||
if !hostinfo.ConnectionState.window.Update(f.l, mc) {
|
||||
hostinfo.logger(f.l).WithField("header", h).
|
||||
Debugln("dropping out of window packet")
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
hostinfo.logger(f.l).Debug("dropping out of window packet", "header", h)
|
||||
}
|
||||
return nil, errors.New("out of window packet")
|
||||
}
|
||||
|
||||
@@ -504,20 +537,23 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
||||
|
||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
||||
hostinfo.logger(f.l).Error("Failed to decrypt packet", "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
err = newPacket(out, true, fwPacket)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
|
||||
Warnf("Error while validating inbound packet")
|
||||
hostinfo.logger(f.l).Warn("Error while validating inbound packet",
|
||||
"error", err,
|
||||
"packet", out,
|
||||
)
|
||||
return false
|
||||
}
|
||||
|
||||
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
|
||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
||||
Debugln("dropping out of window packet")
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
hostinfo.logger(f.l).Debug("dropping out of window packet", "fwPacket", fwPacket)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -526,10 +562,11 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
||||
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
||||
// This gives us a buffer to build the reject packet in
|
||||
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
||||
WithField("reason", dropReason).
|
||||
Debugln("dropping inbound packet")
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
hostinfo.logger(f.l).Debug("dropping inbound packet",
|
||||
"fwPacket", fwPacket,
|
||||
"reason", dropReason,
|
||||
)
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -537,7 +574,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
||||
f.connectionManager.In(hostinfo)
|
||||
_, err = f.readers[q].Write(out)
|
||||
if err != nil {
|
||||
f.l.WithError(err).Error("Failed to write to tun")
|
||||
f.l.Error("Failed to write to tun", "error", err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -553,35 +590,41 @@ func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) {
|
||||
|
||||
b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
|
||||
_ = f.outside.WriteTo(b, endpoint)
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("index", index).
|
||||
WithField("udpAddr", endpoint).
|
||||
Debug("Recv error sent")
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
f.l.Debug("Recv error sent",
|
||||
"index", index,
|
||||
"udpAddr", endpoint,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
|
||||
if !f.acceptRecvErrorConfig.ShouldRecvError(addr) {
|
||||
f.l.WithField("index", h.RemoteIndex).
|
||||
WithField("udpAddr", addr).
|
||||
Debug("Recv error received, ignoring")
|
||||
f.l.Debug("Recv error received, ignoring",
|
||||
"index", h.RemoteIndex,
|
||||
"udpAddr", addr,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("index", h.RemoteIndex).
|
||||
WithField("udpAddr", addr).
|
||||
Debug("Recv error received")
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
f.l.Debug("Recv error received",
|
||||
"index", h.RemoteIndex,
|
||||
"udpAddr", addr,
|
||||
)
|
||||
}
|
||||
|
||||
hostinfo := f.hostMap.QueryReverseIndex(h.RemoteIndex)
|
||||
if hostinfo == nil {
|
||||
f.l.WithField("remoteIndex", h.RemoteIndex).Debugln("Did not find remote index in main hostmap")
|
||||
f.l.Debug("Did not find remote index in main hostmap", "remoteIndex", h.RemoteIndex)
|
||||
return
|
||||
}
|
||||
|
||||
if hostinfo.remote.IsValid() && hostinfo.remote != addr {
|
||||
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
|
||||
f.l.Info("Someone spoofing recv_errors?",
|
||||
"addr", addr,
|
||||
"hostinfoRemote", hostinfo.remote,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
package test
|
||||
// Package overlaytest provides fakes of overlay.Device for tests that do
|
||||
// not want to touch a real tun device or route table.
|
||||
package overlaytest
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@@ -8,6 +10,9 @@ import (
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
// NoopTun is an overlay.Device that silently discards every read and write.
|
||||
// Useful in tests that need to construct a nebula Interface but do not
|
||||
// exercise the datapath.
|
||||
type NoopTun struct{}
|
||||
|
||||
func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways {
|
||||
@@ -2,6 +2,7 @@ package overlay
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -9,7 +10,6 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
@@ -48,11 +48,14 @@ func (r Route) String() string {
|
||||
return s
|
||||
}
|
||||
|
||||
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) {
|
||||
func makeRouteTree(l *slog.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) {
|
||||
routeTree := new(bart.Table[routing.Gateways])
|
||||
for _, r := range routes {
|
||||
if !allowMTU && r.MTU > 0 {
|
||||
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
|
||||
l.Warn("route MTU is not supported on this platform",
|
||||
"goos", runtime.GOOS,
|
||||
"route", r,
|
||||
)
|
||||
}
|
||||
|
||||
gateways := r.Via
|
||||
|
||||
@@ -295,7 +295,7 @@ func Test_makeRouteTree(t *testing.T) {
|
||||
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, routes, 2)
|
||||
routeTree, err := makeRouteTree(l, routes, true)
|
||||
routeTree, err := makeRouteTree(test.NewLogger(), routes, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
ip, err := netip.ParseAddr("1.0.0.2")
|
||||
@@ -367,7 +367,7 @@ func Test_makeMultipathUnsafeRouteTree(t *testing.T) {
|
||||
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, routes, 3)
|
||||
routeTree, err := makeRouteTree(l, routes, true)
|
||||
routeTree, err := makeRouteTree(test.NewLogger(), routes, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
ip, err := netip.ParseAddr("192.168.86.1")
|
||||
|
||||
@@ -2,10 +2,10 @@ package overlay
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
@@ -22,9 +22,9 @@ func (e *NameError) Error() string {
|
||||
}
|
||||
|
||||
// TODO: We may be able to remove routines
|
||||
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
||||
type DeviceFactory func(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
||||
|
||||
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
func NewDeviceFromConfig(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
switch {
|
||||
case c.GetBool("tun.disabled", false):
|
||||
tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
|
||||
@@ -36,7 +36,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Pref
|
||||
}
|
||||
|
||||
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
|
||||
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
return func(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
return newTunFromFd(c, l, *fd, vpnNetworks)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,12 +6,12 @@ package overlay
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
@@ -23,10 +23,10 @@ type tun struct {
|
||||
vpnNetworks []netip.Prefix
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
|
||||
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
@@ -53,7 +53,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTun not supported in Android")
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
@@ -14,7 +15,6 @@ import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
@@ -30,7 +30,7 @@ type tun struct {
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
linkAddr *netroute.LinkAddr
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
|
||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||
out []byte
|
||||
@@ -79,7 +79,7 @@ type ifreqAlias6 struct {
|
||||
Lifetime addrLifetime
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
name := c.GetString("tun.dev", "")
|
||||
ifIndex := -1
|
||||
if name != "" && name != "utun" {
|
||||
@@ -153,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
||||
return
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
|
||||
}
|
||||
|
||||
@@ -389,8 +389,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
||||
err := addRoute(r.Cidr, t.linkAddr)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EEXIST) {
|
||||
t.l.WithField("route", r.Cidr).
|
||||
Warnf("unable to add unsafe_route, identical route already exists")
|
||||
t.l.Warn("unable to add unsafe_route, identical route already exists", "route", r.Cidr)
|
||||
} else {
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||
if logErrors {
|
||||
@@ -400,7 +399,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Added route")
|
||||
t.l.Info("Added route", "route", r)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -415,9 +414,9 @@ func (t *tun) removeRoutes(routes []Route) error {
|
||||
|
||||
err := delRoute(r.Cidr, t.linkAddr)
|
||||
if err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
t.l.Error("Failed to remove route", "error", err, "route", r)
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
t.l.Info("Removed route", "route", r)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
@@ -19,10 +20,10 @@ type disabledTun struct {
|
||||
// Track these metrics since we don't have the tun device to do it for us
|
||||
tx metrics.Counter
|
||||
rx metrics.Counter
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
}
|
||||
|
||||
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
|
||||
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *slog.Logger) *disabledTun {
|
||||
tun := &disabledTun{
|
||||
vpnNetworks: vpnNetworks,
|
||||
read: make(chan []byte, queueLen),
|
||||
@@ -67,8 +68,8 @@ func (t *disabledTun) Read(b []byte) (int, error) {
|
||||
}
|
||||
|
||||
t.tx.Inc(1)
|
||||
if t.l.Level >= logrus.DebugLevel {
|
||||
t.l.WithField("raw", prettyPacket(r)).Debugf("Write payload")
|
||||
if t.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
t.l.Debug("Write payload", "raw", prettyPacket(r))
|
||||
}
|
||||
|
||||
return copy(b, r), nil
|
||||
@@ -85,7 +86,7 @@ func (t *disabledTun) handleICMPEchoRequest(b []byte) bool {
|
||||
select {
|
||||
case t.read <- out:
|
||||
default:
|
||||
t.l.Debugf("tun_disabled: dropped ICMP Echo Reply response")
|
||||
t.l.Debug("tun_disabled: dropped ICMP Echo Reply response")
|
||||
}
|
||||
|
||||
return true
|
||||
@@ -96,11 +97,11 @@ func (t *disabledTun) Write(b []byte) (int, error) {
|
||||
|
||||
// Check for ICMP Echo Request before spending time doing the full parsing
|
||||
if t.handleICMPEchoRequest(b) {
|
||||
if t.l.Level >= logrus.DebugLevel {
|
||||
t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request")
|
||||
if t.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
t.l.Debug("Disabled tun responded to ICMP Echo Request", "raw", prettyPacket(b))
|
||||
}
|
||||
} else if t.l.Level >= logrus.DebugLevel {
|
||||
t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload")
|
||||
} else if t.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
t.l.Debug("Disabled tun received unexpected payload", "raw", prettyPacket(b))
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
@@ -17,8 +18,9 @@ import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/slackhq/nebula/config"
|
||||
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
netroute "golang.org/x/net/route"
|
||||
@@ -93,7 +95,7 @@ type tun struct {
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
linkAddr *netroute.LinkAddr
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
|
||||
fd int
|
||||
shutdownR int // read end of the shutdown pipe; closing the write end wakes blocked polls
|
||||
@@ -243,7 +245,7 @@ func (t *tun) Close() error {
|
||||
|
||||
if t.fd >= 0 {
|
||||
if err := unix.Close(t.fd); err != nil {
|
||||
t.l.WithError(err).Error("Error closing device")
|
||||
t.l.Error("Error closing device", "error", err)
|
||||
}
|
||||
t.fd = -1
|
||||
}
|
||||
@@ -264,7 +266,7 @@ func (t *tun) Close() error {
|
||||
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
||||
}
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("Error destroying tunnel")
|
||||
t.l.Error("Error destroying tunnel", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -277,11 +279,11 @@ func (t *tun) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open existing tun device
|
||||
var fd int
|
||||
var err error
|
||||
@@ -584,7 +586,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
||||
return retErr
|
||||
}
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Added route")
|
||||
t.l.Info("Added route", "route", r)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -599,9 +601,9 @@ func (t *tun) removeRoutes(routes []Route) error {
|
||||
|
||||
err := delRoute(r.Cidr, t.linkAddr)
|
||||
if err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
t.l.Error("Failed to remove route", "error", err, "route", r)
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
t.l.Info("Removed route", "route", r)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
@@ -14,7 +15,6 @@ import (
|
||||
"syscall"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
@@ -25,14 +25,14 @@ type tun struct {
|
||||
vpnNetworks []netip.Prefix
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
}
|
||||
|
||||
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTun not supported in iOS")
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
|
||||
t := &tun{
|
||||
vpnNetworks: vpnNetworks,
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
@@ -17,7 +18,6 @@ import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
@@ -213,7 +213,7 @@ type tun struct {
|
||||
routesFromSystem map[netip.Prefix]routing.Gateways
|
||||
routesFromSystemLock sync.Mutex
|
||||
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
}
|
||||
|
||||
func (t *tun) Networks() []netip.Prefix {
|
||||
@@ -238,7 +238,7 @@ type ifreqQLEN struct {
|
||||
pad [8]byte
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
t, err := newTunGeneric(c, l, deviceFd, vpnNetworks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -249,7 +249,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
|
||||
@@ -299,7 +299,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
||||
}
|
||||
|
||||
// newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error.
|
||||
func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
func newTunGeneric(c *config.C, l *slog.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
tfd, err := newTunFd(fd)
|
||||
if err != nil {
|
||||
_ = unix.Close(fd)
|
||||
@@ -378,16 +378,16 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
||||
if !initial {
|
||||
if oldMaxMTU != newMaxMTU {
|
||||
t.setMTU()
|
||||
t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU)
|
||||
t.l.Info("Set max MTU", "mtu", t.MaxMTU, "oldMTU", oldMaxMTU)
|
||||
}
|
||||
|
||||
if oldDefaultMTU != newDefaultMTU {
|
||||
for i := range t.vpnNetworks {
|
||||
err := t.setDefaultRoute(t.vpnNetworks[i])
|
||||
if err != nil {
|
||||
t.l.Warn(err)
|
||||
t.l.Warn(err.Error())
|
||||
} else {
|
||||
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
|
||||
t.l.Info("Set default MTU", "mtu", t.DefaultMTU, "oldMTU", oldDefaultMTU)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -492,9 +492,9 @@ func (t *tun) addIPs(link netlink.Link) error {
|
||||
}
|
||||
err = netlink.AddrDel(link, &al[i])
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("failed to remove address from tun address list")
|
||||
t.l.Error("failed to remove address from tun address list", "error", err)
|
||||
} else {
|
||||
t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
|
||||
t.l.Info("removed address not listed in cert(s)", "removed", al[i].String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -538,12 +538,12 @@ func (t *tun) Activate() error {
|
||||
ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
|
||||
if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
|
||||
// If we can't set the queue length nebula will still work but it may lead to packet loss
|
||||
t.l.WithError(err).Error("Failed to set tun tx queue length")
|
||||
t.l.Error("Failed to set tun tx queue length", "error", err)
|
||||
}
|
||||
|
||||
const modeNone = 1
|
||||
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
|
||||
t.l.WithError(err).Warn("Failed to disable link local address generation")
|
||||
t.l.Warn("Failed to disable link local address generation", "error", err)
|
||||
}
|
||||
|
||||
if err = t.addIPs(link); err != nil {
|
||||
@@ -582,7 +582,7 @@ func (t *tun) setMTU() {
|
||||
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)}
|
||||
if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
|
||||
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well
|
||||
t.l.WithError(err).Error("Failed to set tun mtu")
|
||||
t.l.Error("Failed to set tun mtu", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -605,7 +605,7 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
|
||||
}
|
||||
err := netlink.RouteReplace(&nr)
|
||||
if err != nil {
|
||||
t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying")
|
||||
t.l.Warn("Failed to set default route MTU, retrying", "error", err, "cidr", cidr)
|
||||
//retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument`
|
||||
for i := 0; i < 2; i++ {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
@@ -613,7 +613,11 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
|
||||
if err == nil {
|
||||
break
|
||||
} else {
|
||||
t.l.WithError(err).WithField("cidr", cidr).WithField("mtu", t.DefaultMTU).Warn("Failed to set default route MTU, retrying")
|
||||
t.l.Warn("Failed to set default route MTU, retrying",
|
||||
"error", err,
|
||||
"cidr", cidr,
|
||||
"mtu", t.DefaultMTU,
|
||||
)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
@@ -658,7 +662,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
||||
return retErr
|
||||
}
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Added route")
|
||||
t.l.Info("Added route", "route", r)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -690,9 +694,9 @@ func (t *tun) removeRoutes(routes []Route) {
|
||||
|
||||
err := netlink.RouteDel(&nr)
|
||||
if err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
t.l.Error("Failed to remove route", "error", err, "route", r)
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
t.l.Info("Removed route", "route", r)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -721,11 +725,11 @@ func (t *tun) watchRoutes() {
|
||||
netlinkOptions := netlink.RouteSubscribeOptions{
|
||||
ReceiveBufferSize: t.useSystemRoutesBufferSize,
|
||||
ReceiveBufferForceSize: t.useSystemRoutesBufferSize != 0,
|
||||
ErrorCallback: func(e error) { t.l.WithError(e).Errorf("netlink error") },
|
||||
ErrorCallback: func(e error) { t.l.Error("netlink error", "error", e) },
|
||||
}
|
||||
|
||||
if err := netlink.RouteSubscribeWithOptions(rch, doneChan, netlinkOptions); err != nil {
|
||||
t.l.WithError(err).Errorf("failed to subscribe to system route changes")
|
||||
t.l.Error("failed to subscribe to system route changes", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -767,7 +771,7 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
|
||||
|
||||
link, err := netlink.LinkByName(t.Device)
|
||||
if err != nil {
|
||||
t.l.WithField("deviceName", t.Device).Error("Ignoring route update: failed to get link by name")
|
||||
t.l.Error("Ignoring route update: failed to get link by name", "deviceName", t.Device)
|
||||
return gateways
|
||||
}
|
||||
|
||||
@@ -779,10 +783,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
|
||||
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
|
||||
} else {
|
||||
// Gateway isn't in our overlay network, ignore
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
|
||||
t.l.Debug("Ignoring route update, gateway is not in our network", "route", r)
|
||||
}
|
||||
} else {
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
|
||||
t.l.Debug("Ignoring route update, invalid gateway or via address", "route", r)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -795,10 +799,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
|
||||
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
|
||||
} else {
|
||||
// Gateway isn't in our overlay network, ignore
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
|
||||
t.l.Debug("Ignoring route update, gateway is not in our network", "route", r)
|
||||
}
|
||||
} else {
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
|
||||
t.l.Debug("Ignoring route update, invalid gateway or via address", "route", r)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -830,18 +834,18 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
||||
gateways := t.getGatewaysFromRoute(&r.Route)
|
||||
if len(gateways) == 0 {
|
||||
// No gateways relevant to our network, no routing changes required.
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
|
||||
t.l.Debug("Ignoring route update, no gateways", "route", r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Dst == nil {
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, no destination address")
|
||||
t.l.Debug("Ignoring route update, no destination address", "route", r)
|
||||
return
|
||||
}
|
||||
|
||||
dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
|
||||
if !ok {
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address")
|
||||
t.l.Debug("Ignoring route update, invalid destination address", "route", r)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -852,12 +856,12 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
||||
|
||||
t.routesFromSystemLock.Lock()
|
||||
if r.Type == unix.RTM_NEWROUTE {
|
||||
t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route")
|
||||
t.l.Info("Adding route", "destination", dst, "via", gateways)
|
||||
t.routesFromSystem[dst] = gateways
|
||||
newTree.Insert(dst, gateways)
|
||||
|
||||
} else {
|
||||
t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route")
|
||||
t.l.Info("Removing route", "destination", dst, "via", gateways)
|
||||
delete(t.routesFromSystem, dst)
|
||||
newTree.Delete(dst)
|
||||
}
|
||||
@@ -888,18 +892,18 @@ func (t *tun) Close() error {
|
||||
}
|
||||
err := t.readers[i].Close()
|
||||
if err != nil {
|
||||
t.l.WithField("reader", i).WithError(err).Error("error closing tun reader")
|
||||
t.l.Error("error closing tun reader", "reader", i, "error", err)
|
||||
} else {
|
||||
t.l.WithField("reader", i).Info("closed tun reader")
|
||||
t.l.Info("closed tun reader", "reader", i)
|
||||
}
|
||||
}
|
||||
|
||||
//this is t.readers[0] too
|
||||
err := t.tunFile.Close()
|
||||
if err != nil {
|
||||
t.l.WithField("reader", 0).WithError(err).Error("error closing tun reader")
|
||||
t.l.Error("error closing tun reader", "reader", 0, "error", err)
|
||||
} else {
|
||||
t.l.WithField("reader", 0).Info("closed tun reader")
|
||||
t.l.Info("closed tun reader", "reader", 0)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"os"
|
||||
"regexp"
|
||||
@@ -15,7 +16,6 @@ import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
@@ -63,18 +63,18 @@ type tun struct {
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
f *os.File
|
||||
fd int
|
||||
}
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open tun device
|
||||
var err error
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
@@ -92,7 +92,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
|
||||
err = unix.SetNonblock(fd, true)
|
||||
if err != nil {
|
||||
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
|
||||
l.Warn("Failed to set the tun device as nonblocking", "error", err)
|
||||
}
|
||||
|
||||
t := &tun{
|
||||
@@ -416,7 +416,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
||||
return retErr
|
||||
}
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Added route")
|
||||
t.l.Info("Added route", "route", r)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -431,9 +431,9 @@ func (t *tun) removeRoutes(routes []Route) error {
|
||||
|
||||
err := delRoute(r.Cidr, t.vpnNetworks)
|
||||
if err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
t.l.Error("Failed to remove route", "error", err, "route", r)
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
t.l.Info("Removed route", "route", r)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"os"
|
||||
"regexp"
|
||||
@@ -15,7 +16,6 @@ import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
@@ -54,7 +54,7 @@ type tun struct {
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
f *os.File
|
||||
fd int
|
||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||
@@ -63,11 +63,11 @@ type tun struct {
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open tun device
|
||||
var err error
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
@@ -85,7 +85,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
|
||||
err = unix.SetNonblock(fd, true)
|
||||
if err != nil {
|
||||
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
|
||||
l.Warn("Failed to set the tun device as nonblocking", "error", err)
|
||||
}
|
||||
|
||||
t := &tun{
|
||||
@@ -336,7 +336,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
||||
return retErr
|
||||
}
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Added route")
|
||||
t.l.Info("Added route", "route", r)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -351,9 +351,9 @@ func (t *tun) removeRoutes(routes []Route) error {
|
||||
|
||||
err := delRoute(r.Cidr, t.vpnNetworks)
|
||||
if err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
t.l.Error("Failed to remove route", "error", err, "route", r)
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
t.l.Info("Removed route", "route", r)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -4,14 +4,15 @@
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
@@ -21,14 +22,14 @@ type TestTun struct {
|
||||
vpnNetworks []netip.Prefix
|
||||
Routes []Route
|
||||
routeTree *bart.Table[routing.Gateways]
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
|
||||
closed atomic.Bool
|
||||
rxPackets chan []byte // Packets to receive into nebula
|
||||
TxPackets chan []byte // Packets transmitted outside by nebula
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
|
||||
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
|
||||
_, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -49,7 +50,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*TestTun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported")
|
||||
}
|
||||
|
||||
@@ -61,8 +62,8 @@ func (t *TestTun) Send(packet []byte) {
|
||||
return
|
||||
}
|
||||
|
||||
if t.l.Level >= logrus.DebugLevel {
|
||||
t.l.WithField("dataLen", len(packet)).Debug("Tun receiving injected packet")
|
||||
if t.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
t.l.Debug("Tun receiving injected packet", "dataLen", len(packet))
|
||||
}
|
||||
t.rxPackets <- packet
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"crypto"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -16,7 +17,6 @@ import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
@@ -33,16 +33,16 @@ type winTun struct {
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
|
||||
tun *wintun.NativeTun
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) {
|
||||
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (Device, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
|
||||
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
|
||||
err := checkWinTunExists()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can not load the wintun driver: %w", err)
|
||||
@@ -71,7 +71,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
if err != nil {
|
||||
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
|
||||
// Trying a second time resolves the issue.
|
||||
l.WithError(err).Debug("Failed to create wintun device, retrying")
|
||||
l.Debug("Failed to create wintun device, retrying", "error", err)
|
||||
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
|
||||
if err != nil {
|
||||
return nil, &NameError{
|
||||
@@ -170,7 +170,7 @@ func (t *winTun) addRoutes(logErrors bool) error {
|
||||
return retErr
|
||||
}
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Added route")
|
||||
t.l.Info("Added route", "route", r)
|
||||
}
|
||||
|
||||
if !foundDefault4 {
|
||||
@@ -208,9 +208,9 @@ func (t *winTun) removeRoutes(routes []Route) error {
|
||||
// See comment on luid.AddRoute
|
||||
err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
|
||||
if err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
t.l.Error("Failed to remove route", "error", err, "route", r)
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
t.l.Info("Removed route", "route", r)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -2,14 +2,14 @@ package overlay
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
func NewUserDeviceFromConfig(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
return NewUserDevice(vpnNetworks)
|
||||
}
|
||||
|
||||
|
||||
18
pki.go
18
pki.go
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
@@ -15,7 +16,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/util"
|
||||
@@ -24,7 +24,7 @@ import (
|
||||
type PKI struct {
|
||||
cs atomic.Pointer[CertState]
|
||||
caPool atomic.Pointer[cert.CAPool]
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
}
|
||||
|
||||
type CertState struct {
|
||||
@@ -46,7 +46,7 @@ type CertState struct {
|
||||
myVpnBroadcastAddrsTable *bart.Lite
|
||||
}
|
||||
|
||||
func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
|
||||
func NewPKIFromConfig(l *slog.Logger, c *config.C) (*PKI, error) {
|
||||
pki := &PKI{l: l}
|
||||
err := pki.reload(c, true)
|
||||
if err != nil {
|
||||
@@ -182,9 +182,9 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
||||
p.cs.Store(newState)
|
||||
|
||||
if initial {
|
||||
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
|
||||
p.l.Debug("Client nebula certificate(s)", "cert", newState)
|
||||
} else {
|
||||
p.l.WithField("cert", newState).Info("Client certificate(s) refreshed from disk")
|
||||
p.l.Info("Client certificate(s) refreshed from disk", "cert", newState)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -196,7 +196,7 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
|
||||
}
|
||||
|
||||
p.caPool.Store(caPool)
|
||||
p.l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
|
||||
p.l.Debug("Trusted CA fingerprints", "fingerprints", caPool.GetFingerprints())
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -487,7 +487,7 @@ func loadCertificate(b []byte) (cert.Certificate, []byte, error) {
|
||||
return c, b, nil
|
||||
}
|
||||
|
||||
func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
|
||||
func loadCAPoolFromConfig(l *slog.Logger, c *config.C) (*cert.CAPool, error) {
|
||||
caPathOrPEM := c.GetString("pki.ca", "")
|
||||
if caPathOrPEM == "" {
|
||||
return nil, errors.New("no pki.ca path or PEM data provided")
|
||||
@@ -512,7 +512,7 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
|
||||
for _, crt := range caPool.CAs {
|
||||
if crt.Certificate.Expired(time.Now()) {
|
||||
expired++
|
||||
l.WithField("cert", crt).Warn("expired certificate present in CA pool")
|
||||
l.Warn("expired certificate present in CA pool", "cert", crt)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -530,7 +530,7 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
|
||||
caPool.BlocklistFingerprint(fp)
|
||||
}
|
||||
|
||||
l.WithField("fingerprintCount", len(bl)).Info("Blocklisted certificates")
|
||||
l.Info("Blocklisted certificates", "fingerprintCount", len(bl))
|
||||
}
|
||||
|
||||
return caPool, nil
|
||||
|
||||
@@ -41,7 +41,7 @@ func BenchmarkReloadConfigWithCAs(b *testing.B) {
|
||||
c := config.NewC(l)
|
||||
require.NoError(b, c.Load(dir))
|
||||
|
||||
_, err := NewPKIFromConfig(l, c)
|
||||
_, err := NewPKIFromConfig(test.NewLogger(), c)
|
||||
require.NoError(b, err)
|
||||
|
||||
b.ReportAllocs()
|
||||
|
||||
14
punchy.go
14
punchy.go
@@ -1,10 +1,10 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
)
|
||||
|
||||
@@ -14,10 +14,10 @@ type Punchy struct {
|
||||
delay atomic.Int64
|
||||
respondDelay atomic.Int64
|
||||
punchEverything atomic.Bool
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
}
|
||||
|
||||
func NewPunchyFromConfig(l *logrus.Logger, c *config.C) *Punchy {
|
||||
func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy {
|
||||
p := &Punchy{l: l}
|
||||
|
||||
p.reload(c, true)
|
||||
@@ -62,7 +62,7 @@ func (p *Punchy) reload(c *config.C, initial bool) {
|
||||
p.respond.Store(yes)
|
||||
|
||||
if !initial {
|
||||
p.l.Infof("punchy.respond changed to %v", p.GetRespond())
|
||||
p.l.Info("punchy.respond changed", "respond", p.GetRespond())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,21 +70,21 @@ func (p *Punchy) reload(c *config.C, initial bool) {
|
||||
if initial || c.HasChanged("punchy.delay") {
|
||||
p.delay.Store((int64)(c.GetDuration("punchy.delay", time.Second)))
|
||||
if !initial {
|
||||
p.l.Infof("punchy.delay changed to %s", p.GetDelay())
|
||||
p.l.Info("punchy.delay changed", "delay", p.GetDelay())
|
||||
}
|
||||
}
|
||||
|
||||
if initial || c.HasChanged("punchy.target_all_remotes") {
|
||||
p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false))
|
||||
if !initial {
|
||||
p.l.WithField("target_all_remotes", p.GetTargetEverything()).Info("punchy.target_all_remotes changed")
|
||||
p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", p.GetTargetEverything())
|
||||
}
|
||||
}
|
||||
|
||||
if initial || c.HasChanged("punchy.respond_delay") {
|
||||
p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second)))
|
||||
if !initial {
|
||||
p.l.Infof("punchy.respond_delay changed to %s", p.GetRespondDelay())
|
||||
p.l.Info("punchy.respond_delay changed", "respond_delay", p.GetRespondDelay())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
173
punchy_test.go
173
punchy_test.go
@@ -1,6 +1,8 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -15,7 +17,7 @@ func TestNewPunchyFromConfig(t *testing.T) {
|
||||
c := config.NewC(l)
|
||||
|
||||
// Test defaults
|
||||
p := NewPunchyFromConfig(l, c)
|
||||
p := NewPunchyFromConfig(test.NewLogger(), c)
|
||||
assert.False(t, p.GetPunch())
|
||||
assert.False(t, p.GetRespond())
|
||||
assert.Equal(t, time.Second, p.GetDelay())
|
||||
@@ -23,33 +25,33 @@ func TestNewPunchyFromConfig(t *testing.T) {
|
||||
|
||||
// punchy deprecation
|
||||
c.Settings["punchy"] = true
|
||||
p = NewPunchyFromConfig(l, c)
|
||||
p = NewPunchyFromConfig(test.NewLogger(), c)
|
||||
assert.True(t, p.GetPunch())
|
||||
|
||||
// punchy.punch
|
||||
c.Settings["punchy"] = map[string]any{"punch": true}
|
||||
p = NewPunchyFromConfig(l, c)
|
||||
p = NewPunchyFromConfig(test.NewLogger(), c)
|
||||
assert.True(t, p.GetPunch())
|
||||
|
||||
// punch_back deprecation
|
||||
c.Settings["punch_back"] = true
|
||||
p = NewPunchyFromConfig(l, c)
|
||||
p = NewPunchyFromConfig(test.NewLogger(), c)
|
||||
assert.True(t, p.GetRespond())
|
||||
|
||||
// punchy.respond
|
||||
c.Settings["punchy"] = map[string]any{"respond": true}
|
||||
c.Settings["punch_back"] = false
|
||||
p = NewPunchyFromConfig(l, c)
|
||||
p = NewPunchyFromConfig(test.NewLogger(), c)
|
||||
assert.True(t, p.GetRespond())
|
||||
|
||||
// punchy.delay
|
||||
c.Settings["punchy"] = map[string]any{"delay": "1m"}
|
||||
p = NewPunchyFromConfig(l, c)
|
||||
p = NewPunchyFromConfig(test.NewLogger(), c)
|
||||
assert.Equal(t, time.Minute, p.GetDelay())
|
||||
|
||||
// punchy.respond_delay
|
||||
c.Settings["punchy"] = map[string]any{"respond_delay": "1m"}
|
||||
p = NewPunchyFromConfig(l, c)
|
||||
p = NewPunchyFromConfig(test.NewLogger(), c)
|
||||
assert.Equal(t, time.Minute, p.GetRespondDelay())
|
||||
}
|
||||
|
||||
@@ -62,7 +64,7 @@ punchy:
|
||||
delay: 1m
|
||||
respond: false
|
||||
`))
|
||||
p := NewPunchyFromConfig(l, c)
|
||||
p := NewPunchyFromConfig(test.NewLogger(), c)
|
||||
assert.Equal(t, delay, p.GetDelay())
|
||||
assert.False(t, p.GetRespond())
|
||||
|
||||
@@ -76,3 +78,158 @@ punchy:
|
||||
assert.Equal(t, newDelay, p.GetDelay())
|
||||
assert.True(t, p.GetRespond())
|
||||
}
|
||||
|
||||
// The tests below pin the shape of each log line Punchy produces so changes
|
||||
// cannot silently break whatever operators are grepping for. The assertions
|
||||
// are on the structured message + attrs (e.g. "punchy.respond changed" with
|
||||
// a respond=true field) rather than a formatted string.
|
||||
//
|
||||
// Punchy.reload also emits a spurious "Changing punchy.punch with reload is
|
||||
// not supported" warning whenever any key under punchy changes, because of
|
||||
// the c.HasChanged("punchy") fallback kept for the deprecated top-level
|
||||
// punchy form. The tests filter by message rather than asserting total
|
||||
// entry counts so that warning is tolerated without being locked into
|
||||
// the format.
|
||||
|
||||
type capturedEntry struct {
|
||||
Level slog.Level
|
||||
Msg string
|
||||
Attrs map[string]any
|
||||
}
|
||||
|
||||
// capturingHandler is a slog.Handler that records each Record it receives so
|
||||
// tests can assert on the level, message, and attribute map of individual log
|
||||
// lines without coupling to any specific text format.
|
||||
type capturingHandler struct {
|
||||
entries []capturedEntry
|
||||
}
|
||||
|
||||
func (h *capturingHandler) Enabled(_ context.Context, _ slog.Level) bool { return true }
|
||||
|
||||
func (h *capturingHandler) Handle(_ context.Context, r slog.Record) error {
|
||||
e := capturedEntry{
|
||||
Level: r.Level,
|
||||
Msg: r.Message,
|
||||
Attrs: make(map[string]any),
|
||||
}
|
||||
r.Attrs(func(a slog.Attr) bool {
|
||||
e.Attrs[a.Key] = a.Value.Resolve().Any()
|
||||
return true
|
||||
})
|
||||
h.entries = append(h.entries, e)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *capturingHandler) WithAttrs(_ []slog.Attr) slog.Handler { return h }
|
||||
func (h *capturingHandler) WithGroup(_ string) slog.Handler { return h }
|
||||
|
||||
func newCapturingPunchyLogger(t *testing.T) (*slog.Logger, *capturingHandler) {
|
||||
t.Helper()
|
||||
hook := &capturingHandler{}
|
||||
return slog.New(hook), hook
|
||||
}
|
||||
|
||||
func findEntry(t *testing.T, entries []capturedEntry, msg string) capturedEntry {
|
||||
t.Helper()
|
||||
for _, e := range entries {
|
||||
if e.Msg == msg {
|
||||
return e
|
||||
}
|
||||
}
|
||||
t.Fatalf("no entry with message %q among %d entries", msg, len(entries))
|
||||
return capturedEntry{}
|
||||
}
|
||||
|
||||
func TestPunchy_LogFormat_InitialEnabled(t *testing.T) {
|
||||
l, hook := newCapturingPunchyLogger(t)
|
||||
c := config.NewC(test.NewLogger())
|
||||
require.NoError(t, c.LoadString(`punchy: {punch: true}`))
|
||||
|
||||
NewPunchyFromConfig(l, c)
|
||||
|
||||
entry := findEntry(t, hook.entries, "punchy enabled")
|
||||
assert.Equal(t, slog.LevelInfo, entry.Level)
|
||||
assert.Empty(t, entry.Attrs)
|
||||
}
|
||||
|
||||
func TestPunchy_LogFormat_InitialDisabled(t *testing.T) {
|
||||
l, hook := newCapturingPunchyLogger(t)
|
||||
c := config.NewC(test.NewLogger())
|
||||
require.NoError(t, c.LoadString(`punchy: {punch: false}`))
|
||||
|
||||
NewPunchyFromConfig(l, c)
|
||||
|
||||
entry := findEntry(t, hook.entries, "punchy disabled")
|
||||
assert.Equal(t, slog.LevelInfo, entry.Level)
|
||||
assert.Empty(t, entry.Attrs)
|
||||
}
|
||||
|
||||
func TestPunchy_LogFormat_ReloadPunchUnsupported(t *testing.T) {
|
||||
l, hook := newCapturingPunchyLogger(t)
|
||||
c := config.NewC(test.NewLogger())
|
||||
require.NoError(t, c.LoadString(`punchy: {punch: false}`))
|
||||
NewPunchyFromConfig(l, c)
|
||||
hook.entries = nil
|
||||
|
||||
require.NoError(t, c.ReloadConfigString(`punchy: {punch: true}`))
|
||||
|
||||
entry := findEntry(t, hook.entries, "Changing punchy.punch with reload is not supported, ignoring.")
|
||||
assert.Equal(t, slog.LevelWarn, entry.Level)
|
||||
assert.Empty(t, entry.Attrs)
|
||||
}
|
||||
|
||||
func TestPunchy_LogFormat_ReloadRespond(t *testing.T) {
|
||||
l, hook := newCapturingPunchyLogger(t)
|
||||
c := config.NewC(test.NewLogger())
|
||||
require.NoError(t, c.LoadString(`punchy: {respond: false}`))
|
||||
NewPunchyFromConfig(l, c)
|
||||
hook.entries = nil
|
||||
|
||||
require.NoError(t, c.ReloadConfigString(`punchy: {respond: true}`))
|
||||
|
||||
entry := findEntry(t, hook.entries, "punchy.respond changed")
|
||||
assert.Equal(t, slog.LevelInfo, entry.Level)
|
||||
assert.Equal(t, map[string]any{"respond": true}, entry.Attrs)
|
||||
}
|
||||
|
||||
func TestPunchy_LogFormat_ReloadDelay(t *testing.T) {
|
||||
l, hook := newCapturingPunchyLogger(t)
|
||||
c := config.NewC(test.NewLogger())
|
||||
require.NoError(t, c.LoadString(`punchy: {delay: 1s}`))
|
||||
NewPunchyFromConfig(l, c)
|
||||
hook.entries = nil
|
||||
|
||||
require.NoError(t, c.ReloadConfigString(`punchy: {delay: 10s}`))
|
||||
|
||||
entry := findEntry(t, hook.entries, "punchy.delay changed")
|
||||
assert.Equal(t, slog.LevelInfo, entry.Level)
|
||||
assert.Equal(t, map[string]any{"delay": 10 * time.Second}, entry.Attrs)
|
||||
}
|
||||
|
||||
func TestPunchy_LogFormat_ReloadTargetAllRemotes(t *testing.T) {
|
||||
l, hook := newCapturingPunchyLogger(t)
|
||||
c := config.NewC(test.NewLogger())
|
||||
require.NoError(t, c.LoadString(`punchy: {target_all_remotes: false}`))
|
||||
NewPunchyFromConfig(l, c)
|
||||
hook.entries = nil
|
||||
|
||||
require.NoError(t, c.ReloadConfigString(`punchy: {target_all_remotes: true}`))
|
||||
|
||||
entry := findEntry(t, hook.entries, "punchy.target_all_remotes changed")
|
||||
assert.Equal(t, slog.LevelInfo, entry.Level)
|
||||
assert.Equal(t, map[string]any{"target_all_remotes": true}, entry.Attrs)
|
||||
}
|
||||
|
||||
func TestPunchy_LogFormat_ReloadRespondDelay(t *testing.T) {
|
||||
l, hook := newCapturingPunchyLogger(t)
|
||||
c := config.NewC(test.NewLogger())
|
||||
require.NoError(t, c.LoadString(`punchy: {respond_delay: 5s}`))
|
||||
NewPunchyFromConfig(l, c)
|
||||
hook.entries = nil
|
||||
|
||||
require.NoError(t, c.ReloadConfigString(`punchy: {respond_delay: 15s}`))
|
||||
|
||||
entry := findEntry(t, hook.entries, "punchy.respond_delay changed")
|
||||
assert.Equal(t, slog.LevelInfo, entry.Level)
|
||||
assert.Equal(t, map[string]any{"respond_delay": 15 * time.Second}, entry.Attrs)
|
||||
}
|
||||
|
||||
165
relay_manager.go
165
relay_manager.go
@@ -5,22 +5,22 @@ import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/header"
|
||||
)
|
||||
|
||||
type relayManager struct {
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
hostmap *HostMap
|
||||
amRelay atomic.Bool
|
||||
}
|
||||
|
||||
func NewRelayManager(ctx context.Context, l *logrus.Logger, hostmap *HostMap, c *config.C) *relayManager {
|
||||
func NewRelayManager(ctx context.Context, l *slog.Logger, hostmap *HostMap, c *config.C) *relayManager {
|
||||
rm := &relayManager{
|
||||
l: l,
|
||||
hostmap: hostmap,
|
||||
@@ -29,7 +29,7 @@ func NewRelayManager(ctx context.Context, l *logrus.Logger, hostmap *HostMap, c
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
err := rm.reload(c, false)
|
||||
if err != nil {
|
||||
l.WithError(err).Error("Failed to reload relay_manager")
|
||||
rm.l.Error("Failed to reload relay_manager", "error", err)
|
||||
}
|
||||
})
|
||||
return rm
|
||||
@@ -52,7 +52,7 @@ func (rm *relayManager) setAmRelay(v bool) {
|
||||
|
||||
// AddRelay finds an available relay index on the hostmap, and associates the relay info with it.
|
||||
// relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp.
|
||||
func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) {
|
||||
func AddRelay(l *slog.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) {
|
||||
hm.Lock()
|
||||
defer hm.Unlock()
|
||||
for range 32 {
|
||||
@@ -92,24 +92,24 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti
|
||||
func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) {
|
||||
relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex)
|
||||
if !ok {
|
||||
fields := logrus.Fields{
|
||||
"relay": relayHostInfo.vpnAddrs[0],
|
||||
"initiatorRelayIndex": m.InitiatorRelayIndex,
|
||||
}
|
||||
|
||||
var relayFrom, relayTo any
|
||||
if m.RelayFromAddr == nil {
|
||||
fields["relayFrom"] = m.OldRelayFromAddr
|
||||
relayFrom = m.OldRelayFromAddr
|
||||
} else {
|
||||
fields["relayFrom"] = m.RelayFromAddr
|
||||
relayFrom = m.RelayFromAddr
|
||||
}
|
||||
|
||||
if m.RelayToAddr == nil {
|
||||
fields["relayTo"] = m.OldRelayToAddr
|
||||
relayTo = m.OldRelayToAddr
|
||||
} else {
|
||||
fields["relayTo"] = m.RelayToAddr
|
||||
relayTo = m.RelayToAddr
|
||||
}
|
||||
|
||||
rm.l.WithFields(fields).Info("relayManager failed to update relay")
|
||||
rm.l.Info("relayManager failed to update relay",
|
||||
"relay", relayHostInfo.vpnAddrs[0],
|
||||
"initiatorRelayIndex", m.InitiatorRelayIndex,
|
||||
"relayFrom", relayFrom,
|
||||
"relayTo", relayTo,
|
||||
)
|
||||
return nil, fmt.Errorf("unknown relay")
|
||||
}
|
||||
|
||||
@@ -120,7 +120,7 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) {
|
||||
msg := &NebulaControl{}
|
||||
err := msg.Unmarshal(d)
|
||||
if err != nil {
|
||||
h.logger(f.l).WithError(err).Error("Failed to unmarshal control message")
|
||||
h.logger(f.l).Error("Failed to unmarshal control message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -147,20 +147,20 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) {
|
||||
}
|
||||
|
||||
func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) {
|
||||
rm.l.WithFields(logrus.Fields{
|
||||
"relayFrom": protoAddrToNetAddr(m.RelayFromAddr),
|
||||
"relayTo": protoAddrToNetAddr(m.RelayToAddr),
|
||||
"initiatorRelayIndex": m.InitiatorRelayIndex,
|
||||
"responderRelayIndex": m.ResponderRelayIndex,
|
||||
"vpnAddrs": h.vpnAddrs}).
|
||||
Info("handleCreateRelayResponse")
|
||||
rm.l.Info("handleCreateRelayResponse",
|
||||
"relayFrom", protoAddrToNetAddr(m.RelayFromAddr),
|
||||
"relayTo", protoAddrToNetAddr(m.RelayToAddr),
|
||||
"initiatorRelayIndex", m.InitiatorRelayIndex,
|
||||
"responderRelayIndex", m.ResponderRelayIndex,
|
||||
"vpnAddrs", h.vpnAddrs,
|
||||
)
|
||||
|
||||
target := m.RelayToAddr
|
||||
targetAddr := protoAddrToNetAddr(target)
|
||||
|
||||
relay, err := rm.EstablishRelay(h, m)
|
||||
if err != nil {
|
||||
rm.l.WithError(err).Error("Failed to update relay for relayTo")
|
||||
rm.l.Error("Failed to update relay for relayTo", "error", err)
|
||||
return
|
||||
}
|
||||
// Do I need to complete the relays now?
|
||||
@@ -170,12 +170,12 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
|
||||
// I'm the middle man. Let the initiator know that the I've established the relay they requested.
|
||||
peerHostInfo := rm.hostmap.QueryVpnAddr(relay.PeerAddr)
|
||||
if peerHostInfo == nil {
|
||||
rm.l.WithField("relayTo", relay.PeerAddr).Error("Can't find a HostInfo for peer")
|
||||
rm.l.Error("Can't find a HostInfo for peer", "relayTo", relay.PeerAddr)
|
||||
return
|
||||
}
|
||||
peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr)
|
||||
if !ok {
|
||||
rm.l.WithField("relayTo", peerHostInfo.vpnAddrs[0]).Error("peerRelay does not have Relay state for relayTo")
|
||||
rm.l.Error("peerRelay does not have Relay state for relayTo", "relayTo", peerHostInfo.vpnAddrs[0])
|
||||
return
|
||||
}
|
||||
switch peerRelay.State {
|
||||
@@ -193,12 +193,13 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
|
||||
if v == cert.Version1 {
|
||||
peer := peerHostInfo.vpnAddrs[0]
|
||||
if !peer.Is4() {
|
||||
rm.l.WithField("relayFrom", peer).
|
||||
WithField("relayTo", target).
|
||||
WithField("initiatorRelayIndex", resp.InitiatorRelayIndex).
|
||||
WithField("responderRelayIndex", resp.ResponderRelayIndex).
|
||||
WithField("vpnAddrs", peerHostInfo.vpnAddrs).
|
||||
Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address")
|
||||
rm.l.Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address",
|
||||
"relayFrom", peer,
|
||||
"relayTo", target,
|
||||
"initiatorRelayIndex", resp.InitiatorRelayIndex,
|
||||
"responderRelayIndex", resp.ResponderRelayIndex,
|
||||
"vpnAddrs", peerHostInfo.vpnAddrs,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -213,17 +214,16 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
|
||||
|
||||
msg, err := resp.Marshal()
|
||||
if err != nil {
|
||||
rm.l.WithError(err).
|
||||
Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
|
||||
rm.l.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err)
|
||||
} else {
|
||||
f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||
rm.l.WithFields(logrus.Fields{
|
||||
"relayFrom": resp.RelayFromAddr,
|
||||
"relayTo": resp.RelayToAddr,
|
||||
"initiatorRelayIndex": resp.InitiatorRelayIndex,
|
||||
"responderRelayIndex": resp.ResponderRelayIndex,
|
||||
"vpnAddrs": peerHostInfo.vpnAddrs}).
|
||||
Info("send CreateRelayResponse")
|
||||
rm.l.Info("send CreateRelayResponse",
|
||||
"relayFrom", resp.RelayFromAddr,
|
||||
"relayTo", resp.RelayToAddr,
|
||||
"initiatorRelayIndex", resp.InitiatorRelayIndex,
|
||||
"responderRelayIndex", resp.ResponderRelayIndex,
|
||||
"vpnAddrs", peerHostInfo.vpnAddrs,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -232,17 +232,18 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
||||
from := protoAddrToNetAddr(m.RelayFromAddr)
|
||||
target := protoAddrToNetAddr(m.RelayToAddr)
|
||||
|
||||
logMsg := rm.l.WithFields(logrus.Fields{
|
||||
"relayFrom": from,
|
||||
"relayTo": target,
|
||||
"initiatorRelayIndex": m.InitiatorRelayIndex,
|
||||
"vpnAddrs": h.vpnAddrs})
|
||||
logMsg := rm.l.With(
|
||||
"relayFrom", from,
|
||||
"relayTo", target,
|
||||
"initiatorRelayIndex", m.InitiatorRelayIndex,
|
||||
"vpnAddrs", h.vpnAddrs,
|
||||
)
|
||||
|
||||
logMsg.Info("handleCreateRelayRequest")
|
||||
// Is the source of the relay me? This should never happen, but did happen due to
|
||||
// an issue migrating relays over to newly re-handshaked host info objects.
|
||||
if f.myVpnAddrsTable.Contains(from) {
|
||||
logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
|
||||
logMsg.Error("Discarding relay request from myself", "myIP", from)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -261,37 +262,37 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
||||
if existingRelay.RemoteIndex != m.InitiatorRelayIndex {
|
||||
// We got a brand new Relay request, because its index is different than what we saw before.
|
||||
// This should never happen. The peer should never change an index, once created.
|
||||
logMsg.WithFields(logrus.Fields{
|
||||
"existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
|
||||
logMsg.Error("Existing relay mismatch with CreateRelayRequest",
|
||||
"existingRemoteIndex", existingRelay.RemoteIndex)
|
||||
return
|
||||
}
|
||||
case Disestablished:
|
||||
if existingRelay.RemoteIndex != m.InitiatorRelayIndex {
|
||||
// We got a brand new Relay request, because its index is different than what we saw before.
|
||||
// This should never happen. The peer should never change an index, once created.
|
||||
logMsg.WithFields(logrus.Fields{
|
||||
"existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
|
||||
logMsg.Error("Existing relay mismatch with CreateRelayRequest",
|
||||
"existingRemoteIndex", existingRelay.RemoteIndex)
|
||||
return
|
||||
}
|
||||
// Mark the relay as 'Established' because it's safe to use again
|
||||
h.relayState.UpdateRelayForByIpState(from, Established)
|
||||
case PeerRequested:
|
||||
// I should never be in this state, because I am terminal, not forwarding.
|
||||
logMsg.WithFields(logrus.Fields{
|
||||
"existingRemoteIndex": existingRelay.RemoteIndex,
|
||||
"state": existingRelay.State}).Error("Unexpected Relay State found")
|
||||
logMsg.Error("Unexpected Relay State found",
|
||||
"existingRemoteIndex", existingRelay.RemoteIndex,
|
||||
"state", existingRelay.State)
|
||||
}
|
||||
} else {
|
||||
_, err := AddRelay(rm.l, h, f.hostMap, from, &m.InitiatorRelayIndex, TerminalType, Established)
|
||||
if err != nil {
|
||||
logMsg.WithError(err).Error("Failed to add relay")
|
||||
logMsg.Error("Failed to add relay", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
relay, ok := h.relayState.QueryRelayForByIp(from)
|
||||
if !ok {
|
||||
logMsg.WithField("from", from).Error("Relay State not found")
|
||||
logMsg.Error("Relay State not found", "from", from)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -313,17 +314,16 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
||||
|
||||
msg, err := resp.Marshal()
|
||||
if err != nil {
|
||||
logMsg.
|
||||
WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
|
||||
logMsg.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err)
|
||||
} else {
|
||||
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
|
||||
rm.l.WithFields(logrus.Fields{
|
||||
"relayFrom": from,
|
||||
"relayTo": target,
|
||||
"initiatorRelayIndex": resp.InitiatorRelayIndex,
|
||||
"responderRelayIndex": resp.ResponderRelayIndex,
|
||||
"vpnAddrs": h.vpnAddrs}).
|
||||
Info("send CreateRelayResponse")
|
||||
rm.l.Info("send CreateRelayResponse",
|
||||
"relayFrom", from,
|
||||
"relayTo", target,
|
||||
"initiatorRelayIndex", resp.InitiatorRelayIndex,
|
||||
"responderRelayIndex", resp.ResponderRelayIndex,
|
||||
"vpnAddrs", h.vpnAddrs,
|
||||
)
|
||||
}
|
||||
return
|
||||
} else {
|
||||
@@ -363,12 +363,13 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
||||
|
||||
if v == cert.Version1 {
|
||||
if !h.vpnAddrs[0].Is4() {
|
||||
rm.l.WithField("relayFrom", h.vpnAddrs[0]).
|
||||
WithField("relayTo", target).
|
||||
WithField("initiatorRelayIndex", req.InitiatorRelayIndex).
|
||||
WithField("responderRelayIndex", req.ResponderRelayIndex).
|
||||
WithField("vpnAddr", target).
|
||||
Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address")
|
||||
rm.l.Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address",
|
||||
"relayFrom", h.vpnAddrs[0],
|
||||
"relayTo", target,
|
||||
"initiatorRelayIndex", req.InitiatorRelayIndex,
|
||||
"responderRelayIndex", req.ResponderRelayIndex,
|
||||
"vpnAddr", target,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -383,17 +384,16 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
||||
|
||||
msg, err := req.Marshal()
|
||||
if err != nil {
|
||||
logMsg.
|
||||
WithError(err).Error("relayManager Failed to marshal Control message to create relay")
|
||||
logMsg.Error("relayManager Failed to marshal Control message to create relay", "error", err)
|
||||
} else {
|
||||
f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
|
||||
rm.l.WithFields(logrus.Fields{
|
||||
"relayFrom": h.vpnAddrs[0],
|
||||
"relayTo": target,
|
||||
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
||||
"responderRelayIndex": req.ResponderRelayIndex,
|
||||
"vpnAddr": target}).
|
||||
Info("send CreateRelayRequest")
|
||||
rm.l.Info("send CreateRelayRequest",
|
||||
"relayFrom", h.vpnAddrs[0],
|
||||
"relayTo", target,
|
||||
"initiatorRelayIndex", req.InitiatorRelayIndex,
|
||||
"responderRelayIndex", req.ResponderRelayIndex,
|
||||
"vpnAddr", target,
|
||||
)
|
||||
}
|
||||
|
||||
// Also track the half-created Relay state just received
|
||||
@@ -401,8 +401,7 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
||||
if !ok {
|
||||
_, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, PeerRequested)
|
||||
if err != nil {
|
||||
logMsg.
|
||||
WithError(err).Error("relayManager Failed to allocate a local index for relay")
|
||||
logMsg.Error("relayManager Failed to allocate a local index for relay", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
@@ -10,8 +11,6 @@ import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// forEachFunc is used to benefit folks that want to do work inside the lock
|
||||
@@ -66,11 +65,11 @@ type hostnamesResults struct {
|
||||
network string
|
||||
lookupTimeout time.Duration
|
||||
cancelFn func()
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
ips atomic.Pointer[map[netip.AddrPort]struct{}]
|
||||
}
|
||||
|
||||
func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) {
|
||||
func NewHostnameResults(ctx context.Context, l *slog.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) {
|
||||
r := &hostnamesResults{
|
||||
hostnames: make([]hostnamePort, len(hostPorts)),
|
||||
network: network,
|
||||
@@ -121,7 +120,11 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
|
||||
addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, r.network, hostPort.name)
|
||||
timeoutCancel()
|
||||
if err != nil {
|
||||
l.WithFields(logrus.Fields{"hostname": hostPort.name, "network": r.network}).WithError(err).Error("DNS resolution failed for static_map host")
|
||||
l.Error("DNS resolution failed for static_map host",
|
||||
"hostname", hostPort.name,
|
||||
"network", r.network,
|
||||
"error", err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
for _, a := range addrs {
|
||||
@@ -145,7 +148,10 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
|
||||
}
|
||||
}
|
||||
if different {
|
||||
l.WithFields(logrus.Fields{"origSet": origSet, "newSet": netipAddrs}).Info("DNS results changed for host list")
|
||||
l.Info("DNS results changed for host list",
|
||||
"origSet", origSet,
|
||||
"newSet", netipAddrs,
|
||||
)
|
||||
r.ips.Store(&netipAddrs)
|
||||
onUpdate()
|
||||
}
|
||||
|
||||
@@ -10,11 +10,11 @@ import (
|
||||
"time"
|
||||
|
||||
"dario.cat/mergo"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/cert_test"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/logging"
|
||||
"github.com/slackhq/nebula/overlay"
|
||||
"go.yaml.in/yaml/v3"
|
||||
"golang.org/x/sync/errgroup"
|
||||
@@ -75,8 +75,7 @@ func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp n
|
||||
panic(err)
|
||||
}
|
||||
|
||||
logger := logrus.New()
|
||||
logger.Out = os.Stdout
|
||||
logger := logging.NewLogger(os.Stdout)
|
||||
|
||||
control, err := nebula.Main(&c, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
|
||||
if err != nil {
|
||||
|
||||
85
ssh.go
85
ssh.go
@@ -6,21 +6,21 @@ import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/logging"
|
||||
"github.com/slackhq/nebula/sshd"
|
||||
)
|
||||
|
||||
@@ -57,12 +57,12 @@ type sshDeviceInfoFlags struct {
|
||||
Pretty bool
|
||||
}
|
||||
|
||||
func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) {
|
||||
func wireSSHReload(l *slog.Logger, ssh *sshd.SSHServer, c *config.C) {
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
if c.GetBool("sshd.enabled", false) {
|
||||
sshRun, err := configSSH(l, ssh, c)
|
||||
if err != nil {
|
||||
l.WithError(err).Error("Failed to reconfigure the sshd")
|
||||
l.Error("Failed to reconfigure the sshd", "error", err)
|
||||
ssh.Stop()
|
||||
}
|
||||
if sshRun != nil {
|
||||
@@ -78,7 +78,7 @@ func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) {
|
||||
// updates the passed-in SSHServer. On success, it returns a function
|
||||
// that callers may invoke to run the configured ssh server. On
|
||||
// failure, it returns nil, error.
|
||||
func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) {
|
||||
func configSSH(l *slog.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) {
|
||||
listen := c.GetString("sshd.listen", "")
|
||||
if listen == "" {
|
||||
return nil, fmt.Errorf("sshd.listen must be provided")
|
||||
@@ -120,7 +120,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
|
||||
for _, caAuthorizedKey := range rawCAs {
|
||||
err := ssh.AddTrustedCA(caAuthorizedKey)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("sshCA", caAuthorizedKey).Warn("SSH CA had an error, ignoring")
|
||||
l.Warn("SSH CA had an error, ignoring", "error", err, "sshCA", caAuthorizedKey)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -131,13 +131,13 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
|
||||
for _, rk := range keys {
|
||||
kDef, ok := rk.(map[string]any)
|
||||
if !ok {
|
||||
l.WithField("sshKeyConfig", rk).Warn("Authorized user had an error, ignoring")
|
||||
l.Warn("Authorized user had an error, ignoring", "sshKeyConfig", rk)
|
||||
continue
|
||||
}
|
||||
|
||||
user, ok := kDef["user"].(string)
|
||||
if !ok {
|
||||
l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the user field")
|
||||
l.Warn("Authorized user is missing the user field", "sshKeyConfig", rk)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -146,7 +146,11 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
|
||||
case string:
|
||||
err := ssh.AddAuthorizedKey(user, v)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("sshKeyConfig", rk).WithField("sshKey", v).Warn("Failed to authorize key")
|
||||
l.Warn("Failed to authorize key",
|
||||
"error", err,
|
||||
"sshKeyConfig", rk,
|
||||
"sshKey", v,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -154,19 +158,25 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
|
||||
for _, subK := range v {
|
||||
sk, ok := subK.(string)
|
||||
if !ok {
|
||||
l.WithField("sshKeyConfig", rk).WithField("sshKey", subK).Warn("Did not understand ssh key")
|
||||
l.Warn("Did not understand ssh key",
|
||||
"sshKeyConfig", rk,
|
||||
"sshKey", subK,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
err := ssh.AddAuthorizedKey(user, sk)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("sshKeyConfig", sk).Warn("Failed to authorize key")
|
||||
l.Warn("Failed to authorize key",
|
||||
"error", err,
|
||||
"sshKeyConfig", sk,
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the keys field or was not understood")
|
||||
l.Warn("Authorized user is missing the keys field or was not understood", "sshKeyConfig", rk)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -178,7 +188,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
|
||||
ssh.Stop()
|
||||
runner = func() {
|
||||
if err := ssh.Run(listen); err != nil {
|
||||
l.WithField("err", err).Warn("Failed to run the SSH server")
|
||||
l.Warn("Failed to run the SSH server", "error", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -188,7 +198,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
|
||||
return runner, nil
|
||||
}
|
||||
|
||||
func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) {
|
||||
func attachCommands(l *slog.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) {
|
||||
// sandboxDir defaults to a dir in temp. The intention is that end user will
|
||||
// create this dir as needed. Overriding this config value to "" allows
|
||||
// writing to anywhere in the system.
|
||||
@@ -789,36 +799,45 @@ func sshGetMutexProfile(sandboxDir string, fs any, a []string, w sshd.StringWrit
|
||||
return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a))
|
||||
}
|
||||
|
||||
func sshLogLevel(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error {
|
||||
if len(a) == 0 {
|
||||
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
|
||||
func sshLogLevel(l *slog.Logger, fs any, a []string, w sshd.StringWriter) error {
|
||||
ctrl, ok := l.Handler().(interface {
|
||||
GetLevel() slog.Level
|
||||
SetLevel(slog.Level)
|
||||
})
|
||||
if !ok {
|
||||
return w.WriteLine("Log level is not reconfigurable on this logger")
|
||||
}
|
||||
|
||||
level, err := logrus.ParseLevel(a[0])
|
||||
if len(a) == 0 {
|
||||
return w.WriteLine(fmt.Sprintf("Log level is: %s", logging.LevelName(ctrl.GetLevel())))
|
||||
}
|
||||
|
||||
level, err := logging.ParseLevel(strings.ToLower(a[0]))
|
||||
if err != nil {
|
||||
return w.WriteLine(fmt.Sprintf("Unknown log level %s. Possible log levels: %s", a, logrus.AllLevels))
|
||||
return w.WriteLine(fmt.Sprintf("Unknown log level %s. Possible log levels: trace, debug, info, warn, error", a))
|
||||
}
|
||||
|
||||
l.SetLevel(level)
|
||||
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
|
||||
ctrl.SetLevel(level)
|
||||
return w.WriteLine(fmt.Sprintf("Log level is: %s", logging.LevelName(ctrl.GetLevel())))
|
||||
}
|
||||
|
||||
func sshLogFormat(l *slog.Logger, fs any, a []string, w sshd.StringWriter) error {
|
||||
ctrl, ok := l.Handler().(interface {
|
||||
GetFormat() string
|
||||
SetFormat(string) error
|
||||
})
|
||||
if !ok {
|
||||
return w.WriteLine("Log format is not reconfigurable on this logger")
|
||||
}
|
||||
|
||||
func sshLogFormat(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error {
|
||||
if len(a) == 0 {
|
||||
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
|
||||
return w.WriteLine(fmt.Sprintf("Log format is: %s", ctrl.GetFormat()))
|
||||
}
|
||||
|
||||
logFormat := strings.ToLower(a[0])
|
||||
switch logFormat {
|
||||
case "text":
|
||||
l.Formatter = &logrus.TextFormatter{}
|
||||
case "json":
|
||||
l.Formatter = &logrus.JSONFormatter{}
|
||||
default:
|
||||
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
|
||||
if err := ctrl.SetFormat(strings.ToLower(a[0])); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
|
||||
return w.WriteLine(fmt.Sprintf("Log format is: %s", ctrl.GetFormat()))
|
||||
}
|
||||
|
||||
func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
|
||||
|
||||
@@ -5,16 +5,16 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
|
||||
"github.com/armon/go-radix"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type SSHServer struct {
|
||||
config *ssh.ServerConfig
|
||||
l *logrus.Entry
|
||||
l *slog.Logger
|
||||
|
||||
certChecker *ssh.CertChecker
|
||||
|
||||
@@ -33,7 +33,7 @@ type SSHServer struct {
|
||||
}
|
||||
|
||||
// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen
|
||||
func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
|
||||
func NewSSHServer(l *slog.Logger) (*SSHServer, error) {
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s := &SSHServer{
|
||||
@@ -121,7 +121,7 @@ func (s *SSHServer) AddTrustedCA(pubKey string) error {
|
||||
}
|
||||
|
||||
s.trustedCAs = append(s.trustedCAs, pk)
|
||||
s.l.WithField("sshKey", pubKey).Info("Trusted CA key")
|
||||
s.l.Info("Trusted CA key", "sshKey", pubKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -139,7 +139,10 @@ func (s *SSHServer) AddAuthorizedKey(user, pubKey string) error {
|
||||
}
|
||||
|
||||
tk[string(pk.Marshal())] = true
|
||||
s.l.WithField("sshKey", pubKey).WithField("sshUser", user).Info("Authorized ssh key")
|
||||
s.l.Info("Authorized ssh key",
|
||||
"sshKey", pubKey,
|
||||
"sshUser", user,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -156,7 +159,7 @@ func (s *SSHServer) Run(addr string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
s.l.WithField("sshListener", addr).Info("SSH server is listening")
|
||||
s.l.Info("SSH server is listening", "sshListener", addr)
|
||||
|
||||
// Run loops until there is an error
|
||||
s.run()
|
||||
@@ -172,7 +175,7 @@ func (s *SSHServer) run() {
|
||||
c, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
if !errors.Is(err, net.ErrClosed) {
|
||||
s.l.WithError(err).Warn("Error in listener, shutting down")
|
||||
s.l.Warn("Error in listener, shutting down", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -193,23 +196,29 @@ func (s *SSHServer) run() {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
l := s.l.WithError(err).WithField("remoteAddress", c.RemoteAddr())
|
||||
l := s.l.With(
|
||||
"error", err,
|
||||
"remoteAddress", c.RemoteAddr(),
|
||||
)
|
||||
if conn != nil {
|
||||
l = l.WithField("sshUser", conn.User())
|
||||
l = l.With("sshUser", conn.User())
|
||||
conn.Close()
|
||||
}
|
||||
if fp != "" {
|
||||
l = l.WithField("sshFingerprint", fp)
|
||||
l = l.With("sshFingerprint", fp)
|
||||
}
|
||||
l.Warn("failed to handshake")
|
||||
sessionCancel()
|
||||
return
|
||||
}
|
||||
|
||||
l := s.l.WithField("sshUser", conn.User())
|
||||
l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in")
|
||||
l := s.l.With("sshUser", conn.User())
|
||||
l.Info("ssh user logged in",
|
||||
"remoteAddress", c.RemoteAddr(),
|
||||
"sshFingerprint", fp,
|
||||
)
|
||||
|
||||
NewSession(s.commands, conn, chans, sessionCancel, l.WithField("subsystem", "sshd.session"))
|
||||
NewSession(s.commands, conn, chans, sessionCancel, l.With("subsystem", "sshd.session"))
|
||||
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
@@ -221,7 +230,7 @@ func (s *SSHServer) Stop() {
|
||||
// Close the listener, this will cause all session to terminate as well, see SSHServer.Run
|
||||
if s.listener != nil {
|
||||
if err := s.listener.Close(); err != nil {
|
||||
s.l.WithError(err).Warn("Failed to close the sshd listener")
|
||||
s.l.Warn("Failed to close the sshd listener", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,25 +2,25 @@ package sshd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/anmitsu/go-shlex"
|
||||
"github.com/armon/go-radix"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
type session struct {
|
||||
l *logrus.Entry
|
||||
l *slog.Logger
|
||||
c *ssh.ServerConn
|
||||
term *term.Terminal
|
||||
commands *radix.Tree
|
||||
cancel func()
|
||||
}
|
||||
|
||||
func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, cancel func(), l *logrus.Entry) *session {
|
||||
func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, cancel func(), l *slog.Logger) *session {
|
||||
s := &session{
|
||||
commands: radix.NewFromMap(commands.ToMap()),
|
||||
l: l,
|
||||
@@ -45,14 +45,14 @@ func (s *session) handleChannels(chans <-chan ssh.NewChannel) {
|
||||
defer s.Close()
|
||||
for newChannel := range chans {
|
||||
if newChannel.ChannelType() != "session" {
|
||||
s.l.WithField("sshChannelType", newChannel.ChannelType()).Error("unknown channel type")
|
||||
s.l.Error("unknown channel type", "sshChannelType", newChannel.ChannelType())
|
||||
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
|
||||
continue
|
||||
}
|
||||
|
||||
channel, requests, err := newChannel.Accept()
|
||||
if err != nil {
|
||||
s.l.WithError(err).Warn("could not accept channel")
|
||||
s.l.Warn("could not accept channel", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -95,12 +95,12 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) {
|
||||
return
|
||||
|
||||
default:
|
||||
s.l.WithField("sshRequest", req.Type).Debug("Rejected unknown request")
|
||||
s.l.Debug("Rejected unknown request", "sshRequest", req.Type)
|
||||
err = req.Reply(false, nil)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
s.l.WithError(err).Info("Error handling ssh session requests")
|
||||
s.l.Info("Error handling ssh session requests", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
24
stats.go
24
stats.go
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
@@ -15,14 +16,13 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
)
|
||||
|
||||
// startStats initializes stats from config. On success, if any further work
|
||||
// is needed to serve stats, it returns a func to handle that work. If no
|
||||
// work is needed, it'll return nil. On failure, it returns nil, error.
|
||||
func startStats(l *logrus.Logger, c *config.C, buildVersion string, configTest bool) (func(), error) {
|
||||
func startStats(l *slog.Logger, c *config.C, buildVersion string, configTest bool) (func(), error) {
|
||||
mType := c.GetString("stats.type", "")
|
||||
if mType == "" || mType == "none" {
|
||||
return nil, nil
|
||||
@@ -59,7 +59,7 @@ func startStats(l *logrus.Logger, c *config.C, buildVersion string, configTest b
|
||||
return startFn, nil
|
||||
}
|
||||
|
||||
func startGraphiteStats(l *logrus.Logger, i time.Duration, c *config.C, configTest bool) error {
|
||||
func startGraphiteStats(l *slog.Logger, i time.Duration, c *config.C, configTest bool) error {
|
||||
proto := c.GetString("stats.protocol", "tcp")
|
||||
host := c.GetString("stats.host", "")
|
||||
if host == "" {
|
||||
@@ -73,13 +73,17 @@ func startGraphiteStats(l *logrus.Logger, i time.Duration, c *config.C, configTe
|
||||
}
|
||||
|
||||
if !configTest {
|
||||
l.Infof("Starting graphite. Interval: %s, prefix: %s, addr: %s", i, prefix, addr)
|
||||
l.Info("Starting graphite",
|
||||
"interval", i,
|
||||
"prefix", prefix,
|
||||
"addr", addr.String(),
|
||||
)
|
||||
go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func startPrometheusStats(l *logrus.Logger, i time.Duration, c *config.C, buildVersion string, configTest bool) (func(), error) {
|
||||
func startPrometheusStats(l *slog.Logger, i time.Duration, c *config.C, buildVersion string, configTest bool) (func(), error) {
|
||||
namespace := c.GetString("stats.namespace", "")
|
||||
subsystem := c.GetString("stats.subsystem", "")
|
||||
|
||||
@@ -116,9 +120,15 @@ func startPrometheusStats(l *logrus.Logger, i time.Duration, c *config.C, buildV
|
||||
|
||||
var startFn func()
|
||||
if !configTest {
|
||||
// promhttp.HandlerOpts.ErrorLog needs a stdlib-shaped Println logger,
|
||||
// so bridge our slog.Logger back to a *log.Logger that emits at Error.
|
||||
errLog := slog.NewLogLogger(l.Handler(), slog.LevelError)
|
||||
startFn = func() {
|
||||
l.Infof("Prometheus stats listening on %s at %s", listen, path)
|
||||
http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: l}))
|
||||
l.Info("Prometheus stats listening",
|
||||
"listen", listen,
|
||||
"path", path,
|
||||
)
|
||||
http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: errLog}))
|
||||
log.Fatal(http.ListenAndServe(listen, nil))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,29 +1,73 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/logging"
|
||||
)
|
||||
|
||||
func NewLogger() *logrus.Logger {
|
||||
l := logrus.New()
|
||||
|
||||
// NewLogger returns a *slog.Logger suitable for use in tests. Output goes to
|
||||
// io.Discard by default; set TEST_LOGS=1 (info), 2 (debug), or 3 (trace) to
|
||||
// stream output to stderr for local debugging.
|
||||
func NewLogger() *slog.Logger {
|
||||
v := os.Getenv("TEST_LOGS")
|
||||
if v == "" {
|
||||
l.SetOutput(io.Discard)
|
||||
return l
|
||||
return slog.New(slog.DiscardHandler)
|
||||
}
|
||||
|
||||
level := slog.LevelInfo
|
||||
switch v {
|
||||
case "2":
|
||||
l.SetLevel(logrus.DebugLevel)
|
||||
level = slog.LevelDebug
|
||||
case "3":
|
||||
l.SetLevel(logrus.TraceLevel)
|
||||
default:
|
||||
l.SetLevel(logrus.InfoLevel)
|
||||
level = logging.LevelTrace
|
||||
}
|
||||
return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level}))
|
||||
}
|
||||
|
||||
return l
|
||||
// NewLoggerWithOutput returns a *slog.Logger whose text output is captured by
|
||||
// w. Timestamps are suppressed so tests can assert on exact output without
|
||||
// baking the current time into expected strings.
|
||||
func NewLoggerWithOutput(w io.Writer) *slog.Logger {
|
||||
return slog.New(&stripTimeHandler{inner: slog.NewTextHandler(w, nil)})
|
||||
}
|
||||
|
||||
// NewLoggerWithOutputAndLevel is NewLoggerWithOutput with an explicit level
|
||||
// so tests can exercise Enabled-gated paths.
|
||||
func NewLoggerWithOutputAndLevel(w io.Writer, level slog.Level) *slog.Logger {
|
||||
return slog.New(&stripTimeHandler{inner: slog.NewTextHandler(w, &slog.HandlerOptions{Level: level})})
|
||||
}
|
||||
|
||||
// NewJSONLoggerWithOutput returns a *slog.Logger emitting JSON to w with
|
||||
// timestamps suppressed, for tests that pin the JSON shape.
|
||||
func NewJSONLoggerWithOutput(w io.Writer, level slog.Level) *slog.Logger {
|
||||
return slog.New(&stripTimeHandler{inner: slog.NewJSONHandler(w, &slog.HandlerOptions{Level: level})})
|
||||
}
|
||||
|
||||
// stripTimeHandler zeros each record's time before delegating so slog's
|
||||
// built-in handlers skip emitting the time attribute. Used to avoid
|
||||
// timestamp-dependent assertions in tests without resorting to ReplaceAttr.
|
||||
type stripTimeHandler struct {
|
||||
inner slog.Handler
|
||||
}
|
||||
|
||||
func (h *stripTimeHandler) Enabled(ctx context.Context, l slog.Level) bool {
|
||||
return h.inner.Enabled(ctx, l)
|
||||
}
|
||||
|
||||
func (h *stripTimeHandler) Handle(ctx context.Context, r slog.Record) error {
|
||||
r.Time = time.Time{}
|
||||
return h.inner.Handle(ctx, r)
|
||||
}
|
||||
|
||||
func (h *stripTimeHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||
return &stripTimeHandler{inner: h.inner.WithAttrs(attrs)}
|
||||
}
|
||||
|
||||
func (h *stripTimeHandler) WithGroup(name string) slog.Handler {
|
||||
return &stripTimeHandler{inner: h.inner.WithGroup(name)}
|
||||
}
|
||||
|
||||
@@ -9,11 +9,12 @@ import (
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"log/slog"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
return NewGenericListener(l, ip, port, multi, batch)
|
||||
}
|
||||
|
||||
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"log/slog"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
return NewGenericListener(l, ip, port, multi, batch)
|
||||
}
|
||||
|
||||
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
@@ -22,12 +22,12 @@ type StdConn struct {
|
||||
*net.UDPConn
|
||||
isV4 bool
|
||||
sysFd uintptr
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
}
|
||||
|
||||
var _ Conn = &StdConn{}
|
||||
|
||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
lc := NewListenConfig(multi)
|
||||
pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
|
||||
if err != nil {
|
||||
@@ -176,7 +176,7 @@ func (u *StdConn) ListenOut(r EncReader) error {
|
||||
return err
|
||||
}
|
||||
|
||||
u.l.WithError(err).Error("unexpected udp socket receive error")
|
||||
u.l.Error("unexpected udp socket receive error", "error", err)
|
||||
}
|
||||
|
||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
||||
@@ -196,7 +196,7 @@ func (u *StdConn) Rebind() error {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
u.l.WithError(err).Error("Failed to rebind udp socket")
|
||||
u.l.Error("Failed to rebind udp socket", "error", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -12,22 +12,22 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
)
|
||||
|
||||
type GenericConn struct {
|
||||
*net.UDPConn
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
}
|
||||
|
||||
var _ Conn = &GenericConn{}
|
||||
|
||||
func NewGenericListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
func NewGenericListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
lc := NewListenConfig(multi)
|
||||
pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
|
||||
if err != nil {
|
||||
@@ -88,7 +88,7 @@ func (u *GenericConn) ListenOut(r EncReader) error {
|
||||
// Dampen unexpected message warns to once per minute
|
||||
if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute {
|
||||
lastRecvErr = time.Now()
|
||||
u.l.WithError(err).Warn("unexpected udp socket receive error")
|
||||
u.l.Warn("unexpected udp socket receive error", "error", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -7,13 +7,13 @@ import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
@@ -22,7 +22,7 @@ type StdConn struct {
|
||||
udpConn *net.UDPConn
|
||||
rawConn syscall.RawConn
|
||||
isV4 bool
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
batch int
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ func setReusePort(network, address string, c syscall.RawConn) error {
|
||||
return opErr
|
||||
}
|
||||
|
||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
listen := netip.AddrPortFrom(ip, uint16(port))
|
||||
lc := net.ListenConfig{}
|
||||
if multi {
|
||||
@@ -242,12 +242,12 @@ func (u *StdConn) ReloadConfig(c *config.C) {
|
||||
if err == nil {
|
||||
s, err := u.GetRecvBuffer()
|
||||
if err == nil {
|
||||
u.l.WithField("size", s).Info("listen.read_buffer was set")
|
||||
u.l.Info("listen.read_buffer was set", "size", s)
|
||||
} else {
|
||||
u.l.WithError(err).Warn("Failed to get listen.read_buffer")
|
||||
u.l.Warn("Failed to get listen.read_buffer", "error", err)
|
||||
}
|
||||
} else {
|
||||
u.l.WithError(err).Error("Failed to set listen.read_buffer")
|
||||
u.l.Error("Failed to set listen.read_buffer", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -257,12 +257,12 @@ func (u *StdConn) ReloadConfig(c *config.C) {
|
||||
if err == nil {
|
||||
s, err := u.GetSendBuffer()
|
||||
if err == nil {
|
||||
u.l.WithField("size", s).Info("listen.write_buffer was set")
|
||||
u.l.Info("listen.write_buffer was set", "size", s)
|
||||
} else {
|
||||
u.l.WithError(err).Warn("Failed to get listen.write_buffer")
|
||||
u.l.Warn("Failed to get listen.write_buffer", "error", err)
|
||||
}
|
||||
} else {
|
||||
u.l.WithError(err).Error("Failed to set listen.write_buffer")
|
||||
u.l.Error("Failed to set listen.write_buffer", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -273,12 +273,12 @@ func (u *StdConn) ReloadConfig(c *config.C) {
|
||||
if err == nil {
|
||||
s, err := u.GetSoMark()
|
||||
if err == nil {
|
||||
u.l.WithField("mark", s).Info("listen.so_mark was set")
|
||||
u.l.Info("listen.so_mark was set", "mark", s)
|
||||
} else {
|
||||
u.l.WithError(err).Warn("Failed to get listen.so_mark")
|
||||
u.l.Warn("Failed to get listen.so_mark", "error", err)
|
||||
}
|
||||
} else {
|
||||
u.l.WithError(err).Error("Failed to set listen.so_mark")
|
||||
u.l.Error("Failed to set listen.so_mark", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"log/slog"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
return NewGenericListener(l, ip, port, multi, batch)
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
@@ -17,7 +18,6 @@ import (
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.zx2c4.com/wireguard/conn/winrio"
|
||||
@@ -53,14 +53,14 @@ type ringBuffer struct {
|
||||
|
||||
type RIOConn struct {
|
||||
isOpen atomic.Bool
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
sock windows.Handle
|
||||
rx, tx ringBuffer
|
||||
rq winrio.Rq
|
||||
results [packetsPerRing]winrio.Result
|
||||
}
|
||||
|
||||
func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, error) {
|
||||
func NewRIOListener(l *slog.Logger, addr netip.Addr, port int) (*RIOConn, error) {
|
||||
if !winrio.Initialize() {
|
||||
return nil, errors.New("could not initialize winrio")
|
||||
}
|
||||
@@ -83,7 +83,7 @@ func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, erro
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error {
|
||||
func (u *RIOConn) bind(l *slog.Logger, sa windows.Sockaddr) error {
|
||||
var err error
|
||||
u.sock, err = winrio.Socket(windows.AF_INET6, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
|
||||
if err != nil {
|
||||
@@ -103,7 +103,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error {
|
||||
if err != nil {
|
||||
// This is a best-effort to prevent errors from being returned by the udp recv operation.
|
||||
// Quietly log a failure and continue.
|
||||
l.WithError(err).Debug("failed to set UDP_CONNRESET ioctl")
|
||||
l.Debug("failed to set UDP_CONNRESET ioctl", "error", err)
|
||||
}
|
||||
|
||||
ret = 0
|
||||
@@ -114,7 +114,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error {
|
||||
if err != nil {
|
||||
// This is a best-effort to prevent errors from being returned by the udp recv operation.
|
||||
// Quietly log a failure and continue.
|
||||
l.WithError(err).Debug("failed to set UDP_NETRESET ioctl")
|
||||
l.Debug("failed to set UDP_NETRESET ioctl", "error", err)
|
||||
}
|
||||
|
||||
err = u.rx.Open()
|
||||
@@ -156,7 +156,7 @@ func (u *RIOConn) ListenOut(r EncReader) error {
|
||||
// Dampen unexpected message warns to once per minute
|
||||
if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute {
|
||||
lastRecvErr = time.Now()
|
||||
u.l.WithError(err).Warn("unexpected udp socket receive error")
|
||||
u.l.Warn("unexpected udp socket receive error", "error", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -4,12 +4,13 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/header"
|
||||
)
|
||||
@@ -46,10 +47,10 @@ type TesterConn struct {
|
||||
done chan struct{}
|
||||
closeOnce sync.Once
|
||||
|
||||
l *logrus.Logger
|
||||
l *slog.Logger
|
||||
}
|
||||
|
||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) {
|
||||
func NewListener(l *slog.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) {
|
||||
return &TesterConn{
|
||||
Addr: netip.AddrPortFrom(ip, uint16(port)),
|
||||
RxPackets: make(chan *Packet, 10),
|
||||
@@ -67,11 +68,12 @@ func (u *TesterConn) Send(packet *Packet) {
|
||||
if err := h.Parse(packet.Data); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if u.l.Level >= logrus.DebugLevel {
|
||||
u.l.WithField("header", h).
|
||||
WithField("udpAddr", packet.From).
|
||||
WithField("dataLen", len(packet.Data)).
|
||||
Debug("UDP receiving injected packet")
|
||||
if u.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
u.l.Debug("UDP receiving injected packet",
|
||||
"header", h,
|
||||
"udpAddr", packet.From,
|
||||
"dataLen", len(packet.Data),
|
||||
)
|
||||
}
|
||||
select {
|
||||
case <-u.done:
|
||||
|
||||
@@ -5,14 +5,13 @@ package udp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
if multi {
|
||||
//NOTE: Technically we can support it with RIO but it wouldn't be at the socket level
|
||||
// The udp stack would need to be reworked to hide away the implementation differences between
|
||||
@@ -25,7 +24,7 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
|
||||
return rc, nil
|
||||
}
|
||||
|
||||
l.WithError(err).Error("Falling back to standard udp sockets")
|
||||
l.Error("Falling back to standard udp sockets", "error", err)
|
||||
return NewGenericListener(l, ip, port, multi, batch)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type ContextualError struct {
|
||||
@@ -28,12 +28,12 @@ func ContextualizeIfNeeded(msg string, err error) error {
|
||||
}
|
||||
|
||||
// LogWithContextIfNeeded is a helper function to log an error line for an error or ContextualError
|
||||
func LogWithContextIfNeeded(msg string, err error, l *logrus.Logger) {
|
||||
func LogWithContextIfNeeded(msg string, err error, l *slog.Logger) {
|
||||
switch v := err.(type) {
|
||||
case *ContextualError:
|
||||
v.Log(l)
|
||||
default:
|
||||
l.WithError(err).Error(msg)
|
||||
l.Error(msg, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,10 +51,19 @@ func (ce *ContextualError) Unwrap() error {
|
||||
return ce.RealError
|
||||
}
|
||||
|
||||
func (ce *ContextualError) Log(lr *logrus.Logger) {
|
||||
// Log emits ce as a single error-level log line with Fields and RealError
|
||||
// promoted to top-level attributes, producing a flat shape callers can grep
|
||||
// or parse without walking into a nested object.
|
||||
func (ce *ContextualError) Log(l *slog.Logger) {
|
||||
attrs := make([]slog.Attr, 0, len(ce.Fields)+1)
|
||||
for k, v := range ce.Fields {
|
||||
attrs = append(attrs, slog.Any(k, v))
|
||||
}
|
||||
if ce.RealError != nil {
|
||||
lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context)
|
||||
} else {
|
||||
lr.WithFields(ce.Fields).Error(ce.Context)
|
||||
attrs = append(attrs, slog.Any("error", ce.RealError))
|
||||
}
|
||||
// LogAttrs is intentional: attrs is built from a map[string]any so it has
|
||||
// no pair-form equivalent.
|
||||
//nolint:sloglint
|
||||
l.LogAttrs(context.Background(), slog.LevelError, ce.Context, attrs...)
|
||||
}
|
||||
|
||||
@@ -1,95 +1,67 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type m = map[string]any
|
||||
|
||||
type TestLogWriter struct {
|
||||
Logs []string
|
||||
}
|
||||
|
||||
func NewTestLogWriter() *TestLogWriter {
|
||||
return &TestLogWriter{Logs: make([]string, 0)}
|
||||
}
|
||||
|
||||
func (tl *TestLogWriter) Write(p []byte) (n int, err error) {
|
||||
tl.Logs = append(tl.Logs, string(p))
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (tl *TestLogWriter) Reset() {
|
||||
tl.Logs = tl.Logs[:0]
|
||||
}
|
||||
|
||||
func TestContextualError_Log(t *testing.T) {
|
||||
l := logrus.New()
|
||||
l.Formatter = &logrus.TextFormatter{
|
||||
DisableTimestamp: true,
|
||||
DisableColors: true,
|
||||
}
|
||||
|
||||
tl := NewTestLogWriter()
|
||||
l.Out = tl
|
||||
buf := &bytes.Buffer{}
|
||||
l := test.NewLoggerWithOutput(buf)
|
||||
|
||||
// Test a full context line
|
||||
tl.Reset()
|
||||
buf.Reset()
|
||||
e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
|
||||
e.Log(l)
|
||||
assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs)
|
||||
assert.Equal(t, "level=ERROR msg=\"test message\" field=1 error=error\n", buf.String())
|
||||
|
||||
// Test a line with an error and msg but no fields
|
||||
tl.Reset()
|
||||
buf.Reset()
|
||||
e = NewContextualError("test message", nil, errors.New("error"))
|
||||
e.Log(l)
|
||||
assert.Equal(t, []string{"level=error msg=\"test message\" error=error\n"}, tl.Logs)
|
||||
assert.Equal(t, "level=ERROR msg=\"test message\" error=error\n", buf.String())
|
||||
|
||||
// Test just a context and fields
|
||||
tl.Reset()
|
||||
buf.Reset()
|
||||
e = NewContextualError("test message", m{"field": "1"}, nil)
|
||||
e.Log(l)
|
||||
assert.Equal(t, []string{"level=error msg=\"test message\" field=1\n"}, tl.Logs)
|
||||
assert.Equal(t, "level=ERROR msg=\"test message\" field=1\n", buf.String())
|
||||
|
||||
// Test just a context
|
||||
tl.Reset()
|
||||
buf.Reset()
|
||||
e = NewContextualError("test message", nil, nil)
|
||||
e.Log(l)
|
||||
assert.Equal(t, []string{"level=error msg=\"test message\"\n"}, tl.Logs)
|
||||
assert.Equal(t, "level=ERROR msg=\"test message\"\n", buf.String())
|
||||
|
||||
// Test just an error
|
||||
tl.Reset()
|
||||
buf.Reset()
|
||||
e = NewContextualError("", nil, errors.New("error"))
|
||||
e.Log(l)
|
||||
assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs)
|
||||
assert.Equal(t, "level=ERROR msg=\"\" error=error\n", buf.String())
|
||||
}
|
||||
|
||||
func TestLogWithContextIfNeeded(t *testing.T) {
|
||||
l := logrus.New()
|
||||
l.Formatter = &logrus.TextFormatter{
|
||||
DisableTimestamp: true,
|
||||
DisableColors: true,
|
||||
}
|
||||
|
||||
tl := NewTestLogWriter()
|
||||
l.Out = tl
|
||||
buf := &bytes.Buffer{}
|
||||
l := test.NewLoggerWithOutput(buf)
|
||||
|
||||
// Test ignoring fallback context
|
||||
tl.Reset()
|
||||
buf.Reset()
|
||||
e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
|
||||
LogWithContextIfNeeded("This should get thrown away", e, l)
|
||||
assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs)
|
||||
assert.Equal(t, "level=ERROR msg=\"test message\" field=1 error=error\n", buf.String())
|
||||
|
||||
// Test using fallback context
|
||||
tl.Reset()
|
||||
buf.Reset()
|
||||
err := fmt.Errorf("this is a normal error")
|
||||
LogWithContextIfNeeded("Fallback context woo", err, l)
|
||||
assert.Equal(t, []string{"level=error msg=\"Fallback context woo\" error=\"this is a normal error\"\n"}, tl.Logs)
|
||||
assert.Equal(t, "level=ERROR msg=\"Fallback context woo\" error=\"this is a normal error\"\n", buf.String())
|
||||
}
|
||||
|
||||
func TestContextualizeIfNeeded(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user