mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-15 20:37:36 +02:00
Switch to slog, remove logrus (#1672)
This commit is contained in:
@@ -2,7 +2,21 @@ version: "2"
|
|||||||
linters:
|
linters:
|
||||||
default: none
|
default: none
|
||||||
enable:
|
enable:
|
||||||
|
- sloglint
|
||||||
- testifylint
|
- testifylint
|
||||||
|
settings:
|
||||||
|
sloglint:
|
||||||
|
# Enforce key-value pair form for Info/Debug/Warn/Error/Log/With and
|
||||||
|
# the package-level slog equivalents. Use l.Log(ctx, level, ...) for
|
||||||
|
# custom levels instead of LogAttrs when you can.
|
||||||
|
#
|
||||||
|
# LogAttrs is also flagged by this rule because it takes ...slog.Attr;
|
||||||
|
# the few legitimate sites (where attrs is built up as a []slog.Attr)
|
||||||
|
# carry a //nolint:sloglint with rationale.
|
||||||
|
kv-only: true
|
||||||
|
# no-mixed-args is on by default: forbids mixing kv and attrs in one call.
|
||||||
|
# discard-handler is on by default (since Go 1.24): suggests
|
||||||
|
# slog.DiscardHandler over slog.NewTextHandler(io.Discard, nil).
|
||||||
exclusions:
|
exclusions:
|
||||||
generated: lax
|
generated: lax
|
||||||
presets:
|
presets:
|
||||||
|
|||||||
38
bits.go
38
bits.go
@@ -1,8 +1,10 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Bits struct {
|
type Bits struct {
|
||||||
@@ -30,7 +32,7 @@ func NewBits(bits uint64) *Bits {
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bits) Check(l *logrus.Logger, i uint64) bool {
|
func (b *Bits) Check(l *slog.Logger, i uint64) bool {
|
||||||
// If i is the next number, return true.
|
// If i is the next number, return true.
|
||||||
if i > b.current {
|
if i > b.current {
|
||||||
return true
|
return true
|
||||||
@@ -42,13 +44,16 @@ func (b *Bits) Check(l *logrus.Logger, i uint64) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Not within the window
|
// Not within the window
|
||||||
if l.Level >= logrus.DebugLevel {
|
if l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
|
l.Debug("rejected a packet (top)",
|
||||||
|
"current", b.current,
|
||||||
|
"incoming", i,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
func (b *Bits) Update(l *slog.Logger, i uint64) bool {
|
||||||
// If i is the next number, return true and update current.
|
// If i is the next number, return true and update current.
|
||||||
if i == b.current+1 {
|
if i == b.current+1 {
|
||||||
// Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter
|
// Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter
|
||||||
@@ -87,9 +92,13 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
|||||||
// Check to see if it's a duplicate
|
// Check to see if it's a duplicate
|
||||||
if i > b.current-b.length || i < b.length && b.current < b.length {
|
if i > b.current-b.length || i < b.length && b.current < b.length {
|
||||||
if b.current == i || b.bits[i%b.length] == true {
|
if b.current == i || b.bits[i%b.length] == true {
|
||||||
if l.Level >= logrus.DebugLevel {
|
if l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}).
|
l.Debug("Receive window",
|
||||||
Debug("Receive window")
|
"accepted", false,
|
||||||
|
"currentCounter", b.current,
|
||||||
|
"incomingCounter", i,
|
||||||
|
"reason", "duplicate",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
b.dupeCounter.Inc(1)
|
b.dupeCounter.Inc(1)
|
||||||
return false
|
return false
|
||||||
@@ -101,12 +110,13 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
|||||||
|
|
||||||
// In all other cases, fail and don't change current.
|
// In all other cases, fail and don't change current.
|
||||||
b.outOfWindowCounter.Inc(1)
|
b.outOfWindowCounter.Inc(1)
|
||||||
if l.Level >= logrus.DebugLevel {
|
if l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
l.WithField("accepted", false).
|
l.Debug("Receive window",
|
||||||
WithField("currentCounter", b.current).
|
"accepted", false,
|
||||||
WithField("incomingCounter", i).
|
"currentCounter", b.current,
|
||||||
WithField("reason", "nonsense").
|
"incomingCounter", i,
|
||||||
Debug("Receive window")
|
"reason", "nonsense",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,8 +3,15 @@
|
|||||||
|
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import "github.com/sirupsen/logrus"
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
|
||||||
func HookLogger(l *logrus.Logger) {
|
"github.com/slackhq/nebula/logging"
|
||||||
// Do nothing, let the logs flow to stdout/stderr
|
)
|
||||||
|
|
||||||
|
// newPlatformLogger returns a *slog.Logger that writes to stdout. Non-Windows
|
||||||
|
// platforms have no special sink to integrate with.
|
||||||
|
func newPlatformLogger() *slog.Logger {
|
||||||
|
return logging.NewLogger(os.Stdout)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,54 +1,86 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"context"
|
||||||
"io/ioutil"
|
"log/slog"
|
||||||
"os"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/kardianos/service"
|
"github.com/slackhq/nebula/logging"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// HookLogger routes the logrus logs through the service logger so that they end up in the Windows Event Viewer
|
// newPlatformLogger returns a *slog.Logger that routes every log record
|
||||||
// logrus output will be discarded
|
// through the Windows service logger so records end up in the Windows
|
||||||
func HookLogger(l *logrus.Logger) {
|
// Event Log. All the heavy lifting (level management, format swap,
|
||||||
l.AddHook(newLogHook(logger))
|
// timestamp toggle, WithAttrs/WithGroup) comes from logging.NewHandler;
|
||||||
l.SetOutput(ioutil.Discard)
|
// this file only contributes:
|
||||||
|
//
|
||||||
|
// - an io.Writer that forwards each formatted line to the service
|
||||||
|
// logger at the current record's Event Log severity, and
|
||||||
|
// - a thin severityTag that embeds *logging.Handler and overrides
|
||||||
|
// only Handle / WithAttrs / WithGroup, so Event Viewer's severity
|
||||||
|
// column and severity-based filters keep working the way they did
|
||||||
|
// before the slog migration.
|
||||||
|
//
|
||||||
|
// Format (text vs json) is carried by the embedded *logging.Handler, so
|
||||||
|
// logging.format: json in config still produces JSON lines in Event
|
||||||
|
// Viewer, same as the pre-slog logrus setup.
|
||||||
|
func newPlatformLogger() *slog.Logger {
|
||||||
|
w := &eventLogWriter{}
|
||||||
|
return slog.New(&severityTag{Handler: logging.NewHandler(w), w: w})
|
||||||
}
|
}
|
||||||
|
|
||||||
type logHook struct {
|
// eventLogWriter forwards slog-formatted lines to the Windows service
|
||||||
sl service.Logger
|
// logger at the severity most recently stashed by severityTag.Handle.
|
||||||
|
// The mutex serializes the stash + inner.Handle + Write cycle per record
|
||||||
|
// across all concurrent goroutines; slog's builtin text/json handlers
|
||||||
|
// each hold their own mutex around Write, but that only protects the
|
||||||
|
// Write call itself, not our stash-then-handle sequence.
|
||||||
|
type eventLogWriter struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
level slog.Level
|
||||||
}
|
}
|
||||||
|
|
||||||
func newLogHook(sl service.Logger) *logHook {
|
func (w *eventLogWriter) Write(p []byte) (int, error) {
|
||||||
return &logHook{sl: sl}
|
line := strings.TrimRight(string(p), "\n")
|
||||||
}
|
switch {
|
||||||
|
case w.level >= slog.LevelError:
|
||||||
func (h *logHook) Fire(entry *logrus.Entry) error {
|
return len(p), logger.Error(line)
|
||||||
line, err := entry.String()
|
case w.level >= slog.LevelWarn:
|
||||||
if err != nil {
|
return len(p), logger.Warning(line)
|
||||||
fmt.Fprintf(os.Stderr, "Unable to read entry, %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch entry.Level {
|
|
||||||
case logrus.PanicLevel:
|
|
||||||
return h.sl.Error(line)
|
|
||||||
case logrus.FatalLevel:
|
|
||||||
return h.sl.Error(line)
|
|
||||||
case logrus.ErrorLevel:
|
|
||||||
return h.sl.Error(line)
|
|
||||||
case logrus.WarnLevel:
|
|
||||||
return h.sl.Warning(line)
|
|
||||||
case logrus.InfoLevel:
|
|
||||||
return h.sl.Info(line)
|
|
||||||
case logrus.DebugLevel:
|
|
||||||
return h.sl.Info(line)
|
|
||||||
default:
|
default:
|
||||||
return nil
|
return len(p), logger.Info(line)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *logHook) Levels() []logrus.Level {
|
// severityTag embeds *logging.Handler to pick up everything it does for
|
||||||
return logrus.AllLevels
|
// free (Enabled, SetLevel, GetLevel, SetFormat, GetFormat,
|
||||||
|
// SetDisableTimestamp) and overrides only Handle / WithAttrs / WithGroup
|
||||||
|
// so each record's slog.Level is stashed on the writer before formatting
|
||||||
|
// and so derived handlers stay wrapped as severityTag rather than
|
||||||
|
// downgrading to bare *logging.Handler.
|
||||||
|
type severityTag struct {
|
||||||
|
*logging.Handler
|
||||||
|
w *eventLogWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *severityTag) Handle(ctx context.Context, r slog.Record) error {
|
||||||
|
s.w.mu.Lock()
|
||||||
|
defer s.w.mu.Unlock()
|
||||||
|
s.w.level = r.Level
|
||||||
|
return s.Handler.Handle(ctx, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *severityTag) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||||
|
if len(attrs) == 0 {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return &severityTag{Handler: s.Handler.WithAttrs(attrs).(*logging.Handler), w: s.w}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *severityTag) WithGroup(name string) slog.Handler {
|
||||||
|
if name == "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return &severityTag{Handler: s.Handler.WithGroup(name).(*logging.Handler), w: s.w}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ import (
|
|||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/logging"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -50,12 +50,11 @@ func main() {
|
|||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
l := logrus.New()
|
l := logging.NewLogger(os.Stdout)
|
||||||
l.Out = os.Stdout
|
|
||||||
|
|
||||||
if *serviceFlag != "" {
|
if *serviceFlag != "" {
|
||||||
if err := doService(configPath, configTest, Build, serviceFlag); err != nil {
|
if err := doService(configPath, configTest, Build, serviceFlag); err != nil {
|
||||||
l.WithError(err).Error("Service command failed")
|
l.Error("Service command failed", "error", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@@ -74,6 +73,16 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := logging.ApplyConfig(l, c); err != nil {
|
||||||
|
fmt.Printf("failed to apply logging config: %s", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
|
if err := logging.ApplyConfig(l, c); err != nil {
|
||||||
|
l.Error("Failed to reconfigure logger on reload", "error", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
|
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.LogWithContextIfNeeded("Failed to start", err, l)
|
util.LogWithContextIfNeeded("Failed to start", err, l)
|
||||||
@@ -90,7 +99,7 @@ func main() {
|
|||||||
go ctrl.ShutdownBlock()
|
go ctrl.ShutdownBlock()
|
||||||
|
|
||||||
if err := wait(); err != nil {
|
if err := wait(); err != nil {
|
||||||
l.WithError(err).Error("Nebula stopped due to fatal error")
|
l.Error("Nebula stopped due to fatal error", "error", err)
|
||||||
os.Exit(2)
|
os.Exit(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
"github.com/kardianos/service"
|
"github.com/kardianos/service"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger service.Logger
|
var logger service.Logger
|
||||||
@@ -25,8 +25,7 @@ func (p *program) Start(s service.Service) error {
|
|||||||
// Start should not block.
|
// Start should not block.
|
||||||
logger.Info("Nebula service starting.")
|
logger.Info("Nebula service starting.")
|
||||||
|
|
||||||
l := logrus.New()
|
l := newPlatformLogger()
|
||||||
HookLogger(l)
|
|
||||||
|
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
err := c.Load(*p.configPath)
|
err := c.Load(*p.configPath)
|
||||||
@@ -34,6 +33,15 @@ func (p *program) Start(s service.Service) error {
|
|||||||
return fmt.Errorf("failed to load config: %s", err)
|
return fmt.Errorf("failed to load config: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := logging.ApplyConfig(l, c); err != nil {
|
||||||
|
return fmt.Errorf("failed to apply logging config: %s", err)
|
||||||
|
}
|
||||||
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
|
if err := logging.ApplyConfig(l, c); err != nil {
|
||||||
|
l.Error("Failed to reconfigure logger on reload", "error", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
p.control, err = nebula.Main(c, *p.configTest, Build, l, nil)
|
p.control, err = nebula.Main(c, *p.configTest, Build, l, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -85,7 +93,7 @@ func doService(configPath *string, configTest *bool, build string, serviceFlag *
|
|||||||
// Here are what the different loggers are doing:
|
// Here are what the different loggers are doing:
|
||||||
// - `log` is the standard go log utility, meant to be used while the process is still attached to stdout/stderr
|
// - `log` is the standard go log utility, meant to be used while the process is still attached to stdout/stderr
|
||||||
// - `logger` is the service log utility that may be attached to a special place depending on OS (Windows will have it attached to the event log)
|
// - `logger` is the service log utility that may be attached to a special place depending on OS (Windows will have it attached to the event log)
|
||||||
// - above, in `Run` we create a `logrus.Logger` which is what nebula expects to use
|
// - in program.Start we build a *slog.Logger via newPlatformLogger; on non-Windows that is a stdout-backed slog logger, on Windows it routes records through the service logger
|
||||||
s, err := service.New(prg, svcConfig)
|
s, err := service.New(prg, svcConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ import (
|
|||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/logging"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -55,8 +55,7 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
l := logrus.New()
|
l := logging.NewLogger(os.Stdout)
|
||||||
l.Out = os.Stdout
|
|
||||||
|
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
err := c.Load(*configPath)
|
err := c.Load(*configPath)
|
||||||
@@ -65,6 +64,16 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := logging.ApplyConfig(l, c); err != nil {
|
||||||
|
fmt.Printf("failed to apply logging config: %s", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
|
if err := logging.ApplyConfig(l, c); err != nil {
|
||||||
|
l.Error("Failed to reconfigure logger on reload", "error", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
|
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.LogWithContextIfNeeded("Failed to start", err, l)
|
util.LogWithContextIfNeeded("Failed to start", err, l)
|
||||||
@@ -82,7 +91,7 @@ func main() {
|
|||||||
notifyReady(l)
|
notifyReady(l)
|
||||||
|
|
||||||
if err := wait(); err != nil {
|
if err := wait(); err != nil {
|
||||||
l.WithError(err).Error("Nebula stopped due to fatal error")
|
l.Error("Nebula stopped due to fatal error", "error", err)
|
||||||
os.Exit(2)
|
os.Exit(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// SdNotifyReady tells systemd the service is ready and dependent services can now be started
|
// SdNotifyReady tells systemd the service is ready and dependent services can now be started
|
||||||
@@ -13,30 +12,30 @@ import (
|
|||||||
// https://www.freedesktop.org/software/systemd/man/systemd.service.html
|
// https://www.freedesktop.org/software/systemd/man/systemd.service.html
|
||||||
const SdNotifyReady = "READY=1"
|
const SdNotifyReady = "READY=1"
|
||||||
|
|
||||||
func notifyReady(l *logrus.Logger) {
|
func notifyReady(l *slog.Logger) {
|
||||||
sockName := os.Getenv("NOTIFY_SOCKET")
|
sockName := os.Getenv("NOTIFY_SOCKET")
|
||||||
if sockName == "" {
|
if sockName == "" {
|
||||||
l.Debugln("NOTIFY_SOCKET systemd env var not set, not sending ready signal")
|
l.Debug("NOTIFY_SOCKET systemd env var not set, not sending ready signal")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := net.DialTimeout("unixgram", sockName, time.Second)
|
conn, err := net.DialTimeout("unixgram", sockName, time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Error("failed to connect to systemd notification socket")
|
l.Error("failed to connect to systemd notification socket", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
err = conn.SetWriteDeadline(time.Now().Add(time.Second))
|
err = conn.SetWriteDeadline(time.Now().Add(time.Second))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Error("failed to set the write deadline for the systemd notification socket")
|
l.Error("failed to set the write deadline for the systemd notification socket", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err = conn.Write([]byte(SdNotifyReady)); err != nil {
|
if _, err = conn.Write([]byte(SdNotifyReady)); err != nil {
|
||||||
l.WithError(err).Error("failed to signal the systemd notification socket")
|
l.Error("failed to signal the systemd notification socket", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
l.Debugln("notified systemd the service is ready")
|
l.Debug("notified systemd the service is ready")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,8 +3,8 @@
|
|||||||
|
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import "github.com/sirupsen/logrus"
|
import "log/slog"
|
||||||
|
|
||||||
func notifyReady(_ *logrus.Logger) {
|
func notifyReady(_ *slog.Logger) {
|
||||||
// No init service to notify
|
// No init service to notify
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
@@ -16,7 +17,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"dario.cat/mergo"
|
"dario.cat/mergo"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"go.yaml.in/yaml/v3"
|
"go.yaml.in/yaml/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -26,11 +26,11 @@ type C struct {
|
|||||||
Settings map[string]any
|
Settings map[string]any
|
||||||
oldSettings map[string]any
|
oldSettings map[string]any
|
||||||
callbacks []func(*C)
|
callbacks []func(*C)
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
reloadLock sync.Mutex
|
reloadLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewC(l *logrus.Logger) *C {
|
func NewC(l *slog.Logger) *C {
|
||||||
return &C{
|
return &C{
|
||||||
Settings: make(map[string]any),
|
Settings: make(map[string]any),
|
||||||
l: l,
|
l: l,
|
||||||
@@ -107,12 +107,18 @@ func (c *C) HasChanged(k string) bool {
|
|||||||
|
|
||||||
newVals, err := yaml.Marshal(nv)
|
newVals, err := yaml.Marshal(nv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
|
c.l.Error("Error while marshaling new config",
|
||||||
|
"config_path", k,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
oldVals, err := yaml.Marshal(ov)
|
oldVals, err := yaml.Marshal(ov)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
|
c.l.Error("Error while marshaling old config",
|
||||||
|
"config_path", k,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return string(newVals) != string(oldVals)
|
return string(newVals) != string(oldVals)
|
||||||
@@ -154,7 +160,10 @@ func (c *C) ReloadConfig() {
|
|||||||
|
|
||||||
err := c.Load(c.path)
|
err := c.Load(c.path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
|
c.l.Error("Error occurred while reloading config",
|
||||||
|
"config_path", c.path,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,13 +5,13 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
@@ -47,10 +47,10 @@ type connectionManager struct {
|
|||||||
|
|
||||||
metricsTxPunchy metrics.Counter
|
metricsTxPunchy metrics.Counter
|
||||||
|
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
|
func newConnectionManagerFromConfig(l *slog.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
|
||||||
cm := &connectionManager{
|
cm := &connectionManager{
|
||||||
hostMap: hm,
|
hostMap: hm,
|
||||||
l: l,
|
l: l,
|
||||||
@@ -85,9 +85,10 @@ func (cm *connectionManager) reload(c *config.C, initial bool) {
|
|||||||
old := cm.getInactivityTimeout()
|
old := cm.getInactivityTimeout()
|
||||||
cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute)))
|
cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute)))
|
||||||
if !initial {
|
if !initial {
|
||||||
cm.l.WithField("oldDuration", old).
|
cm.l.Info("Inactivity timeout has changed",
|
||||||
WithField("newDuration", cm.getInactivityTimeout()).
|
"oldDuration", old,
|
||||||
Info("Inactivity timeout has changed")
|
"newDuration", cm.getInactivityTimeout(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,9 +96,10 @@ func (cm *connectionManager) reload(c *config.C, initial bool) {
|
|||||||
old := cm.dropInactive.Load()
|
old := cm.dropInactive.Load()
|
||||||
cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false))
|
cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false))
|
||||||
if !initial {
|
if !initial {
|
||||||
cm.l.WithField("oldBool", old).
|
cm.l.Info("Drop inactive setting has changed",
|
||||||
WithField("newBool", cm.dropInactive.Load()).
|
"oldBool", old,
|
||||||
Info("Drop inactive setting has changed")
|
"newBool", cm.dropInactive.Load(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -256,7 +258,7 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
|
|||||||
var err error
|
var err error
|
||||||
index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested)
|
index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cm.l.WithError(err).Error("failed to migrate relay to new hostinfo")
|
cm.l.Error("failed to migrate relay to new hostinfo", "error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
switch r.Type {
|
switch r.Type {
|
||||||
@@ -304,16 +306,16 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
|
|||||||
|
|
||||||
msg, err := req.Marshal()
|
msg, err := req.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cm.l.WithError(err).Error("failed to marshal Control message to migrate relay")
|
cm.l.Error("failed to marshal Control message to migrate relay", "error", err)
|
||||||
} else {
|
} else {
|
||||||
cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
|
cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
cm.l.WithFields(logrus.Fields{
|
cm.l.Info("send CreateRelayRequest",
|
||||||
"relayFrom": req.RelayFromAddr,
|
"relayFrom", req.RelayFromAddr,
|
||||||
"relayTo": req.RelayToAddr,
|
"relayTo", req.RelayToAddr,
|
||||||
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
"initiatorRelayIndex", req.InitiatorRelayIndex,
|
||||||
"responderRelayIndex": req.ResponderRelayIndex,
|
"responderRelayIndex", req.ResponderRelayIndex,
|
||||||
"vpnAddrs": newhostinfo.vpnAddrs}).
|
"vpnAddrs", newhostinfo.vpnAddrs,
|
||||||
Info("send CreateRelayRequest")
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -325,7 +327,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
|
|
||||||
hostinfo := cm.hostMap.Indexes[localIndex]
|
hostinfo := cm.hostMap.Indexes[localIndex]
|
||||||
if hostinfo == nil {
|
if hostinfo == nil {
|
||||||
cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap")
|
cm.l.Debug("Not found in hostmap", "localIndex", localIndex)
|
||||||
return doNothing, nil, nil
|
return doNothing, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -345,10 +347,10 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
// A hostinfo is determined alive if there is incoming traffic
|
// A hostinfo is determined alive if there is incoming traffic
|
||||||
if inTraffic {
|
if inTraffic {
|
||||||
decision := doNothing
|
decision := doNothing
|
||||||
if cm.l.Level >= logrus.DebugLevel {
|
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hostinfo.logger(cm.l).
|
hostinfo.logger(cm.l).Debug("Tunnel status",
|
||||||
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
|
"tunnelCheck", m{"state": "alive", "method": "passive"},
|
||||||
Debug("Tunnel status")
|
)
|
||||||
}
|
}
|
||||||
hostinfo.pendingDeletion.Store(false)
|
hostinfo.pendingDeletion.Store(false)
|
||||||
|
|
||||||
@@ -375,9 +377,9 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
|
|
||||||
if hostinfo.pendingDeletion.Load() {
|
if hostinfo.pendingDeletion.Load() {
|
||||||
// We have already sent a test packet and nothing was returned, this hostinfo is dead
|
// We have already sent a test packet and nothing was returned, this hostinfo is dead
|
||||||
hostinfo.logger(cm.l).
|
hostinfo.logger(cm.l).Info("Tunnel status",
|
||||||
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
|
"tunnelCheck", m{"state": "dead", "method": "active"},
|
||||||
Info("Tunnel status")
|
)
|
||||||
|
|
||||||
return deleteTunnel, hostinfo, nil
|
return deleteTunnel, hostinfo, nil
|
||||||
}
|
}
|
||||||
@@ -388,10 +390,10 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
inactiveFor, isInactive := cm.isInactive(hostinfo, now)
|
inactiveFor, isInactive := cm.isInactive(hostinfo, now)
|
||||||
if isInactive {
|
if isInactive {
|
||||||
// Tunnel is inactive, tear it down
|
// Tunnel is inactive, tear it down
|
||||||
hostinfo.logger(cm.l).
|
hostinfo.logger(cm.l).Info("Dropping tunnel due to inactivity",
|
||||||
WithField("inactiveDuration", inactiveFor).
|
"inactiveDuration", inactiveFor,
|
||||||
WithField("primary", mainHostInfo).
|
"primary", mainHostInfo,
|
||||||
Info("Dropping tunnel due to inactivity")
|
)
|
||||||
|
|
||||||
return closeTunnel, hostinfo, primary
|
return closeTunnel, hostinfo, primary
|
||||||
}
|
}
|
||||||
@@ -410,18 +412,18 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
cm.sendPunch(hostinfo)
|
cm.sendPunch(hostinfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cm.l.Level >= logrus.DebugLevel {
|
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hostinfo.logger(cm.l).
|
hostinfo.logger(cm.l).Debug("Tunnel status",
|
||||||
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
|
"tunnelCheck", m{"state": "testing", "method": "active"},
|
||||||
Debug("Tunnel status")
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
|
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
|
||||||
decision = sendTestPacket
|
decision = sendTestPacket
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if cm.l.Level >= logrus.DebugLevel {
|
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hostinfo.logger(cm.l).Debugf("Hostinfo sadness")
|
hostinfo.logger(cm.l).Debug("Hostinfo sadness")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -493,14 +495,16 @@ func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostI
|
|||||||
return false //cert is still valid! yay!
|
return false //cert is still valid! yay!
|
||||||
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
|
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
|
||||||
// Block listed certificates should always be disconnected
|
// Block listed certificates should always be disconnected
|
||||||
hostinfo.logger(cm.l).WithError(err).
|
hostinfo.logger(cm.l).Info("Remote certificate is blocked, tearing down the tunnel",
|
||||||
WithField("fingerprint", remoteCert.Fingerprint).
|
"error", err,
|
||||||
Info("Remote certificate is blocked, tearing down the tunnel")
|
"fingerprint", remoteCert.Fingerprint,
|
||||||
|
)
|
||||||
return true
|
return true
|
||||||
} else if cm.intf.disconnectInvalid.Load() {
|
} else if cm.intf.disconnectInvalid.Load() {
|
||||||
hostinfo.logger(cm.l).WithError(err).
|
hostinfo.logger(cm.l).Info("Remote certificate is no longer valid, tearing down the tunnel",
|
||||||
WithField("fingerprint", remoteCert.Fingerprint).
|
"error", err,
|
||||||
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
"fingerprint", remoteCert.Fingerprint,
|
||||||
|
)
|
||||||
return true
|
return true
|
||||||
} else {
|
} else {
|
||||||
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
|
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
|
||||||
@@ -539,10 +543,11 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
|||||||
curCrtVersion := curCrt.Version()
|
curCrtVersion := curCrt.Version()
|
||||||
myCrt := cs.getCertificate(curCrtVersion)
|
myCrt := cs.getCertificate(curCrtVersion)
|
||||||
if myCrt == nil {
|
if myCrt == nil {
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
cm.l.Info("Re-handshaking with remote",
|
||||||
WithField("version", curCrtVersion).
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
WithField("reason", "local certificate removed").
|
"version", curCrtVersion,
|
||||||
Info("Re-handshaking with remote")
|
"reason", "local certificate removed",
|
||||||
|
)
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -550,11 +555,12 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
|||||||
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
|
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
|
||||||
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
|
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
|
||||||
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
|
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
cm.l.Info("Re-handshaking with remote",
|
||||||
WithField("version", curCrtVersion).
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
WithField("peerVersion", peerCrt.Certificate.Version()).
|
"version", curCrtVersion,
|
||||||
WithField("reason", "local certificate version lower than peer, attempting to correct").
|
"peerVersion", peerCrt.Certificate.Version(),
|
||||||
Info("Re-handshaking with remote")
|
"reason", "local certificate version lower than peer, attempting to correct",
|
||||||
|
)
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
|
||||||
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
|
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
|
||||||
})
|
})
|
||||||
@@ -562,17 +568,19 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
|
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
cm.l.Info("Re-handshaking with remote",
|
||||||
WithField("reason", "local certificate is not current").
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
Info("Re-handshaking with remote")
|
"reason", "local certificate is not current",
|
||||||
|
)
|
||||||
|
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if curCrtVersion < cs.initiatingVersion {
|
if curCrtVersion < cs.initiatingVersion {
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
cm.l.Info("Re-handshaking with remote",
|
||||||
WithField("reason", "current cert version < pki.initiatingVersion").
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
Info("Re-handshaking with remote")
|
"reason", "current cert version < pki.initiatingVersion",
|
||||||
|
)
|
||||||
|
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/flynn/noise"
|
"github.com/flynn/noise"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/overlay/overlaytest"
|
||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -52,7 +53,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||||||
lh := newTestLighthouse()
|
lh := newTestLighthouse()
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
inside: &test.NoopTun{},
|
inside: &overlaytest.NoopTun{},
|
||||||
outside: &udp.NoopConn{},
|
outside: &udp.NoopConn{},
|
||||||
firewall: &Firewall{},
|
firewall: &Firewall{},
|
||||||
lightHouse: lh,
|
lightHouse: lh,
|
||||||
@@ -63,9 +64,9 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||||||
ifce.pki.cs.Store(cs)
|
ifce.pki.cs.Store(cs)
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
conf := config.NewC(l)
|
conf := config.NewC(test.NewLogger())
|
||||||
punchy := NewPunchyFromConfig(l, conf)
|
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
|
||||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
||||||
nc.intf = ifce
|
nc.intf = ifce
|
||||||
p := []byte("")
|
p := []byte("")
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
@@ -135,7 +136,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||||||
lh := newTestLighthouse()
|
lh := newTestLighthouse()
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
inside: &test.NoopTun{},
|
inside: &overlaytest.NoopTun{},
|
||||||
outside: &udp.NoopConn{},
|
outside: &udp.NoopConn{},
|
||||||
firewall: &Firewall{},
|
firewall: &Firewall{},
|
||||||
lightHouse: lh,
|
lightHouse: lh,
|
||||||
@@ -146,9 +147,9 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||||||
ifce.pki.cs.Store(cs)
|
ifce.pki.cs.Store(cs)
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
conf := config.NewC(l)
|
conf := config.NewC(test.NewLogger())
|
||||||
punchy := NewPunchyFromConfig(l, conf)
|
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
|
||||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
||||||
nc.intf = ifce
|
nc.intf = ifce
|
||||||
p := []byte("")
|
p := []byte("")
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
@@ -220,7 +221,7 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
|
|||||||
lh := newTestLighthouse()
|
lh := newTestLighthouse()
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
inside: &test.NoopTun{},
|
inside: &overlaytest.NoopTun{},
|
||||||
outside: &udp.NoopConn{},
|
outside: &udp.NoopConn{},
|
||||||
firewall: &Firewall{},
|
firewall: &Firewall{},
|
||||||
lightHouse: lh,
|
lightHouse: lh,
|
||||||
@@ -231,12 +232,12 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
|
|||||||
ifce.pki.cs.Store(cs)
|
ifce.pki.cs.Store(cs)
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
conf := config.NewC(l)
|
conf := config.NewC(test.NewLogger())
|
||||||
conf.Settings["tunnels"] = map[string]any{
|
conf.Settings["tunnels"] = map[string]any{
|
||||||
"drop_inactive": true,
|
"drop_inactive": true,
|
||||||
}
|
}
|
||||||
punchy := NewPunchyFromConfig(l, conf)
|
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
|
||||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
||||||
assert.True(t, nc.dropInactive.Load())
|
assert.True(t, nc.dropInactive.Load())
|
||||||
nc.intf = ifce
|
nc.intf = ifce
|
||||||
|
|
||||||
@@ -347,7 +348,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||||||
lh := newTestLighthouse()
|
lh := newTestLighthouse()
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
inside: &test.NoopTun{},
|
inside: &overlaytest.NoopTun{},
|
||||||
outside: &udp.NoopConn{},
|
outside: &udp.NoopConn{},
|
||||||
firewall: &Firewall{},
|
firewall: &Firewall{},
|
||||||
lightHouse: lh,
|
lightHouse: lh,
|
||||||
@@ -360,9 +361,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||||||
ifce.disconnectInvalid.Store(true)
|
ifce.disconnectInvalid.Store(true)
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
conf := config.NewC(l)
|
conf := config.NewC(test.NewLogger())
|
||||||
punchy := NewPunchyFromConfig(l, conf)
|
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
|
||||||
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
||||||
nc.intf = ifce
|
nc.intf = ifce
|
||||||
ifce.connectionManager = nc
|
ifce.connectionManager = nc
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/flynn/noise"
|
"github.com/flynn/noise"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/noiseutil"
|
"github.com/slackhq/nebula/noiseutil"
|
||||||
)
|
)
|
||||||
@@ -27,7 +26,7 @@ type ConnectionState struct {
|
|||||||
writeLock sync.Mutex
|
writeLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
|
func NewConnectionState(cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
|
||||||
var dhFunc noise.DHFunc
|
var dhFunc noise.DHFunc
|
||||||
switch crt.Curve() {
|
switch crt.Curve() {
|
||||||
case cert.Curve_CURVE25519:
|
case cert.Curve_CURVE25519:
|
||||||
|
|||||||
14
control.go
14
control.go
@@ -3,13 +3,13 @@ package nebula
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
@@ -46,7 +46,7 @@ type Control struct {
|
|||||||
state RunState
|
state RunState
|
||||||
|
|
||||||
f *Interface
|
f *Interface
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
sshStart func()
|
sshStart func()
|
||||||
@@ -151,7 +151,7 @@ func (c *Control) Stop() {
|
|||||||
|
|
||||||
c.CloseAllTunnels(false)
|
c.CloseAllTunnels(false)
|
||||||
if err := c.f.Close(); err != nil {
|
if err := c.f.Close(); err != nil {
|
||||||
c.l.WithError(err).Error("Close interface failed")
|
c.l.Error("Close interface failed", "error", err)
|
||||||
}
|
}
|
||||||
c.stateLock.Lock()
|
c.stateLock.Lock()
|
||||||
c.state = StateStopped
|
c.state = StateStopped
|
||||||
@@ -166,7 +166,7 @@ func (c *Control) ShutdownBlock() {
|
|||||||
|
|
||||||
rawSig := <-sigChan
|
rawSig := <-sigChan
|
||||||
sig := rawSig.String()
|
sig := rawSig.String()
|
||||||
c.l.WithField("signal", sig).Info("Caught signal, shutting down")
|
c.l.Info("Caught signal, shutting down", "signal", sig)
|
||||||
c.Stop()
|
c.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -303,8 +303,10 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
|
|||||||
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
||||||
c.f.closeTunnel(h)
|
c.f.closeTunnel(h)
|
||||||
|
|
||||||
c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote).
|
c.l.Debug("Sending close tunnel message",
|
||||||
Debug("Sending close tunnel message")
|
"vpnAddrs", h.vpnAddrs,
|
||||||
|
"udpAddr", h.remote,
|
||||||
|
)
|
||||||
closed++
|
closed++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -83,7 +82,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||||||
f: &Interface{
|
f: &Interface{
|
||||||
hostMap: hm,
|
hostMap: hm,
|
||||||
},
|
},
|
||||||
l: logrus.New(),
|
l: test.NewLogger(),
|
||||||
}
|
}
|
||||||
|
|
||||||
thi := c.GetHostInfoByVpnAddr(vpnIp, false)
|
thi := c.GetHostInfoByVpnAddr(vpnIp, false)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package nebula
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -12,13 +13,12 @@ import (
|
|||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
type dnsServer struct {
|
type dnsServer struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
dnsMap4 map[string]netip.Addr
|
dnsMap4 map[string]netip.Addr
|
||||||
dnsMap6 map[string]netip.Addr
|
dnsMap6 map[string]netip.Addr
|
||||||
@@ -55,7 +55,7 @@ type dnsServer struct {
|
|||||||
// they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel
|
// they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel
|
||||||
// watcher that tears the listener down on nebula shutdown. The returned
|
// watcher that tears the listener down on nebula shutdown. The returned
|
||||||
// pointer is always non-nil, even on error.
|
// pointer is always non-nil, even on error.
|
||||||
func newDnsServerFromConfig(ctx context.Context, l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) {
|
func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) {
|
||||||
ds := &dnsServer{
|
ds := &dnsServer{
|
||||||
l: l,
|
l: l,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
@@ -69,7 +69,7 @@ func newDnsServerFromConfig(ctx context.Context, l *logrus.Logger, cs *CertState
|
|||||||
|
|
||||||
c.RegisterReloadCallback(func(c *config.C) {
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
if err := ds.reload(c, false); err != nil {
|
if err := ds.reload(c, false); err != nil {
|
||||||
l.WithError(err).Error("Failed to reload DNS responder from config")
|
ds.l.Error("Failed to reload DNS responder from config", "error", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -145,7 +145,7 @@ func (d *dnsServer) shutdownServer(srv *dns.Server, started chan struct{}, reaso
|
|||||||
<-started
|
<-started
|
||||||
}
|
}
|
||||||
if err := srv.Shutdown(); err != nil {
|
if err := srv.Shutdown(); err != nil {
|
||||||
d.l.WithError(err).WithField("reason", reason).Warn("Failed to shut down the DNS responder")
|
d.l.Warn("Failed to shut down the DNS responder", "reason", reason, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -188,7 +188,7 @@ func (d *dnsServer) Start() {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
d.l.WithField("dnsListener", addr).Info("Starting DNS responder")
|
d.l.Info("Starting DNS responder", "dnsListener", addr)
|
||||||
err := server.ListenAndServe()
|
err := server.ListenAndServe()
|
||||||
close(done)
|
close(done)
|
||||||
|
|
||||||
@@ -201,7 +201,7 @@ func (d *dnsServer) Start() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
d.l.WithError(err).Warn("Failed to run the DNS responder")
|
d.l.Warn("Failed to run the DNS responder", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -314,6 +314,7 @@ func (d *dnsServer) isSelfNebulaOrLocalhost(addr string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
||||||
|
debugEnabled := d.l.Enabled(context.Background(), slog.LevelDebug)
|
||||||
// Per RFC 2308 §2.2, a name that exists but has no record of the requested
|
// Per RFC 2308 §2.2, a name that exists but has no record of the requested
|
||||||
// type must be answered with NOERROR and an empty answer section (NODATA),
|
// type must be answered with NOERROR and an empty answer section (NODATA),
|
||||||
// not NXDOMAIN (RFC 2308 §2.1), which is reserved for names that do not
|
// not NXDOMAIN (RFC 2308 §2.1), which is reserved for names that do not
|
||||||
@@ -323,7 +324,9 @@ func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
|||||||
switch q.Qtype {
|
switch q.Qtype {
|
||||||
case dns.TypeA, dns.TypeAAAA:
|
case dns.TypeA, dns.TypeAAAA:
|
||||||
qType := dns.TypeToString[q.Qtype]
|
qType := dns.TypeToString[q.Qtype]
|
||||||
d.l.Debugf("Query for %s %s", qType, q.Name)
|
if debugEnabled {
|
||||||
|
d.l.Debug("DNS query", "type", qType, "name", q.Name)
|
||||||
|
}
|
||||||
ip, nameExists := d.Query(q.Qtype, q.Name)
|
ip, nameExists := d.Query(q.Qtype, q.Name)
|
||||||
if nameExists {
|
if nameExists {
|
||||||
anyNameExists = true
|
anyNameExists = true
|
||||||
@@ -339,7 +342,9 @@ func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
|||||||
if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
|
if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
d.l.Debugf("Query for TXT %s", q.Name)
|
if debugEnabled {
|
||||||
|
d.l.Debug("DNS query", "type", "TXT", "name", q.Name)
|
||||||
|
}
|
||||||
ip := d.QueryCert(q.Name)
|
ip := d.QueryCert(q.Name)
|
||||||
if ip != "" {
|
if ip != "" {
|
||||||
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
|
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"io"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -30,7 +29,7 @@ func (stubDNSWriter) TsigTimersOnly(bool) {}
|
|||||||
func (stubDNSWriter) Hijack() {}
|
func (stubDNSWriter) Hijack() {}
|
||||||
|
|
||||||
func TestParsequery(t *testing.T) {
|
func TestParsequery(t *testing.T) {
|
||||||
l := logrus.New()
|
l := slog.New(slog.DiscardHandler)
|
||||||
hostMap := &HostMap{}
|
hostMap := &HostMap{}
|
||||||
ds := &dnsServer{
|
ds := &dnsServer{
|
||||||
l: l,
|
l: l,
|
||||||
@@ -137,10 +136,9 @@ func Test_getDnsServerAddr(t *testing.T) {
|
|||||||
|
|
||||||
func newTestDnsServer(t *testing.T) (*dnsServer, *config.C) {
|
func newTestDnsServer(t *testing.T) (*dnsServer, *config.C) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
l := logrus.New()
|
sl := slog.New(slog.DiscardHandler)
|
||||||
l.Out = io.Discard
|
|
||||||
ds := &dnsServer{
|
ds := &dnsServer{
|
||||||
l: l,
|
l: sl,
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
dnsMap4: make(map[string]netip.Addr),
|
dnsMap4: make(map[string]netip.Addr),
|
||||||
dnsMap6: make(map[string]netip.Addr),
|
dnsMap6: make(map[string]netip.Addr),
|
||||||
@@ -148,7 +146,7 @@ func newTestDnsServer(t *testing.T) (*dnsServer, *config.C) {
|
|||||||
}
|
}
|
||||||
ds.mux = dns.NewServeMux()
|
ds.mux = dns.NewServeMux()
|
||||||
ds.mux.HandleFunc(".", ds.handleDnsRequest)
|
ds.mux.HandleFunc(".", ds.handleDnsRequest)
|
||||||
return ds, config.NewC(l)
|
return ds, config.NewC(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func setDnsConfig(c *config.C, host string, port string, amLighthouse, serveDns bool) {
|
func setDnsConfig(c *config.C, host string, port string, amLighthouse, serveDns bool) {
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/cert_test"
|
"github.com/slackhq/nebula/cert_test"
|
||||||
@@ -749,7 +748,6 @@ func TestStage1RaceRelays2(t *testing.T) {
|
|||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
||||||
l := NewTestLogger()
|
|
||||||
|
|
||||||
// Teach my how to get to the relay and that their can be reached via the relay
|
// Teach my how to get to the relay and that their can be reached via the relay
|
||||||
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
||||||
@@ -771,49 +769,41 @@ func TestStage1RaceRelays2(t *testing.T) {
|
|||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
r.Log("Get a tunnel between me and relay")
|
r.Log("Get a tunnel between me and relay")
|
||||||
l.Info("Get a tunnel between me and relay")
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
|
||||||
|
|
||||||
r.Log("Get a tunnel between them and relay")
|
r.Log("Get a tunnel between them and relay")
|
||||||
l.Info("Get a tunnel between them and relay")
|
|
||||||
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
|
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
|
||||||
|
|
||||||
r.Log("Trigger a handshake from both them and me via relay to them and me")
|
r.Log("Trigger a handshake from both them and me via relay to them and me")
|
||||||
l.Info("Trigger a handshake from both them and me via relay to them and me")
|
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||||
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
|
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
|
||||||
|
|
||||||
//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
|
//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
|
||||||
//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
|
//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
|
||||||
|
|
||||||
r.Log("Wait for a packet from them to me")
|
r.Log("Wait for a packet from them to me; myControl")
|
||||||
l.Info("Wait for a packet from them to me; myControl")
|
|
||||||
r.RouteForAllUntilTxTun(myControl)
|
r.RouteForAllUntilTxTun(myControl)
|
||||||
l.Info("Wait for a packet from them to me; theirControl")
|
r.Log("Wait for a packet from them to me; theirControl")
|
||||||
r.RouteForAllUntilTxTun(theirControl)
|
r.RouteForAllUntilTxTun(theirControl)
|
||||||
|
|
||||||
r.Log("Assert the tunnel works")
|
r.Log("Assert the tunnel works")
|
||||||
l.Info("Assert the tunnel works")
|
|
||||||
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||||
|
|
||||||
t.Log("Wait until we remove extra tunnels")
|
t.Log("Wait until we remove extra tunnels")
|
||||||
l.Info("Wait until we remove extra tunnels")
|
t.Logf("Waiting for hostinfos to be removed... myControl=%d theirControl=%d relayControl=%d",
|
||||||
l.WithFields(
|
len(myControl.GetHostmap().Indexes),
|
||||||
logrus.Fields{
|
len(theirControl.GetHostmap().Indexes),
|
||||||
"myControl": len(myControl.GetHostmap().Indexes),
|
len(relayControl.GetHostmap().Indexes),
|
||||||
"theirControl": len(theirControl.GetHostmap().Indexes),
|
)
|
||||||
"relayControl": len(relayControl.GetHostmap().Indexes),
|
|
||||||
}).Info("Waiting for hostinfos to be removed...")
|
|
||||||
hostInfos := len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes)
|
hostInfos := len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes)
|
||||||
retries := 60
|
retries := 60
|
||||||
for hostInfos > 6 && retries > 0 {
|
for hostInfos > 6 && retries > 0 {
|
||||||
hostInfos = len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes)
|
hostInfos = len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes)
|
||||||
l.WithFields(
|
t.Logf("Waiting for hostinfos to be removed... myControl=%d theirControl=%d relayControl=%d",
|
||||||
logrus.Fields{
|
len(myControl.GetHostmap().Indexes),
|
||||||
"myControl": len(myControl.GetHostmap().Indexes),
|
len(theirControl.GetHostmap().Indexes),
|
||||||
"theirControl": len(theirControl.GetHostmap().Indexes),
|
len(relayControl.GetHostmap().Indexes),
|
||||||
"relayControl": len(relayControl.GetHostmap().Indexes),
|
)
|
||||||
}).Info("Waiting for hostinfos to be removed...")
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
t.Log("Connection manager hasn't ticked yet")
|
t.Log("Connection manager hasn't ticked yet")
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
@@ -821,7 +811,6 @@ func TestStage1RaceRelays2(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r.Log("Assert the tunnel works")
|
r.Log("Assert the tunnel works")
|
||||||
l.Info("Assert the tunnel works")
|
|
||||||
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||||
|
|
||||||
myControl.Stop()
|
myControl.Stop()
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
package e2e
|
package e2e
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
@@ -12,15 +11,18 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"log/slog"
|
||||||
|
|
||||||
"dario.cat/mergo"
|
"dario.cat/mergo"
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/cert_test"
|
"github.com/slackhq/nebula/cert_test"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/e2e/router"
|
"github.com/slackhq/nebula/e2e/router"
|
||||||
|
"github.com/slackhq/nebula/logging"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.yaml.in/yaml/v3"
|
"go.yaml.in/yaml/v3"
|
||||||
@@ -132,8 +134,7 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
|
|||||||
"port": udpAddr.Port(),
|
"port": udpAddr.Port(),
|
||||||
},
|
},
|
||||||
"logging": m{
|
"logging": m{
|
||||||
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name),
|
"level": testLogLevelName(),
|
||||||
"level": l.Level.String(),
|
|
||||||
},
|
},
|
||||||
"timers": m{
|
"timers": m{
|
||||||
"pending_deletion_interval": 2,
|
"pending_deletion_interval": 2,
|
||||||
@@ -234,8 +235,7 @@ func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, o
|
|||||||
"port": udpAddr.Port(),
|
"port": udpAddr.Port(),
|
||||||
},
|
},
|
||||||
"logging": m{
|
"logging": m{
|
||||||
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
|
"level": testLogLevelName(),
|
||||||
"level": l.Level.String(),
|
|
||||||
},
|
},
|
||||||
"timers": m{
|
"timers": m{
|
||||||
"pending_deletion_interval": 2,
|
"pending_deletion_interval": 2,
|
||||||
@@ -379,24 +379,32 @@ func getAddrs(ns []netip.Prefix) []netip.Addr {
|
|||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTestLogger() *logrus.Logger {
|
func NewTestLogger() *slog.Logger {
|
||||||
l := logrus.New()
|
|
||||||
|
|
||||||
v := os.Getenv("TEST_LOGS")
|
v := os.Getenv("TEST_LOGS")
|
||||||
if v == "" {
|
if v == "" {
|
||||||
l.SetOutput(io.Discard)
|
return slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||||
l.SetLevel(logrus.PanicLevel)
|
|
||||||
return l
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
level := slog.LevelInfo
|
||||||
switch v {
|
switch v {
|
||||||
case "2":
|
case "2":
|
||||||
l.SetLevel(logrus.DebugLevel)
|
level = slog.LevelDebug
|
||||||
case "3":
|
case "3":
|
||||||
l.SetLevel(logrus.TraceLevel)
|
level = logging.LevelTrace
|
||||||
default:
|
}
|
||||||
l.SetLevel(logrus.InfoLevel)
|
return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level}))
|
||||||
}
|
}
|
||||||
|
|
||||||
return l
|
// testLogLevelName returns the level name string accepted by logging.ApplyConfig
|
||||||
|
// for the current TEST_LOGS setting. Kept in sync with NewTestLogger.
|
||||||
|
func testLogLevelName() string {
|
||||||
|
switch os.Getenv("TEST_LOGS") {
|
||||||
|
case "2":
|
||||||
|
return "debug"
|
||||||
|
case "3":
|
||||||
|
return "trace"
|
||||||
|
case "":
|
||||||
|
return "info"
|
||||||
|
}
|
||||||
|
return "info"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -292,23 +292,17 @@ tun:
|
|||||||
|
|
||||||
# Configure logging level
|
# Configure logging level
|
||||||
logging:
|
logging:
|
||||||
# panic, fatal, error, warning, info, or debug. Default is info and is reloadable.
|
# trace, debug, info, warn, or error. Default is info and is reloadable.
|
||||||
#NOTE: Debug mode can log remotely controlled/untrusted data which can quickly fill a disk in some
|
# fatal and panic are accepted for backwards compatibility and map to error.
|
||||||
# scenarios. Debug logging is also CPU intensive and will decrease performance overall.
|
#NOTE: Debug and trace modes can log remotely controlled/untrusted data which can quickly fill a disk in some
|
||||||
# Only enable debug logging while actively investigating an issue.
|
# scenarios. Debug and trace logging are also CPU intensive and will decrease performance overall.
|
||||||
|
# Only enable debug or trace logging while actively investigating an issue.
|
||||||
level: info
|
level: info
|
||||||
# json or text formats currently available. Default is text
|
# json or text formats currently available. Default is text.
|
||||||
format: text
|
format: text
|
||||||
# Disable timestamp logging. useful when output is redirected to logging system that already adds timestamps. Default is false
|
# Disable timestamp logging. Useful when output is redirected to a logging system that already adds timestamps. Default is false.
|
||||||
#disable_timestamp: true
|
#disable_timestamp: true
|
||||||
# timestamp format is specified in Go time format, see:
|
# Timestamps use RFC3339Nano ("2006-01-02T15:04:05.999999999Z07:00") and are not configurable.
|
||||||
# https://golang.org/pkg/time/#pkg-constants
|
|
||||||
# default when `format: json`: "2006-01-02T15:04:05Z07:00" (RFC3339)
|
|
||||||
# default when `format: text`:
|
|
||||||
# when TTY attached: seconds since beginning of execution
|
|
||||||
# otherwise: "2006-01-02T15:04:05Z07:00" (RFC3339)
|
|
||||||
# As an example, to log as RFC3339 with millisecond precision, set to:
|
|
||||||
#timestamp_format: "2006-01-02T15:04:05.000Z07:00"
|
|
||||||
|
|
||||||
#stats:
|
#stats:
|
||||||
#type: graphite
|
#type: graphite
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/logging"
|
||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
"github.com/slackhq/nebula/service"
|
"github.com/slackhq/nebula/service"
|
||||||
)
|
)
|
||||||
@@ -64,8 +64,7 @@ pki:
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
logger := logrus.New()
|
logger := logging.NewLogger(os.Stdout)
|
||||||
logger.Out = os.Stdout
|
|
||||||
|
|
||||||
ctrl, err := nebula.Main(&cfg, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
|
ctrl, err := nebula.Main(&cfg, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
65
firewall.go
65
firewall.go
@@ -1,11 +1,13 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"hash/fnv"
|
"hash/fnv"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"reflect"
|
"reflect"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -16,7 +18,6 @@ import (
|
|||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
@@ -67,7 +68,7 @@ type Firewall struct {
|
|||||||
incomingMetrics firewallMetrics
|
incomingMetrics firewallMetrics
|
||||||
outgoingMetrics firewallMetrics
|
outgoingMetrics firewallMetrics
|
||||||
|
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type firewallMetrics struct {
|
type firewallMetrics struct {
|
||||||
@@ -131,7 +132,7 @@ type firewallLocalCIDR struct {
|
|||||||
|
|
||||||
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
||||||
// The certificate provided should be the highest version loaded in memory.
|
// The certificate provided should be the highest version loaded in memory.
|
||||||
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
|
func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
|
||||||
//TODO: error on 0 duration
|
//TODO: error on 0 duration
|
||||||
var tmin, tmax time.Duration
|
var tmin, tmax time.Duration
|
||||||
|
|
||||||
@@ -191,7 +192,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
|
func NewFirewallFromConfig(l *slog.Logger, cs *CertState, c *config.C) (*Firewall, error) {
|
||||||
certificate := cs.getCertificate(cert.Version2)
|
certificate := cs.getCertificate(cert.Version2)
|
||||||
if certificate == nil {
|
if certificate == nil {
|
||||||
certificate = cs.getCertificate(cert.Version1)
|
certificate = cs.getCertificate(cert.Version1)
|
||||||
@@ -219,7 +220,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
|
|||||||
case "drop":
|
case "drop":
|
||||||
fw.InSendReject = false
|
fw.InSendReject = false
|
||||||
default:
|
default:
|
||||||
l.WithField("action", inboundAction).Warn("invalid firewall.inbound_action, defaulting to `drop`")
|
l.Warn("invalid firewall.inbound_action, defaulting to `drop`", "action", inboundAction)
|
||||||
fw.InSendReject = false
|
fw.InSendReject = false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -230,7 +231,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
|
|||||||
case "drop":
|
case "drop":
|
||||||
fw.OutSendReject = false
|
fw.OutSendReject = false
|
||||||
default:
|
default:
|
||||||
l.WithField("action", outboundAction).Warn("invalid firewall.outbound_action, defaulting to `drop`")
|
l.Warn("invalid firewall.outbound_action, defaulting to `drop`", "action", outboundAction)
|
||||||
fw.OutSendReject = false
|
fw.OutSendReject = false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -268,7 +269,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
|||||||
case firewall.ProtoICMP, firewall.ProtoICMPv6:
|
case firewall.ProtoICMP, firewall.ProtoICMPv6:
|
||||||
//ICMP traffic doesn't have ports, so we always coerce to "any", even if a value is provided
|
//ICMP traffic doesn't have ports, so we always coerce to "any", even if a value is provided
|
||||||
if startPort != firewall.PortAny {
|
if startPort != firewall.PortAny {
|
||||||
f.l.WithField("startPort", startPort).Warn("ignoring port specification for ICMP firewall rule")
|
f.l.Warn("ignoring port specification for ICMP firewall rule", "startPort", startPort)
|
||||||
}
|
}
|
||||||
startPort = firewall.PortAny
|
startPort = firewall.PortAny
|
||||||
endPort = firewall.PortAny
|
endPort = firewall.PortAny
|
||||||
@@ -290,8 +291,9 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
|||||||
if !incoming {
|
if !incoming {
|
||||||
direction = "outgoing"
|
direction = "outgoing"
|
||||||
}
|
}
|
||||||
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}).
|
f.l.Info("Firewall rule added",
|
||||||
Info("Firewall rule added")
|
"firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha},
|
||||||
|
)
|
||||||
|
|
||||||
return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha)
|
return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha)
|
||||||
}
|
}
|
||||||
@@ -314,7 +316,7 @@ func (f *Firewall) GetRuleHashes() string {
|
|||||||
return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10)
|
return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10)
|
||||||
}
|
}
|
||||||
|
|
||||||
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
|
func AddFirewallRulesFromConfig(l *slog.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
|
||||||
var table string
|
var table string
|
||||||
if inbound {
|
if inbound {
|
||||||
table = "firewall.inbound"
|
table = "firewall.inbound"
|
||||||
@@ -372,7 +374,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
|||||||
startPort = firewall.PortAny
|
startPort = firewall.PortAny
|
||||||
endPort = firewall.PortAny
|
endPort = firewall.PortAny
|
||||||
if sPort != "" {
|
if sPort != "" {
|
||||||
l.WithField("port", sPort).Warn("ignoring port specification for ICMP firewall rule")
|
l.Warn("ignoring port specification for ICMP firewall rule", "port", sPort)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
|
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
|
||||||
@@ -396,7 +398,11 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
|||||||
}
|
}
|
||||||
|
|
||||||
if warning := r.sanity(); warning != nil {
|
if warning := r.sanity(); warning != nil {
|
||||||
l.Warnf("%s rule #%v; %s", table, i, warning)
|
l.Warn("firewall rule sanity check",
|
||||||
|
"table", table,
|
||||||
|
"rule", i,
|
||||||
|
"warning", warning,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
|
err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
|
||||||
@@ -528,26 +534,26 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
|||||||
|
|
||||||
// We now know which firewall table to check against
|
// We now know which firewall table to check against
|
||||||
if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
|
if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
h.logger(f.l).
|
h.logger(f.l).Debug("dropping old conntrack entry, does not match new ruleset",
|
||||||
WithField("fwPacket", fp).
|
"fwPacket", fp,
|
||||||
WithField("incoming", c.incoming).
|
"incoming", c.incoming,
|
||||||
WithField("rulesVersion", f.rulesVersion).
|
"rulesVersion", f.rulesVersion,
|
||||||
WithField("oldRulesVersion", c.rulesVersion).
|
"oldRulesVersion", c.rulesVersion,
|
||||||
Debugln("dropping old conntrack entry, does not match new ruleset")
|
)
|
||||||
}
|
}
|
||||||
delete(conntrack.Conns, fp)
|
delete(conntrack.Conns, fp)
|
||||||
conntrack.Unlock()
|
conntrack.Unlock()
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
h.logger(f.l).
|
h.logger(f.l).Debug("keeping old conntrack entry, does match new ruleset",
|
||||||
WithField("fwPacket", fp).
|
"fwPacket", fp,
|
||||||
WithField("incoming", c.incoming).
|
"incoming", c.incoming,
|
||||||
WithField("rulesVersion", f.rulesVersion).
|
"rulesVersion", f.rulesVersion,
|
||||||
WithField("oldRulesVersion", c.rulesVersion).
|
"oldRulesVersion", c.rulesVersion,
|
||||||
Debugln("keeping old conntrack entry, does match new ruleset")
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.rulesVersion = f.rulesVersion
|
c.rulesVersion = f.rulesVersion
|
||||||
@@ -935,7 +941,7 @@ type rule struct {
|
|||||||
CASha string
|
CASha string
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
|
func convertRule(l *slog.Logger, p any, table string, i int) (rule, error) {
|
||||||
r := rule{}
|
r := rule{}
|
||||||
|
|
||||||
m, ok := p.(map[string]any)
|
m, ok := p.(map[string]any)
|
||||||
@@ -966,7 +972,10 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
|
|||||||
return r, errors.New("group should contain a single value, an array with more than one entry was provided")
|
return r, errors.New("group should contain a single value, an array with more than one entry was provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i)
|
l.Warn("group was an array with a single value, converting to simple value",
|
||||||
|
"table", table,
|
||||||
|
"rule", i,
|
||||||
|
)
|
||||||
m["group"] = v[0]
|
m["group"] = v[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,9 @@ package firewall
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"log/slog"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConntrackCache is used as a local routine cache to know if a given flow
|
// ConntrackCache is used as a local routine cache to know if a given flow
|
||||||
@@ -16,15 +15,17 @@ type ConntrackCacheTicker struct {
|
|||||||
cacheV uint64
|
cacheV uint64
|
||||||
cacheTick atomic.Uint64
|
cacheTick atomic.Uint64
|
||||||
|
|
||||||
|
l *slog.Logger
|
||||||
cache ConntrackCache
|
cache ConntrackCache
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConntrackCacheTicker(ctx context.Context, d time.Duration) *ConntrackCacheTicker {
|
func NewConntrackCacheTicker(ctx context.Context, l *slog.Logger, d time.Duration) *ConntrackCacheTicker {
|
||||||
if d == 0 {
|
if d == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
c := &ConntrackCacheTicker{
|
c := &ConntrackCacheTicker{
|
||||||
|
l: l,
|
||||||
cache: ConntrackCache{},
|
cache: ConntrackCache{},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -48,15 +49,15 @@ func (c *ConntrackCacheTicker) tick(ctx context.Context, d time.Duration) {
|
|||||||
|
|
||||||
// Get checks if the cache ticker has moved to the next version before returning
|
// Get checks if the cache ticker has moved to the next version before returning
|
||||||
// the map. If it has moved, we reset the map.
|
// the map. If it has moved, we reset the map.
|
||||||
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
|
func (c *ConntrackCacheTicker) Get() ConntrackCache {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if tick := c.cacheTick.Load(); tick != c.cacheV {
|
if tick := c.cacheTick.Load(); tick != c.cacheV {
|
||||||
c.cacheV = tick
|
c.cacheV = tick
|
||||||
if ll := len(c.cache); ll > 0 {
|
if ll := len(c.cache); ll > 0 {
|
||||||
if l.Level == logrus.DebugLevel {
|
if c.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
l.WithField("len", ll).Debug("resetting conntrack cache")
|
c.l.Debug("resetting conntrack cache", "len", ll)
|
||||||
}
|
}
|
||||||
c.cache = make(ConntrackCache, ll)
|
c.cache = make(ConntrackCache, ll)
|
||||||
}
|
}
|
||||||
|
|||||||
69
firewall/cache_test.go
Normal file
69
firewall/cache_test.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/test"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// The tests below pin the log format produced by ConntrackCacheTicker.Get
|
||||||
|
// so changes cannot silently break what operators are grepping for. The
|
||||||
|
// ticker's internal state (cache + cacheTick) is poked directly to avoid
|
||||||
|
// racing a goroutine-driven tick in tests.
|
||||||
|
|
||||||
|
func newFixedTicker(t *testing.T, l *slog.Logger, cacheLen int) *ConntrackCacheTicker {
|
||||||
|
t.Helper()
|
||||||
|
c := &ConntrackCacheTicker{
|
||||||
|
l: l,
|
||||||
|
cache: make(ConntrackCache, cacheLen),
|
||||||
|
}
|
||||||
|
for i := 0; i < cacheLen; i++ {
|
||||||
|
c.cache[Packet{LocalPort: uint16(i) + 1}] = struct{}{}
|
||||||
|
}
|
||||||
|
c.cacheTick.Store(1) // cacheV starts at 0, so Get() takes the reset path
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConntrackCacheTicker_Get_TextFormat(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug)
|
||||||
|
|
||||||
|
c := newFixedTicker(t, l, 3)
|
||||||
|
c.Get()
|
||||||
|
|
||||||
|
assert.Equal(t, "level=DEBUG msg=\"resetting conntrack cache\" len=3\n", buf.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConntrackCacheTicker_Get_JSONFormat(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
l := test.NewJSONLoggerWithOutput(buf, slog.LevelDebug)
|
||||||
|
|
||||||
|
c := newFixedTicker(t, l, 2)
|
||||||
|
c.Get()
|
||||||
|
|
||||||
|
assert.JSONEq(t, `{"level":"DEBUG","msg":"resetting conntrack cache","len":2}`, strings.TrimSpace(buf.String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConntrackCacheTicker_Get_QuietBelowDebug(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelInfo)
|
||||||
|
|
||||||
|
c := newFixedTicker(t, l, 5)
|
||||||
|
c.Get()
|
||||||
|
|
||||||
|
assert.Empty(t, buf.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConntrackCacheTicker_Get_QuietWhenCacheEmpty(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug)
|
||||||
|
|
||||||
|
c := newFixedTicker(t, l, 0)
|
||||||
|
c.Get()
|
||||||
|
|
||||||
|
assert.Empty(t, buf.String())
|
||||||
|
}
|
||||||
100
firewall_test.go
100
firewall_test.go
@@ -3,13 +3,13 @@ package nebula
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
@@ -58,9 +58,8 @@ func TestNewFirewall(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_AddRule(t *testing.T) {
|
func TestFirewall_AddRule(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
|
|
||||||
c := &dummyCert{}
|
c := &dummyCert{}
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
@@ -177,9 +176,8 @@ func TestFirewall_AddRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop(t *testing.T) {
|
func TestFirewall_Drop(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
@@ -254,9 +252,8 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_DropV6(t *testing.T) {
|
func TestFirewall_DropV6(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
|
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
||||||
@@ -485,9 +482,8 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop2(t *testing.T) {
|
func TestFirewall_Drop2(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
|
|
||||||
@@ -544,9 +540,8 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop3(t *testing.T) {
|
func TestFirewall_Drop3(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
|
|
||||||
@@ -633,9 +628,8 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop3V6(t *testing.T) {
|
func TestFirewall_Drop3V6(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
||||||
|
|
||||||
@@ -671,9 +665,8 @@ func TestFirewall_Drop3V6(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_DropConntrackReload(t *testing.T) {
|
func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
|
|
||||||
@@ -736,9 +729,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
|
|
||||||
@@ -880,9 +872,8 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_DropIPSpoofing(t *testing.T) {
|
func TestFirewall_DropIPSpoofing(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
|
||||||
|
|
||||||
@@ -1045,25 +1036,25 @@ func TestNewFirewallFromConfig(t *testing.T) {
|
|||||||
cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
|
cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
conf := config.NewC(l)
|
conf := config.NewC(test.NewLogger())
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
|
conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
|
require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
|
||||||
|
|
||||||
// Test both port and code
|
// Test both port and code
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
|
require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
|
||||||
|
|
||||||
// Test missing host, group, cidr, ca_name and ca_sha
|
// Test missing host, group, cidr, ca_name and ca_sha
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
|
require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
|
||||||
|
|
||||||
// Test code/port error
|
// Test code/port error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh", "proto": "any"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh", "proto": "any"}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
|
require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
|
||||||
@@ -1073,25 +1064,25 @@ func TestNewFirewallFromConfig(t *testing.T) {
|
|||||||
require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
|
require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
|
||||||
|
|
||||||
// Test proto error
|
// Test proto error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
|
require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
|
||||||
|
|
||||||
// Test cidr parse error
|
// Test cidr parse error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
||||||
|
|
||||||
// Test local_cidr parse error
|
// Test local_cidr parse error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
||||||
|
|
||||||
// Test both group and groups
|
// Test both group and groups
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
|
||||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
|
require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
|
||||||
@@ -1100,35 +1091,35 @@ func TestNewFirewallFromConfig(t *testing.T) {
|
|||||||
func TestAddFirewallRulesFromConfig(t *testing.T) {
|
func TestAddFirewallRulesFromConfig(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
// Test adding tcp rule
|
// Test adding tcp rule
|
||||||
conf := config.NewC(l)
|
conf := config.NewC(test.NewLogger())
|
||||||
mf := &mockFirewall{}
|
mf := &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding udp rule
|
// Test adding udp rule
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding icmp rule
|
// Test adding icmp rule
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding icmp rule no port
|
// Test adding icmp rule no port
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"proto": "icmp", "host": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"proto": "icmp", "host": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding any rule
|
// Test adding any rule
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
@@ -1136,14 +1127,14 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
|||||||
|
|
||||||
// Test adding rule with cidr
|
// Test adding rule with cidr
|
||||||
cidr := netip.MustParsePrefix("10.0.0.0/8")
|
cidr := netip.MustParsePrefix("10.0.0.0/8")
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with local_cidr
|
// Test adding rule with local_cidr
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
@@ -1151,82 +1142,82 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
|||||||
|
|
||||||
// Test adding rule with cidr ipv6
|
// Test adding rule with cidr ipv6
|
||||||
cidr6 := netip.MustParsePrefix("fd00::/8")
|
cidr6 := netip.MustParsePrefix("fd00::/8")
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with any cidr
|
// Test adding rule with any cidr
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with junk cidr
|
// Test adding rule with junk cidr
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}}
|
||||||
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
|
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
|
||||||
|
|
||||||
// Test adding rule with local_cidr ipv6
|
// Test adding rule with local_cidr ipv6
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with any local_cidr
|
// Test adding rule with any local_cidr
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with junk local_cidr
|
// Test adding rule with junk local_cidr
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}}
|
||||||
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
|
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
|
||||||
|
|
||||||
// Test adding rule with ca_sha
|
// Test adding rule with ca_sha
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with ca_name
|
// Test adding rule with ca_name
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall)
|
||||||
|
|
||||||
// Test single group
|
// Test single group
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test single groups
|
// Test single groups
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test multiple AND groups
|
// Test multiple AND groups
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall)
|
||||||
|
|
||||||
// Test Add error
|
// Test Add error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(test.NewLogger())
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
mf.nextCallReturn = errors.New("test error")
|
mf.nextCallReturn = errors.New("test error")
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
|
||||||
@@ -1234,9 +1225,8 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_convertRule(t *testing.T) {
|
func TestFirewall_convertRule(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
|
|
||||||
// Ensure group array of 1 is converted and a warning is printed
|
// Ensure group array of 1 is converted and a warning is printed
|
||||||
c := map[string]any{
|
c := map[string]any{
|
||||||
@@ -1244,7 +1234,9 @@ func TestFirewall_convertRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r, err := convertRule(l, c, "test", 1)
|
r, err := convertRule(l, c, "test", 1)
|
||||||
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
|
assert.Contains(t, ob.String(), "group was an array with a single value, converting to simple value")
|
||||||
|
assert.Contains(t, ob.String(), "table=test")
|
||||||
|
assert.Contains(t, ob.String(), "rule=1")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []string{"group1"}, r.Groups)
|
assert.Equal(t, []string{"group1"}, r.Groups)
|
||||||
|
|
||||||
@@ -1270,9 +1262,8 @@ func TestFirewall_convertRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_convertRuleSanity(t *testing.T) {
|
func TestFirewall_convertRuleSanity(t *testing.T) {
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
|
|
||||||
noWarningPlease := []map[string]any{
|
noWarningPlease := []map[string]any{
|
||||||
{"group": "group1"},
|
{"group": "group1"},
|
||||||
@@ -1386,7 +1377,7 @@ type testsetup struct {
|
|||||||
fw *Firewall
|
fw *Firewall
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup {
|
func newSetup(t *testing.T, l *slog.Logger, myPrefixes ...netip.Prefix) testsetup {
|
||||||
c := dummyCert{
|
c := dummyCert{
|
||||||
name: "me",
|
name: "me",
|
||||||
networks: myPrefixes,
|
networks: myPrefixes,
|
||||||
@@ -1397,7 +1388,7 @@ func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testse
|
|||||||
return newSetupFromCert(t, l, c)
|
return newSetupFromCert(t, l, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
|
func newSetupFromCert(t *testing.T, l *slog.Logger, c dummyCert) testsetup {
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
for _, prefix := range c.Networks() {
|
for _, prefix := range c.Networks() {
|
||||||
myVpnNetworksTable.Insert(prefix)
|
myVpnNetworksTable.Insert(prefix)
|
||||||
@@ -1414,9 +1405,8 @@ func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
|
|||||||
|
|
||||||
func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
|
func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l := test.NewLoggerWithOutput(ob)
|
||||||
|
|
||||||
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
|
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
|
||||||
// for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out
|
// for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out
|
||||||
|
|||||||
1
go.mod
1
go.mod
@@ -18,7 +18,6 @@ require (
|
|||||||
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
|
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
|
||||||
github.com/prometheus/client_golang v1.23.2
|
github.com/prometheus/client_golang v1.23.2
|
||||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
|
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
|
||||||
github.com/sirupsen/logrus v1.9.4
|
|
||||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||||
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -133,8 +133,6 @@ github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncj
|
|||||||
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
||||||
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
|
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
|
||||||
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
|
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
|
||||||
github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w=
|
|
||||||
github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g=
|
|
||||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
|
||||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
|
||||||
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw=
|
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw=
|
||||||
|
|||||||
541
handshake_ix.go
541
handshake_ix.go
@@ -2,11 +2,12 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/flynn/noise"
|
"github.com/flynn/noise"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
)
|
)
|
||||||
@@ -18,8 +19,11 @@ import (
|
|||||||
func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||||
err := f.handshakeManager.allocateIndex(hh)
|
err := f.handshakeManager.allocateIndex(hh)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
f.l.Error("Failed to generate index",
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
|
"error", err,
|
||||||
|
"vpnAddrs", hh.hostinfo.vpnAddrs,
|
||||||
|
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -39,28 +43,32 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
|
|
||||||
crt := cs.getCertificate(v)
|
crt := cs.getCertificate(v)
|
||||||
if crt == nil {
|
if crt == nil {
|
||||||
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
f.l.Error("Unable to handshake with host because no certificate is available",
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
"vpnAddrs", hh.hostinfo.vpnAddrs,
|
||||||
WithField("certVersion", v).
|
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||||
Error("Unable to handshake with host because no certificate is available")
|
"certVersion", v,
|
||||||
|
)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
crtHs := cs.getHandshakeBytes(v)
|
crtHs := cs.getHandshakeBytes(v)
|
||||||
if crtHs == nil {
|
if crtHs == nil {
|
||||||
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
f.l.Error("Unable to handshake with host because no certificate handshake bytes is available",
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
"vpnAddrs", hh.hostinfo.vpnAddrs,
|
||||||
WithField("certVersion", v).
|
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
"certVersion", v,
|
||||||
|
)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
|
ci, err := NewConnectionState(cs, crt, true, noise.HandshakeIX)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
f.l.Error("Failed to create connection state",
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
"error", err,
|
||||||
WithField("certVersion", v).
|
"vpnAddrs", hh.hostinfo.vpnAddrs,
|
||||||
Error("Failed to create connection state")
|
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||||
|
"certVersion", v,
|
||||||
|
)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
hh.hostinfo.ConnectionState = ci
|
hh.hostinfo.ConnectionState = ci
|
||||||
@@ -76,9 +84,12 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
|
|
||||||
hsBytes, err := hs.Marshal()
|
hsBytes, err := hs.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
f.l.Error("Failed to marshal handshake message",
|
||||||
WithField("certVersion", v).
|
"error", err,
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
"vpnAddrs", hh.hostinfo.vpnAddrs,
|
||||||
|
"certVersion", v,
|
||||||
|
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,8 +97,11 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
|
|
||||||
msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
|
msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
f.l.Error("Failed to call noise.WriteMessage",
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
"error", err,
|
||||||
|
"vpnAddrs", hh.hostinfo.vpnAddrs,
|
||||||
|
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -104,18 +118,21 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
cs := f.pki.getCertState()
|
cs := f.pki.getCertState()
|
||||||
crt := cs.GetDefaultCertificate()
|
crt := cs.GetDefaultCertificate()
|
||||||
if crt == nil {
|
if crt == nil {
|
||||||
f.l.WithField("from", via).
|
f.l.Error("Unable to handshake with host because no certificate is available",
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
"from", via,
|
||||||
WithField("certVersion", cs.initiatingVersion).
|
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||||
Error("Unable to handshake with host because no certificate is available")
|
"certVersion", cs.initiatingVersion,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
ci, err := NewConnectionState(cs, crt, false, noise.HandshakeIX)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("from", via).
|
f.l.Error("Failed to create connection state",
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"error", err,
|
||||||
Error("Failed to create connection state")
|
"from", via,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,26 +141,32 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
|
|
||||||
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("from", via).
|
f.l.Error("Failed to call noise.ReadMessage",
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"error", err,
|
||||||
Error("Failed to call noise.ReadMessage")
|
"from", via,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
hs := &NebulaHandshake{}
|
hs := &NebulaHandshake{}
|
||||||
err = hs.Unmarshal(msg)
|
err = hs.Unmarshal(msg)
|
||||||
if err != nil || hs.Details == nil {
|
if err != nil || hs.Details == nil {
|
||||||
f.l.WithError(err).WithField("from", via).
|
f.l.Error("Failed unmarshal handshake message",
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"error", err,
|
||||||
Error("Failed unmarshal handshake message")
|
"from", via,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("from", via).
|
f.l.Info("Handshake did not contain a certificate",
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"error", err,
|
||||||
Info("Handshake did not contain a certificate")
|
"from", via,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -154,23 +177,30 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
fp = "<error generating certificate fingerprint>"
|
fp = "<error generating certificate fingerprint>"
|
||||||
}
|
}
|
||||||
|
|
||||||
e := f.l.WithError(err).WithField("from", via).
|
attrs := []slog.Attr{
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
slog.Any("error", err),
|
||||||
WithField("certVpnNetworks", rc.Networks()).
|
slog.Any("from", via),
|
||||||
WithField("certFingerprint", fp)
|
slog.Any("handshake", m{"stage": 1, "style": "ix_psk0"}),
|
||||||
|
slog.Any("certVpnNetworks", rc.Networks()),
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
slog.String("certFingerprint", fp),
|
||||||
e = e.WithField("cert", rc)
|
}
|
||||||
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
attrs = append(attrs, slog.Any("cert", rc))
|
||||||
}
|
}
|
||||||
|
|
||||||
e.Info("Invalid certificate from host")
|
// LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that
|
||||||
|
// callers grow conditionally, which has no pair-form equivalent.
|
||||||
|
//nolint:sloglint
|
||||||
|
f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
|
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
|
||||||
f.l.WithField("from", via).
|
f.l.Info("public key mismatch between certificate and handshake",
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"from", via,
|
||||||
WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake")
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
"cert", remoteCert,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -178,12 +208,13 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
||||||
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
|
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
|
||||||
if myCertOtherVersion == nil {
|
if myCertOtherVersion == nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithError(err).WithFields(m{
|
f.l.Debug("Might be unable to handshake with host due to missing certificate version",
|
||||||
"from": via,
|
"error", err,
|
||||||
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
"from", via,
|
||||||
"cert": remoteCert,
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
}).Debug("Might be unable to handshake with host due to missing certificate version")
|
"cert", remoteCert,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Record the certificate we are actually using
|
// Record the certificate we are actually using
|
||||||
@@ -192,10 +223,12 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||||
f.l.WithError(err).WithField("from", via).
|
f.l.Info("No networks in certificate",
|
||||||
WithField("cert", remoteCert).
|
"error", err,
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"from", via,
|
||||||
Info("No networks in certificate")
|
"cert", remoteCert,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -209,12 +242,15 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
||||||
for i, network := range vpnNetworks {
|
for i, network := range vpnNetworks {
|
||||||
if f.myVpnAddrsTable.Contains(network.Addr()) {
|
if f.myVpnAddrsTable.Contains(network.Addr()) {
|
||||||
f.l.WithField("vpnNetworks", vpnNetworks).WithField("from", via).
|
f.l.Error("Refusing to handshake with myself",
|
||||||
WithField("certName", certName).
|
"vpnNetworks", vpnNetworks,
|
||||||
WithField("certVersion", certVersion).
|
"from", via,
|
||||||
WithField("fingerprint", fingerprint).
|
"certName", certName,
|
||||||
WithField("issuer", issuer).
|
"certVersion", certVersion,
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
"fingerprint", fingerprint,
|
||||||
|
"issuer", issuer,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
vpnAddrs[i] = network.Addr()
|
vpnAddrs[i] = network.Addr()
|
||||||
@@ -226,20 +262,28 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
if !via.IsRelayed {
|
if !via.IsRelayed {
|
||||||
// We only want to apply the remote allow list for direct tunnels here
|
// We only want to apply the remote allow list for direct tunnels here
|
||||||
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) {
|
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) {
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
Debug("lighthouse.remote_allow_list denied incoming handshake")
|
f.l.Debug("lighthouse.remote_allow_list denied incoming handshake",
|
||||||
|
"vpnAddrs", vpnAddrs,
|
||||||
|
"from", via,
|
||||||
|
)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
myIndex, err := generateIndex(f.l)
|
myIndex, err := generateIndex(f.l)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
f.l.Error("Failed to generate index",
|
||||||
WithField("certName", certName).
|
"error", err,
|
||||||
WithField("certVersion", certVersion).
|
"vpnAddrs", vpnAddrs,
|
||||||
WithField("fingerprint", fingerprint).
|
"from", via,
|
||||||
WithField("issuer", issuer).
|
"certName", certName,
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
|
"certVersion", certVersion,
|
||||||
|
"fingerprint", fingerprint,
|
||||||
|
"issuer", issuer,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -257,18 +301,18 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
msgRxL := f.l.WithFields(m{
|
msgRxL := f.l.With(
|
||||||
"vpnAddrs": vpnAddrs,
|
"vpnAddrs", vpnAddrs,
|
||||||
"from": via,
|
"from", via,
|
||||||
"certName": certName,
|
"certName", certName,
|
||||||
"certVersion": certVersion,
|
"certVersion", certVersion,
|
||||||
"fingerprint": fingerprint,
|
"fingerprint", fingerprint,
|
||||||
"issuer": issuer,
|
"issuer", issuer,
|
||||||
"initiatorIndex": hs.Details.InitiatorIndex,
|
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||||
"responderIndex": hs.Details.ResponderIndex,
|
"responderIndex", hs.Details.ResponderIndex,
|
||||||
"remoteIndex": h.RemoteIndex,
|
"remoteIndex", h.RemoteIndex,
|
||||||
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
})
|
)
|
||||||
|
|
||||||
if anyVpnAddrsInCommon {
|
if anyVpnAddrsInCommon {
|
||||||
msgRxL.Info("Handshake message received")
|
msgRxL.Info("Handshake message received")
|
||||||
@@ -280,8 +324,9 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
hs.Details.ResponderIndex = myIndex
|
hs.Details.ResponderIndex = myIndex
|
||||||
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
||||||
if hs.Details.Cert == nil {
|
if hs.Details.Cert == nil {
|
||||||
msgRxL.WithField("myCertVersion", ci.myCert.Version()).
|
msgRxL.Error("Unable to handshake with host because no certificate handshake bytes is available",
|
||||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
"myCertVersion", ci.myCert.Version(),
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -291,32 +336,43 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
|
|
||||||
hsBytes, err := hs.Marshal()
|
hsBytes, err := hs.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
f.l.Error("Failed to marshal handshake message",
|
||||||
WithField("certName", certName).
|
"error", err,
|
||||||
WithField("certVersion", certVersion).
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
WithField("fingerprint", fingerprint).
|
"from", via,
|
||||||
WithField("issuer", issuer).
|
"certName", certName,
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
"certVersion", certVersion,
|
||||||
|
"fingerprint", fingerprint,
|
||||||
|
"issuer", issuer,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
|
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
|
||||||
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
|
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
f.l.Error("Failed to call noise.WriteMessage",
|
||||||
WithField("certName", certName).
|
"error", err,
|
||||||
WithField("certVersion", certVersion).
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
WithField("fingerprint", fingerprint).
|
"from", via,
|
||||||
WithField("issuer", issuer).
|
"certName", certName,
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
"certVersion", certVersion,
|
||||||
|
"fingerprint", fingerprint,
|
||||||
|
"issuer", issuer,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
} else if dKey == nil || eKey == nil {
|
} else if dKey == nil || eKey == nil {
|
||||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
f.l.Error("Noise did not arrive at a key",
|
||||||
WithField("certName", certName).
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
WithField("certVersion", certVersion).
|
"from", via,
|
||||||
WithField("fingerprint", fingerprint).
|
"certName", certName,
|
||||||
WithField("issuer", issuer).
|
"certVersion", certVersion,
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
|
"fingerprint", fingerprint,
|
||||||
|
"issuer", issuer,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -358,13 +414,20 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
if !via.IsRelayed {
|
if !via.IsRelayed {
|
||||||
err := f.outside.WriteTo(msg, via.UdpAddr)
|
err := f.outside.WriteTo(msg, via.UdpAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
|
f.l.Error("Failed to send handshake message",
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
"vpnAddrs", existing.vpnAddrs,
|
||||||
WithError(err).Error("Failed to send handshake message")
|
"from", via,
|
||||||
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
"cached", true,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
|
f.l.Info("Handshake message sent",
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
"vpnAddrs", existing.vpnAddrs,
|
||||||
Info("Handshake message sent")
|
"from", via,
|
||||||
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
"cached", true,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
@@ -374,50 +437,67 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
}
|
}
|
||||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||||
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
||||||
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
|
f.l.Info("Handshake message sent",
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
"vpnAddrs", existing.vpnAddrs,
|
||||||
Info("Handshake message sent")
|
"relay", via.relayHI.vpnAddrs[0],
|
||||||
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
"cached", true,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case ErrExistingHostInfo:
|
case ErrExistingHostInfo:
|
||||||
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
|
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
f.l.Info("Handshake too old",
|
||||||
WithField("certName", certName).
|
"vpnAddrs", vpnAddrs,
|
||||||
WithField("certVersion", certVersion).
|
"from", via,
|
||||||
WithField("oldHandshakeTime", existing.lastHandshakeTime).
|
"certName", certName,
|
||||||
WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
|
"certVersion", certVersion,
|
||||||
WithField("fingerprint", fingerprint).
|
"oldHandshakeTime", existing.lastHandshakeTime,
|
||||||
WithField("issuer", issuer).
|
"newHandshakeTime", hostinfo.lastHandshakeTime,
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
"fingerprint", fingerprint,
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"issuer", issuer,
|
||||||
Info("Handshake too old")
|
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||||
|
"responderIndex", hs.Details.ResponderIndex,
|
||||||
|
"remoteIndex", h.RemoteIndex,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
|
|
||||||
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
|
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
|
||||||
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||||
return
|
return
|
||||||
case ErrLocalIndexCollision:
|
case ErrLocalIndexCollision:
|
||||||
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
|
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
f.l.Error("Failed to add HostInfo due to localIndex collision",
|
||||||
WithField("certName", certName).
|
"vpnAddrs", vpnAddrs,
|
||||||
WithField("certVersion", certVersion).
|
"from", via,
|
||||||
WithField("fingerprint", fingerprint).
|
"certName", certName,
|
||||||
WithField("issuer", issuer).
|
"certVersion", certVersion,
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
"fingerprint", fingerprint,
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"issuer", issuer,
|
||||||
WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs).
|
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||||
Error("Failed to add HostInfo due to localIndex collision")
|
"responderIndex", hs.Details.ResponderIndex,
|
||||||
|
"remoteIndex", h.RemoteIndex,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
"localIndex", hostinfo.localIndexId,
|
||||||
|
"collision", existing.vpnAddrs,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
|
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
|
||||||
// And we forget to update it here
|
// And we forget to update it here
|
||||||
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
f.l.Error("Failed to add HostInfo to HostMap",
|
||||||
WithField("certName", certName).
|
"error", err,
|
||||||
WithField("certVersion", certVersion).
|
"vpnAddrs", vpnAddrs,
|
||||||
WithField("fingerprint", fingerprint).
|
"from", via,
|
||||||
WithField("issuer", issuer).
|
"certName", certName,
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
"certVersion", certVersion,
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"fingerprint", fingerprint,
|
||||||
Error("Failed to add HostInfo to HostMap")
|
"issuer", issuer,
|
||||||
|
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||||
|
"responderIndex", hs.Details.ResponderIndex,
|
||||||
|
"remoteIndex", h.RemoteIndex,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -426,15 +506,20 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
|
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
|
||||||
if !via.IsRelayed {
|
if !via.IsRelayed {
|
||||||
err = f.outside.WriteTo(msg, via.UdpAddr)
|
err = f.outside.WriteTo(msg, via.UdpAddr)
|
||||||
log := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
log := f.l.With(
|
||||||
WithField("certName", certName).
|
"vpnAddrs", vpnAddrs,
|
||||||
WithField("certVersion", certVersion).
|
"from", via,
|
||||||
WithField("fingerprint", fingerprint).
|
"certName", certName,
|
||||||
WithField("issuer", issuer).
|
"certVersion", certVersion,
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
"fingerprint", fingerprint,
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
|
"issuer", issuer,
|
||||||
|
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||||
|
"responderIndex", hs.Details.ResponderIndex,
|
||||||
|
"remoteIndex", h.RemoteIndex,
|
||||||
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("Failed to send handshake")
|
log.Error("Failed to send handshake", "error", err)
|
||||||
} else {
|
} else {
|
||||||
log.Info("Handshake message sent")
|
log.Info("Handshake message sent")
|
||||||
}
|
}
|
||||||
@@ -448,14 +533,18 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
|
|||||||
// it's correctly marked as working.
|
// it's correctly marked as working.
|
||||||
via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established)
|
via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established)
|
||||||
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
|
f.l.Info("Handshake message sent",
|
||||||
WithField("certName", certName).
|
"vpnAddrs", vpnAddrs,
|
||||||
WithField("certVersion", certVersion).
|
"relay", via.relayHI.vpnAddrs[0],
|
||||||
WithField("fingerprint", fingerprint).
|
"certName", certName,
|
||||||
WithField("issuer", issuer).
|
"certVersion", certVersion,
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
"fingerprint", fingerprint,
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
"issuer", issuer,
|
||||||
Info("Handshake message sent")
|
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||||
|
"responderIndex", hs.Details.ResponderIndex,
|
||||||
|
"remoteIndex", h.RemoteIndex,
|
||||||
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.connectionManager.AddTrafficWatch(hostinfo)
|
f.connectionManager.AddTrafficWatch(hostinfo)
|
||||||
@@ -483,7 +572,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
|||||||
if !via.IsRelayed {
|
if !via.IsRelayed {
|
||||||
// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
|
// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
|
||||||
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
|
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
|
||||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
f.l.Debug("lighthouse.remote_allow_list denied incoming handshake",
|
||||||
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
|
"from", via,
|
||||||
|
)
|
||||||
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -491,18 +585,24 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
|||||||
ci := hostinfo.ConnectionState
|
ci := hostinfo.ConnectionState
|
||||||
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
f.l.Error("Failed to call noise.ReadMessage",
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
|
"error", err,
|
||||||
Error("Failed to call noise.ReadMessage")
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
|
"from", via,
|
||||||
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
"header", h,
|
||||||
|
)
|
||||||
|
|
||||||
// We don't want to tear down the connection on a bad ReadMessage because it could be an attacker trying
|
// We don't want to tear down the connection on a bad ReadMessage because it could be an attacker trying
|
||||||
// to DOS us. Every other error condition after should to allow a possible good handshake to complete in the
|
// to DOS us. Every other error condition after should to allow a possible good handshake to complete in the
|
||||||
// near future
|
// near future
|
||||||
return false
|
return false
|
||||||
} else if dKey == nil || eKey == nil {
|
} else if dKey == nil || eKey == nil {
|
||||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
f.l.Error("Noise did not arrive at a key",
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
Error("Noise did not arrive at a key")
|
"from", via,
|
||||||
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
|
|
||||||
// This should be impossible in IX but just in case, if we get here then there is no chance to recover
|
// This should be impossible in IX but just in case, if we get here then there is no chance to recover
|
||||||
// the handshake state machine. Tear it down
|
// the handshake state machine. Tear it down
|
||||||
@@ -512,8 +612,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
|||||||
hs := &NebulaHandshake{}
|
hs := &NebulaHandshake{}
|
||||||
err = hs.Unmarshal(msg)
|
err = hs.Unmarshal(msg)
|
||||||
if err != nil || hs.Details == nil {
|
if err != nil || hs.Details == nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
f.l.Error("Failed unmarshal handshake message",
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
|
"error", err,
|
||||||
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
|
"from", via,
|
||||||
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
|
|
||||||
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
|
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
|
||||||
return true
|
return true
|
||||||
@@ -521,10 +625,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
|||||||
|
|
||||||
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("from", via).
|
f.l.Info("Handshake did not contain a certificate",
|
||||||
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
"error", err,
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
"from", via,
|
||||||
Info("Handshake did not contain a certificate")
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -535,32 +641,41 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
|||||||
fp = "<error generating certificate fingerprint>"
|
fp = "<error generating certificate fingerprint>"
|
||||||
}
|
}
|
||||||
|
|
||||||
e := f.l.WithError(err).WithField("from", via).
|
attrs := []slog.Attr{
|
||||||
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
slog.Any("error", err),
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
slog.Any("from", via),
|
||||||
WithField("certFingerprint", fp).
|
slog.Any("vpnAddrs", hostinfo.vpnAddrs),
|
||||||
WithField("certVpnNetworks", rc.Networks())
|
slog.Any("handshake", m{"stage": 2, "style": "ix_psk0"}),
|
||||||
|
slog.String("certFingerprint", fp),
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
slog.Any("certVpnNetworks", rc.Networks()),
|
||||||
e = e.WithField("cert", rc)
|
}
|
||||||
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
attrs = append(attrs, slog.Any("cert", rc))
|
||||||
}
|
}
|
||||||
|
|
||||||
e.Info("Invalid certificate from host")
|
// LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that
|
||||||
|
// callers grow conditionally, which has no pair-form equivalent.
|
||||||
|
//nolint:sloglint
|
||||||
|
f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
|
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
|
||||||
f.l.WithField("from", via).
|
f.l.Info("public key mismatch between certificate and handshake",
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
"from", via,
|
||||||
WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake")
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
"cert", remoteCert,
|
||||||
|
)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||||
f.l.WithError(err).WithField("from", via).
|
f.l.Info("No networks in certificate",
|
||||||
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
"error", err,
|
||||||
WithField("cert", remoteCert).
|
"from", via,
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
Info("No networks in certificate")
|
"cert", remoteCert,
|
||||||
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -601,12 +716,14 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
|||||||
|
|
||||||
// Ensure the right host responded
|
// Ensure the right host responded
|
||||||
if !correctHostResponded {
|
if !correctHostResponded {
|
||||||
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
f.l.Info("Incorrect host responded to handshake",
|
||||||
WithField("from", via).
|
"intendedVpnAddrs", hostinfo.vpnAddrs,
|
||||||
WithField("certName", certName).
|
"haveVpnNetworks", vpnNetworks,
|
||||||
WithField("certVersion", certVersion).
|
"from", via,
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
"certName", certName,
|
||||||
Info("Incorrect host responded to handshake")
|
"certVersion", certVersion,
|
||||||
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
|
|
||||||
// Release our old handshake from pending, it should not continue
|
// Release our old handshake from pending, it should not continue
|
||||||
f.handshakeManager.DeleteHostInfo(hostinfo)
|
f.handshakeManager.DeleteHostInfo(hostinfo)
|
||||||
@@ -618,10 +735,11 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
|||||||
newHH.hostinfo.remotes = hostinfo.remotes
|
newHH.hostinfo.remotes = hostinfo.remotes
|
||||||
newHH.hostinfo.remotes.BlockRemote(via)
|
newHH.hostinfo.remotes.BlockRemote(via)
|
||||||
|
|
||||||
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
|
f.l.Info("Blocked addresses for handshakes",
|
||||||
WithField("vpnNetworks", vpnNetworks).
|
"blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes(),
|
||||||
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
|
"vpnNetworks", vpnNetworks,
|
||||||
Info("Blocked addresses for handshakes")
|
"remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges()),
|
||||||
|
)
|
||||||
|
|
||||||
// Swap the packet store to benefit the original intended recipient
|
// Swap the packet store to benefit the original intended recipient
|
||||||
newHH.packetStore = hh.packetStore
|
newHH.packetStore = hh.packetStore
|
||||||
@@ -639,15 +757,20 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
|||||||
ci.window.Update(f.l, 2)
|
ci.window.Update(f.l, 2)
|
||||||
|
|
||||||
duration := time.Since(hh.startTime).Nanoseconds()
|
duration := time.Since(hh.startTime).Nanoseconds()
|
||||||
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
msgRxL := f.l.With(
|
||||||
WithField("certName", certName).
|
"vpnAddrs", vpnAddrs,
|
||||||
WithField("certVersion", certVersion).
|
"from", via,
|
||||||
WithField("fingerprint", fingerprint).
|
"certName", certName,
|
||||||
WithField("issuer", issuer).
|
"certVersion", certVersion,
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
"fingerprint", fingerprint,
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
"issuer", issuer,
|
||||||
WithField("durationNs", duration).
|
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||||
WithField("sentCachedPackets", len(hh.packetStore))
|
"responderIndex", hs.Details.ResponderIndex,
|
||||||
|
"remoteIndex", h.RemoteIndex,
|
||||||
|
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||||
|
"durationNs", duration,
|
||||||
|
"sentCachedPackets", len(hh.packetStore),
|
||||||
|
)
|
||||||
if anyVpnAddrsInCommon {
|
if anyVpnAddrsInCommon {
|
||||||
msgRxL.Info("Handshake message received")
|
msgRxL.Info("Handshake message received")
|
||||||
} else {
|
} else {
|
||||||
@@ -663,8 +786,10 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
|
|||||||
f.handshakeManager.Complete(hostinfo, f)
|
f.handshakeManager.Complete(hostinfo, f)
|
||||||
f.connectionManager.AddTrafficWatch(hostinfo)
|
f.connectionManager.AddTrafficWatch(hostinfo)
|
||||||
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
|
hostinfo.logger(f.l).Debug("Sending stored packets",
|
||||||
|
"count", len(hh.packetStore),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(hh.packetStore) > 0 {
|
if len(hh.packetStore) > 0 {
|
||||||
|
|||||||
@@ -6,13 +6,13 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
@@ -59,7 +59,7 @@ type HandshakeManager struct {
|
|||||||
metricInitiated metrics.Counter
|
metricInitiated metrics.Counter
|
||||||
metricTimedOut metrics.Counter
|
metricTimedOut metrics.Counter
|
||||||
f *Interface
|
f *Interface
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
|
|
||||||
// can be used to trigger outbound handshake for the given vpnIp
|
// can be used to trigger outbound handshake for the given vpnIp
|
||||||
trigger chan netip.Addr
|
trigger chan netip.Addr
|
||||||
@@ -78,32 +78,32 @@ type HandshakeHostInfo struct {
|
|||||||
hostinfo *HostInfo
|
hostinfo *HostInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
|
func (hh *HandshakeHostInfo) cachePacket(l *slog.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
|
||||||
if len(hh.packetStore) < 100 {
|
if len(hh.packetStore) < 100 {
|
||||||
tempPacket := make([]byte, len(packet))
|
tempPacket := make([]byte, len(packet))
|
||||||
copy(tempPacket, packet)
|
copy(tempPacket, packet)
|
||||||
|
|
||||||
hh.packetStore = append(hh.packetStore, &cachedPacket{t, st, f, tempPacket})
|
hh.packetStore = append(hh.packetStore, &cachedPacket{t, st, f, tempPacket})
|
||||||
if l.Level >= logrus.DebugLevel {
|
if l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hh.hostinfo.logger(l).
|
hh.hostinfo.logger(l).Debug("Packet store",
|
||||||
WithField("length", len(hh.packetStore)).
|
"length", len(hh.packetStore),
|
||||||
WithField("stored", true).
|
"stored", true,
|
||||||
Debugf("Packet store")
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
m.dropped.Inc(1)
|
m.dropped.Inc(1)
|
||||||
|
|
||||||
if l.Level >= logrus.DebugLevel {
|
if l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hh.hostinfo.logger(l).
|
hh.hostinfo.logger(l).Debug("Packet store",
|
||||||
WithField("length", len(hh.packetStore)).
|
"length", len(hh.packetStore),
|
||||||
WithField("stored", false).
|
"stored", false,
|
||||||
Debugf("Packet store")
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
|
func NewHandshakeManager(l *slog.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
|
||||||
return &HandshakeManager{
|
return &HandshakeManager{
|
||||||
vpnIps: map[netip.Addr]*HandshakeHostInfo{},
|
vpnIps: map[netip.Addr]*HandshakeHostInfo{},
|
||||||
indexes: map[uint32]*HandshakeHostInfo{},
|
indexes: map[uint32]*HandshakeHostInfo{},
|
||||||
@@ -140,7 +140,7 @@ func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *head
|
|||||||
// First remote allow list check before we know the vpnIp
|
// First remote allow list check before we know the vpnIp
|
||||||
if !via.IsRelayed {
|
if !via.IsRelayed {
|
||||||
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) {
|
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) {
|
||||||
hm.l.WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
hm.l.Debug("lighthouse.remote_allow_list denied incoming handshake", "from", via)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -183,12 +183,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
hostinfo := hh.hostinfo
|
hostinfo := hh.hostinfo
|
||||||
// If we are out of time, clean up
|
// If we are out of time, clean up
|
||||||
if hh.counter >= hm.config.retries {
|
if hh.counter >= hm.config.retries {
|
||||||
hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())).
|
hh.hostinfo.logger(hm.l).Info("Handshake timed out",
|
||||||
WithField("initiatorIndex", hh.hostinfo.localIndexId).
|
"udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()),
|
||||||
WithField("remoteIndex", hh.hostinfo.remoteIndexId).
|
"initiatorIndex", hh.hostinfo.localIndexId,
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"remoteIndex", hh.hostinfo.remoteIndexId,
|
||||||
WithField("durationNs", time.Since(hh.startTime).Nanoseconds()).
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
Info("Handshake timed out")
|
"durationNs", time.Since(hh.startTime).Nanoseconds(),
|
||||||
|
)
|
||||||
hm.metricTimedOut.Inc(1)
|
hm.metricTimedOut.Inc(1)
|
||||||
hm.DeleteHostInfo(hostinfo)
|
hm.DeleteHostInfo(hostinfo)
|
||||||
return
|
return
|
||||||
@@ -241,10 +242,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
|
hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
|
||||||
err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
|
err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(hm.l).WithField("udpAddr", addr).
|
hostinfo.logger(hm.l).Error("Failed to send handshake message",
|
||||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
"udpAddr", addr,
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"initiatorIndex", hostinfo.localIndexId,
|
||||||
WithError(err).Error("Failed to send handshake message")
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
sentTo = append(sentTo, addr)
|
sentTo = append(sentTo, addr)
|
||||||
@@ -254,19 +257,21 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
// Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout,
|
// Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout,
|
||||||
// so only log when the list of remotes has changed
|
// so only log when the list of remotes has changed
|
||||||
if remotesHaveChanged {
|
if remotesHaveChanged {
|
||||||
hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
|
hostinfo.logger(hm.l).Info("Handshake message sent",
|
||||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
"udpAddrs", sentTo,
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"initiatorIndex", hostinfo.localIndexId,
|
||||||
Info("Handshake message sent")
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
} else if hm.l.Level >= logrus.DebugLevel {
|
)
|
||||||
hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
|
} else if hm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
hostinfo.logger(hm.l).Debug("Handshake message sent",
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"udpAddrs", sentTo,
|
||||||
Debug("Handshake message sent")
|
"initiatorIndex", hostinfo.localIndexId,
|
||||||
|
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if hm.config.useRelays && len(hostinfo.remotes.relays) > 0 {
|
if hm.config.useRelays && len(hostinfo.remotes.relays) > 0 {
|
||||||
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
|
hostinfo.logger(hm.l).Info("Attempt to relay through hosts", "relays", hostinfo.remotes.relays)
|
||||||
// Send a RelayRequest to all known Relay IP's
|
// Send a RelayRequest to all known Relay IP's
|
||||||
for _, relay := range hostinfo.remotes.relays {
|
for _, relay := range hostinfo.remotes.relays {
|
||||||
// Don't relay through the host I'm trying to connect to
|
// Don't relay through the host I'm trying to connect to
|
||||||
@@ -281,7 +286,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
|
|
||||||
relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay)
|
relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay)
|
||||||
if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
|
if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
|
||||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
|
hostinfo.logger(hm.l).Info("Establish tunnel to relay target", "relay", relay.String())
|
||||||
hm.f.Handshake(relay)
|
hm.f.Handshake(relay)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -292,7 +297,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
if relayHostInfo.remote.IsValid() {
|
if relayHostInfo.remote.IsValid() {
|
||||||
idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
|
idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
|
hostinfo.logger(hm.l).Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m := NebulaControl{
|
m := NebulaControl{
|
||||||
@@ -326,17 +331,15 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
|
|
||||||
msg, err := m.Marshal()
|
msg, err := m.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(hm.l).
|
hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err)
|
||||||
WithError(err).
|
|
||||||
Error("Failed to marshal Control message to create relay")
|
|
||||||
} else {
|
} else {
|
||||||
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
hm.l.WithFields(logrus.Fields{
|
hm.l.Info("send CreateRelayRequest",
|
||||||
"relayFrom": hm.f.myVpnAddrs[0],
|
"relayFrom", hm.f.myVpnAddrs[0],
|
||||||
"relayTo": vpnIp,
|
"relayTo", vpnIp,
|
||||||
"initiatorRelayIndex": idx,
|
"initiatorRelayIndex", idx,
|
||||||
"relay": relay}).
|
"relay", relay,
|
||||||
Info("send CreateRelayRequest")
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
@@ -344,14 +347,14 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
|
|
||||||
switch existingRelay.State {
|
switch existingRelay.State {
|
||||||
case Established:
|
case Established:
|
||||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay")
|
hostinfo.logger(hm.l).Info("Send handshake via relay", "relay", relay.String())
|
||||||
hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
|
hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
|
||||||
case Disestablished:
|
case Disestablished:
|
||||||
// Mark this relay as 'requested'
|
// Mark this relay as 'requested'
|
||||||
relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested)
|
relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested)
|
||||||
fallthrough
|
fallthrough
|
||||||
case Requested:
|
case Requested:
|
||||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
|
hostinfo.logger(hm.l).Info("Re-send CreateRelay request", "relay", relay.String())
|
||||||
// Re-send the CreateRelay request, in case the previous one was lost.
|
// Re-send the CreateRelay request, in case the previous one was lost.
|
||||||
m := NebulaControl{
|
m := NebulaControl{
|
||||||
Type: NebulaControl_CreateRelayRequest,
|
Type: NebulaControl_CreateRelayRequest,
|
||||||
@@ -383,28 +386,26 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
}
|
}
|
||||||
msg, err := m.Marshal()
|
msg, err := m.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(hm.l).
|
hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err)
|
||||||
WithError(err).
|
|
||||||
Error("Failed to marshal Control message to create relay")
|
|
||||||
} else {
|
} else {
|
||||||
// This must send over the hostinfo, not over hm.Hosts[ip]
|
// This must send over the hostinfo, not over hm.Hosts[ip]
|
||||||
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
hm.l.WithFields(logrus.Fields{
|
hm.l.Info("send CreateRelayRequest",
|
||||||
"relayFrom": hm.f.myVpnAddrs[0],
|
"relayFrom", hm.f.myVpnAddrs[0],
|
||||||
"relayTo": vpnIp,
|
"relayTo", vpnIp,
|
||||||
"initiatorRelayIndex": existingRelay.LocalIndex,
|
"initiatorRelayIndex", existingRelay.LocalIndex,
|
||||||
"relay": relay}).
|
"relay", relay,
|
||||||
Info("send CreateRelayRequest")
|
)
|
||||||
}
|
}
|
||||||
case PeerRequested:
|
case PeerRequested:
|
||||||
// PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case.
|
// PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case.
|
||||||
fallthrough
|
fallthrough
|
||||||
default:
|
default:
|
||||||
hostinfo.logger(hm.l).
|
hostinfo.logger(hm.l).Error("Relay unexpected state",
|
||||||
WithField("vpnIp", vpnIp).
|
"vpnIp", vpnIp,
|
||||||
WithField("state", existingRelay.State).
|
"state", existingRelay.State,
|
||||||
WithField("relay", relay).
|
"relay", relay,
|
||||||
Errorf("Relay unexpected state")
|
)
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -549,9 +550,10 @@ func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
|
|||||||
if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] {
|
if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] {
|
||||||
// We have a collision, but this can happen since we can't control
|
// We have a collision, but this can happen since we can't control
|
||||||
// the remote ID. Just log about the situation as a note.
|
// the remote ID. Just log about the situation as a note.
|
||||||
hostinfo.logger(hm.l).
|
hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex",
|
||||||
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
|
"remoteIndex", hostinfo.remoteIndexId,
|
||||||
Info("New host shadows existing host remoteIndex")
|
"collision", existingRemoteIndex.vpnAddrs,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
hm.mainHostMap.unlockedAddHostInfo(hostinfo, f)
|
hm.mainHostMap.unlockedAddHostInfo(hostinfo, f)
|
||||||
@@ -571,9 +573,10 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
|
|||||||
if found && existingRemoteIndex != nil {
|
if found && existingRemoteIndex != nil {
|
||||||
// We have a collision, but this can happen since we can't control
|
// We have a collision, but this can happen since we can't control
|
||||||
// the remote ID. Just log about the situation as a note.
|
// the remote ID. Just log about the situation as a note.
|
||||||
hostinfo.logger(hm.l).
|
hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex",
|
||||||
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
|
"remoteIndex", hostinfo.remoteIndexId,
|
||||||
Info("New host shadows existing host remoteIndex")
|
"collision", existingRemoteIndex.vpnAddrs,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap.
|
// We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap.
|
||||||
@@ -629,10 +632,11 @@ func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
|
|||||||
hm.indexes = map[uint32]*HandshakeHostInfo{}
|
hm.indexes = map[uint32]*HandshakeHostInfo{}
|
||||||
}
|
}
|
||||||
|
|
||||||
if hm.l.Level >= logrus.DebugLevel {
|
if hm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps),
|
hm.l.Debug("Pending hostmap hostInfo deleted",
|
||||||
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
|
"hostMap", m{"mapTotalSize": len(hm.vpnIps),
|
||||||
Debug("Pending hostmap hostInfo deleted")
|
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -700,7 +704,7 @@ func (hm *HandshakeManager) EmitStats() {
|
|||||||
|
|
||||||
// Utility functions below
|
// Utility functions below
|
||||||
|
|
||||||
func generateIndex(l *logrus.Logger) (uint32, error) {
|
func generateIndex(l *slog.Logger) (uint32, error) {
|
||||||
b := make([]byte, 4)
|
b := make([]byte, 4)
|
||||||
|
|
||||||
// Let zero mean we don't know the ID, so don't generate zero
|
// Let zero mean we don't know the ID, so don't generate zero
|
||||||
@@ -708,16 +712,15 @@ func generateIndex(l *logrus.Logger) (uint32, error) {
|
|||||||
for index == 0 {
|
for index == 0 {
|
||||||
_, err := rand.Read(b)
|
_, err := rand.Read(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Errorln(err)
|
l.Error("Failed to generate index", "error", err)
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
index = binary.BigEndian.Uint32(b)
|
index = binary.BigEndian.Uint32(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
if l.Level >= logrus.DebugLevel {
|
if l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
l.WithField("index", index).
|
l.Debug("Generated index", "index", index)
|
||||||
Debug("Generated index")
|
|
||||||
}
|
}
|
||||||
return index, nil
|
return index, nil
|
||||||
}
|
}
|
||||||
|
|||||||
76
hostmap.go
76
hostmap.go
@@ -1,9 +1,11 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -13,10 +15,10 @@ import (
|
|||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
|
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
|
||||||
@@ -60,7 +62,7 @@ type HostMap struct {
|
|||||||
RemoteIndexes map[uint32]*HostInfo
|
RemoteIndexes map[uint32]*HostInfo
|
||||||
Hosts map[netip.Addr]*HostInfo
|
Hosts map[netip.Addr]*HostInfo
|
||||||
preferredRanges atomic.Pointer[[]netip.Prefix]
|
preferredRanges atomic.Pointer[[]netip.Prefix]
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// For synchronization, treat the pointed-to Relay struct as immutable. To edit the Relay
|
// For synchronization, treat the pointed-to Relay struct as immutable. To edit the Relay
|
||||||
@@ -313,7 +315,7 @@ type cachedPacketMetrics struct {
|
|||||||
dropped metrics.Counter
|
dropped metrics.Counter
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
|
func NewHostMapFromConfig(l *slog.Logger, c *config.C) *HostMap {
|
||||||
hm := newHostMap(l)
|
hm := newHostMap(l)
|
||||||
|
|
||||||
hm.reload(c, true)
|
hm.reload(c, true)
|
||||||
@@ -321,13 +323,12 @@ func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
|
|||||||
hm.reload(c, false)
|
hm.reload(c, false)
|
||||||
})
|
})
|
||||||
|
|
||||||
l.WithField("preferredRanges", hm.GetPreferredRanges()).
|
l.Info("Main HostMap created", "preferredRanges", hm.GetPreferredRanges())
|
||||||
Info("Main HostMap created")
|
|
||||||
|
|
||||||
return hm
|
return hm
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostMap(l *logrus.Logger) *HostMap {
|
func newHostMap(l *slog.Logger) *HostMap {
|
||||||
return &HostMap{
|
return &HostMap{
|
||||||
Indexes: map[uint32]*HostInfo{},
|
Indexes: map[uint32]*HostInfo{},
|
||||||
Relays: map[uint32]*HostInfo{},
|
Relays: map[uint32]*HostInfo{},
|
||||||
@@ -346,7 +347,10 @@ func (hm *HostMap) reload(c *config.C, initial bool) {
|
|||||||
preferredRange, err := netip.ParsePrefix(rawPreferredRange)
|
preferredRange, err := netip.ParsePrefix(rawPreferredRange)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring")
|
hm.l.Warn("Failed to parse preferred ranges, ignoring",
|
||||||
|
"error", err,
|
||||||
|
"range", rawPreferredRanges,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -355,7 +359,10 @@ func (hm *HostMap) reload(c *config.C, initial bool) {
|
|||||||
|
|
||||||
oldRanges := hm.preferredRanges.Swap(&preferredRanges)
|
oldRanges := hm.preferredRanges.Swap(&preferredRanges)
|
||||||
if !initial {
|
if !initial {
|
||||||
hm.l.WithField("oldPreferredRanges", *oldRanges).WithField("newPreferredRanges", preferredRanges).Info("preferred_ranges changed")
|
hm.l.Info("preferred_ranges changed",
|
||||||
|
"oldPreferredRanges", *oldRanges,
|
||||||
|
"newPreferredRanges", preferredRanges,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -488,10 +495,11 @@ func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Ad
|
|||||||
hm.Indexes = map[uint32]*HostInfo{}
|
hm.Indexes = map[uint32]*HostInfo{}
|
||||||
}
|
}
|
||||||
|
|
||||||
if hm.l.Level >= logrus.DebugLevel {
|
if hm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts),
|
hm.l.Debug("Hostmap hostInfo deleted",
|
||||||
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
|
"hostMap", m{"mapTotalSize": len(hm.Hosts),
|
||||||
Debug("Hostmap hostInfo deleted")
|
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if isLastHostinfo {
|
if isLastHostinfo {
|
||||||
@@ -615,10 +623,11 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
|
|||||||
hm.Indexes[hostinfo.localIndexId] = hostinfo
|
hm.Indexes[hostinfo.localIndexId] = hostinfo
|
||||||
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
|
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
|
||||||
|
|
||||||
if hm.l.Level >= logrus.DebugLevel {
|
if hm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
|
hm.l.Debug("Hostmap vpnIp added",
|
||||||
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}).
|
"hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
|
||||||
Debug("Hostmap vpnIp added")
|
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -784,18 +793,21 @@ func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certifica
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
|
// logger returns a derived slog.Logger with per-hostinfo fields pre-bound.
|
||||||
|
func (i *HostInfo) logger(l *slog.Logger) *slog.Logger {
|
||||||
if i == nil {
|
if i == nil {
|
||||||
return logrus.NewEntry(l)
|
return l
|
||||||
}
|
}
|
||||||
|
|
||||||
li := l.WithField("vpnAddrs", i.vpnAddrs).
|
li := l.With(
|
||||||
WithField("localIndex", i.localIndexId).
|
"vpnAddrs", i.vpnAddrs,
|
||||||
WithField("remoteIndex", i.remoteIndexId)
|
"localIndex", i.localIndexId,
|
||||||
|
"remoteIndex", i.remoteIndexId,
|
||||||
|
)
|
||||||
|
|
||||||
if connState := i.ConnectionState; connState != nil {
|
if connState := i.ConnectionState; connState != nil {
|
||||||
if peerCert := connState.peerCert; peerCert != nil {
|
if peerCert := connState.peerCert; peerCert != nil {
|
||||||
li = li.WithField("certName", peerCert.Certificate.Name())
|
li = li.With("certName", peerCert.Certificate.Name())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -804,14 +816,17 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
|
|||||||
|
|
||||||
// Utility functions
|
// Utility functions
|
||||||
|
|
||||||
func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
|
func localAddrs(l *slog.Logger, allowList *LocalAllowList) []netip.Addr {
|
||||||
//FIXME: This function is pretty garbage
|
//FIXME: This function is pretty garbage
|
||||||
var finalAddrs []netip.Addr
|
var finalAddrs []netip.Addr
|
||||||
ifaces, _ := net.Interfaces()
|
ifaces, _ := net.Interfaces()
|
||||||
for _, i := range ifaces {
|
for _, i := range ifaces {
|
||||||
allow := allowList.AllowName(i.Name)
|
allow := allowList.AllowName(i.Name)
|
||||||
if l.Level >= logrus.TraceLevel {
|
if l.Enabled(context.Background(), logging.LevelTrace) {
|
||||||
l.WithField("interfaceName", i.Name).WithField("allow", allow).Trace("localAllowList.AllowName")
|
l.Log(context.Background(), logging.LevelTrace, "localAllowList.AllowName",
|
||||||
|
"interfaceName", i.Name,
|
||||||
|
"allow", allow,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !allow {
|
if !allow {
|
||||||
@@ -829,8 +844,8 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !addr.IsValid() {
|
if !addr.IsValid() {
|
||||||
if l.Level >= logrus.DebugLevel {
|
if l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
l.WithField("localAddr", rawAddr).Debug("addr was invalid")
|
l.Debug("addr was invalid", "localAddr", rawAddr)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -838,8 +853,11 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
|
|||||||
|
|
||||||
if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false {
|
if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false {
|
||||||
isAllowed := allowList.Allow(addr)
|
isAllowed := allowList.Allow(addr)
|
||||||
if l.Level >= logrus.TraceLevel {
|
if l.Enabled(context.Background(), logging.LevelTrace) {
|
||||||
l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow")
|
l.Log(context.Background(), logging.LevelTrace, "localAllowList.Allow",
|
||||||
|
"localAddr", addr,
|
||||||
|
"allowed", isAllowed,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
if !isAllowed {
|
if !isAllowed {
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -196,7 +196,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
|
|||||||
|
|
||||||
func TestHostMap_reload(t *testing.T) {
|
func TestHostMap_reload(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
c := config.NewC(l)
|
c := config.NewC(test.NewLogger())
|
||||||
|
|
||||||
hm := NewHostMapFromConfig(l, c)
|
hm := NewHostMapFromConfig(l, c)
|
||||||
|
|
||||||
|
|||||||
119
inside.go
119
inside.go
@@ -1,9 +1,10 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
@@ -14,8 +15,11 @@ import (
|
|||||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
err := newPacket(packet, false, fwPacket)
|
err := newPacket(packet, false, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
|
f.l.Debug("Error while validating outbound packet",
|
||||||
|
"packet", packet,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -35,7 +39,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
if immediatelyForwardToSelf {
|
if immediatelyForwardToSelf {
|
||||||
_, err := f.readers[q].Write(packet)
|
_, err := f.readers[q].Write(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Failed to forward to tun")
|
f.l.Error("Failed to forward to tun", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Otherwise, drop. On linux, we should never see these packets - Linux
|
// Otherwise, drop. On linux, we should never see these packets - Linux
|
||||||
@@ -54,10 +58,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
|
|
||||||
if hostinfo == nil {
|
if hostinfo == nil {
|
||||||
f.rejectInside(packet, out, q)
|
f.rejectInside(packet, out, q)
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
|
f.l.Debug("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks",
|
||||||
WithField("fwPacket", fwPacket).
|
"vpnAddr", fwPacket.RemoteAddr,
|
||||||
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
|
"fwPacket", fwPacket,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -72,11 +77,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
|
|
||||||
} else {
|
} else {
|
||||||
f.rejectInside(packet, out, q)
|
f.rejectInside(packet, out, q)
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hostinfo.logger(f.l).
|
hostinfo.logger(f.l).Debug("dropping outbound packet",
|
||||||
WithField("fwPacket", fwPacket).
|
"fwPacket", fwPacket,
|
||||||
WithField("reason", dropReason).
|
"reason", dropReason,
|
||||||
Debugln("dropping outbound packet")
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -93,7 +98,7 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
|
|||||||
|
|
||||||
_, err := f.readers[q].Write(out)
|
_, err := f.readers[q].Write(out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Failed to write to tun")
|
f.l.Error("Failed to write to tun", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,11 +113,11 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(out) > iputil.MaxRejectPacketSize {
|
if len(out) > iputil.MaxRejectPacketSize {
|
||||||
if f.l.GetLevel() >= logrus.InfoLevel {
|
if f.l.Enabled(context.Background(), slog.LevelInfo) {
|
||||||
f.l.
|
f.l.Info("rejectOutside: packet too big, not sending",
|
||||||
WithField("packet", packet).
|
"packet", packet,
|
||||||
WithField("outPacket", out).
|
"outPacket", out,
|
||||||
Info("rejectOutside: packet too big, not sending")
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -184,10 +189,11 @@ func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cac
|
|||||||
// This would also need to interact with unsafe_route updates through reloading the config or
|
// This would also need to interact with unsafe_route updates through reloading the config or
|
||||||
// use of the use_system_route_table option
|
// use of the use_system_route_table option
|
||||||
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("destination", destinationAddr).
|
f.l.Debug("Calculated gateway for ECMP not available, attempting other gateways",
|
||||||
WithField("originalGateway", gatewayAddr).
|
"destination", destinationAddr,
|
||||||
Debugln("Calculated gateway for ECMP not available, attempting other gateways")
|
"originalGateway", gatewayAddr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range gateways {
|
for i := range gateways {
|
||||||
@@ -213,17 +219,18 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
|||||||
fp := &firewall.Packet{}
|
fp := &firewall.Packet{}
|
||||||
err := newPacket(p, false, fp)
|
err := newPacket(p, false, fp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
|
f.l.Warn("error while parsing outgoing packet for firewall check", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if packet is in outbound fw rules
|
// check if packet is in outbound fw rules
|
||||||
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
|
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
|
||||||
if dropReason != nil {
|
if dropReason != nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("fwPacket", fp).
|
f.l.Debug("dropping cached packet",
|
||||||
WithField("reason", dropReason).
|
"fwPacket", fp,
|
||||||
Debugln("dropping cached packet")
|
"reason", dropReason,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -239,9 +246,10 @@ func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.Message
|
|||||||
})
|
})
|
||||||
|
|
||||||
if hostInfo == nil {
|
if hostInfo == nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("vpnAddr", vpnAddr).
|
f.l.Debug("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes",
|
||||||
Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes")
|
"vpnAddr", vpnAddr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -297,12 +305,12 @@ func (f *Interface) SendVia(via *HostInfo,
|
|||||||
if noiseutil.EncryptLockNeeded {
|
if noiseutil.EncryptLockNeeded {
|
||||||
via.ConnectionState.writeLock.Unlock()
|
via.ConnectionState.writeLock.Unlock()
|
||||||
}
|
}
|
||||||
via.logger(f.l).
|
via.logger(f.l).Error("SendVia out buffer not large enough for relay",
|
||||||
WithField("outCap", cap(out)).
|
"outCap", cap(out),
|
||||||
WithField("payloadLen", len(ad)).
|
"payloadLen", len(ad),
|
||||||
WithField("headerLen", len(out)).
|
"headerLen", len(out),
|
||||||
WithField("cipherOverhead", via.ConnectionState.eKey.Overhead()).
|
"cipherOverhead", via.ConnectionState.eKey.Overhead(),
|
||||||
Error("SendVia out buffer not large enough for relay")
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -322,12 +330,12 @@ func (f *Interface) SendVia(via *HostInfo,
|
|||||||
via.ConnectionState.writeLock.Unlock()
|
via.ConnectionState.writeLock.Unlock()
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
via.logger(f.l).WithError(err).Info("Failed to EncryptDanger in sendVia")
|
via.logger(f.l).Info("Failed to EncryptDanger in sendVia", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = f.writers[0].WriteTo(out, via.remote)
|
err = f.writers[0].WriteTo(out, via.remote)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
via.logger(f.l).WithError(err).Info("Failed to WriteTo in sendVia")
|
via.logger(f.l).Info("Failed to WriteTo in sendVia", "error", err)
|
||||||
}
|
}
|
||||||
f.connectionManager.RelayUsed(relay.LocalIndex)
|
f.connectionManager.RelayUsed(relay.LocalIndex)
|
||||||
}
|
}
|
||||||
@@ -366,8 +374,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
|
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
|
||||||
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
|
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
|
||||||
hostinfo.lastRebindCount = f.rebindCount
|
hostinfo.lastRebindCount = f.rebindCount
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
|
f.l.Debug("Lighthouse update triggered for punch due to rebind counter",
|
||||||
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -377,24 +387,30 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
ci.writeLock.Unlock()
|
ci.writeLock.Unlock()
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).
|
hostinfo.logger(f.l).Error("Failed to encrypt outgoing packet",
|
||||||
WithField("udpAddr", remote).WithField("counter", c).
|
"error", err,
|
||||||
WithField("attemptedCounter", c).
|
"udpAddr", remote,
|
||||||
Error("Failed to encrypt outgoing packet")
|
"counter", c,
|
||||||
|
"attemptedCounter", c,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if remote.IsValid() {
|
if remote.IsValid() {
|
||||||
err = f.writers[q].WriteTo(out, remote)
|
err = f.writers[q].WriteTo(out, remote)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).
|
hostinfo.logger(f.l).Error("Failed to write outgoing packet",
|
||||||
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
"error", err,
|
||||||
|
"udpAddr", remote,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
} else if hostinfo.remote.IsValid() {
|
} else if hostinfo.remote.IsValid() {
|
||||||
err = f.writers[q].WriteTo(out, hostinfo.remote)
|
err = f.writers[q].WriteTo(out, hostinfo.remote)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).
|
hostinfo.logger(f.l).Error("Failed to write outgoing packet",
|
||||||
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
"error", err,
|
||||||
|
"udpAddr", remote,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Try to send via a relay
|
// Try to send via a relay
|
||||||
@@ -402,7 +418,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.relayState.DeleteRelay(relayIP)
|
hostinfo.relayState.DeleteRelay(relayIP)
|
||||||
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
|
hostinfo.logger(f.l).Info("sendNoMetrics failed to find HostInfo",
|
||||||
|
"relay", relayIP,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
|
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
|
||||||
|
|||||||
66
interface.go
66
interface.go
@@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -12,7 +13,7 @@ import (
|
|||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
@@ -46,7 +47,7 @@ type InterfaceConfig struct {
|
|||||||
reQueryWait time.Duration
|
reQueryWait time.Duration
|
||||||
|
|
||||||
ConntrackCacheTimeout time.Duration
|
ConntrackCacheTimeout time.Duration
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type Interface struct {
|
type Interface struct {
|
||||||
@@ -100,7 +101,7 @@ type Interface struct {
|
|||||||
messageMetrics *MessageMetrics
|
messageMetrics *MessageMetrics
|
||||||
cachedPacketMetrics *cachedPacketMetrics
|
cachedPacketMetrics *cachedPacketMetrics
|
||||||
|
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type EncWriter interface {
|
type EncWriter interface {
|
||||||
@@ -223,13 +224,16 @@ func (f *Interface) activate() error {
|
|||||||
|
|
||||||
addr, err := f.outside.LocalAddr()
|
addr, err := f.outside.LocalAddr()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Failed to get udp listen address")
|
f.l.Error("Failed to get udp listen address", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks).
|
f.l.Info("Nebula interface is active",
|
||||||
WithField("build", f.version).WithField("udpAddr", addr).
|
"interface", f.inside.Name(),
|
||||||
WithField("boringcrypto", boringEnabled()).
|
"networks", f.myVpnNetworks,
|
||||||
Info("Nebula interface is active")
|
"build", f.version,
|
||||||
|
"udpAddr", addr,
|
||||||
|
"boringcrypto", boringEnabled(),
|
||||||
|
)
|
||||||
|
|
||||||
if f.routines > 1 {
|
if f.routines > 1 {
|
||||||
if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() {
|
if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() {
|
||||||
@@ -305,7 +309,7 @@ func (f *Interface) listenOut(i int) {
|
|||||||
li = f.outside
|
li = f.outside
|
||||||
}
|
}
|
||||||
|
|
||||||
ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.conntrackCacheTimeout)
|
ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
|
||||||
lhh := f.lightHouse.NewRequestHandler()
|
lhh := f.lightHouse.NewRequestHandler()
|
||||||
plaintext := make([]byte, udp.MTU)
|
plaintext := make([]byte, udp.MTU)
|
||||||
h := &header.H{}
|
h := &header.H{}
|
||||||
@@ -313,15 +317,15 @@ func (f *Interface) listenOut(i int) {
|
|||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
||||||
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get())
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil && !f.closed.Load() {
|
if err != nil && !f.closed.Load() {
|
||||||
f.l.WithError(err).Error("Error while reading inbound packet, closing")
|
f.l.Error("Error while reading inbound packet, closing", "error", err)
|
||||||
f.onFatal(err)
|
f.onFatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.l.Debugf("underlay reader %v is done", i)
|
f.l.Debug("underlay reader is done", "reader", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||||
@@ -330,22 +334,22 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
|||||||
fwPacket := &firewall.Packet{}
|
fwPacket := &firewall.Packet{}
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.conntrackCacheTimeout)
|
conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
n, err := reader.Read(packet)
|
n, err := reader.Read(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !f.closed.Load() {
|
if !f.closed.Load() {
|
||||||
f.l.WithError(err).WithField("reader", i).Error("Error while reading outbound packet, closing")
|
f.l.Error("Error while reading outbound packet, closing", "error", err, "reader", i)
|
||||||
f.onFatal(err)
|
f.onFatal(err)
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get())
|
||||||
}
|
}
|
||||||
|
|
||||||
f.l.Debugf("overlay reader %v is done", i)
|
f.l.Debug("overlay reader is done", "reader", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
||||||
@@ -365,7 +369,7 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) {
|
|||||||
if initial || c.HasChanged("pki.disconnect_invalid") {
|
if initial || c.HasChanged("pki.disconnect_invalid") {
|
||||||
f.disconnectInvalid.Store(c.GetBool("pki.disconnect_invalid", true))
|
f.disconnectInvalid.Store(c.GetBool("pki.disconnect_invalid", true))
|
||||||
if !initial {
|
if !initial {
|
||||||
f.l.Infof("pki.disconnect_invalid changed to %v", f.disconnectInvalid.Load())
|
f.l.Info("pki.disconnect_invalid changed", "value", f.disconnectInvalid.Load())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -379,7 +383,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
|
|||||||
|
|
||||||
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
|
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Error while creating firewall during reload")
|
f.l.Error("Error while creating firewall during reload", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -392,10 +396,11 @@ func (f *Interface) reloadFirewall(c *config.C) {
|
|||||||
// If rulesVersion is back to zero, we have wrapped all the way around. Be
|
// If rulesVersion is back to zero, we have wrapped all the way around. Be
|
||||||
// safe and just reset conntrack in this case.
|
// safe and just reset conntrack in this case.
|
||||||
if fw.rulesVersion == 0 {
|
if fw.rulesVersion == 0 {
|
||||||
f.l.WithField("firewallHashes", fw.GetRuleHashes()).
|
f.l.Warn("firewall rulesVersion has overflowed, resetting conntrack",
|
||||||
WithField("oldFirewallHashes", oldFw.GetRuleHashes()).
|
"firewallHashes", fw.GetRuleHashes(),
|
||||||
WithField("rulesVersion", fw.rulesVersion).
|
"oldFirewallHashes", oldFw.GetRuleHashes(),
|
||||||
Warn("firewall rulesVersion has overflowed, resetting conntrack")
|
"rulesVersion", fw.rulesVersion,
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
fw.Conntrack = conntrack
|
fw.Conntrack = conntrack
|
||||||
}
|
}
|
||||||
@@ -403,10 +408,11 @@ func (f *Interface) reloadFirewall(c *config.C) {
|
|||||||
f.firewall = fw
|
f.firewall = fw
|
||||||
|
|
||||||
oldFw.Destroy()
|
oldFw.Destroy()
|
||||||
f.l.WithField("firewallHashes", fw.GetRuleHashes()).
|
f.l.Info("New firewall has been installed",
|
||||||
WithField("oldFirewallHashes", oldFw.GetRuleHashes()).
|
"firewallHashes", fw.GetRuleHashes(),
|
||||||
WithField("rulesVersion", fw.rulesVersion).
|
"oldFirewallHashes", oldFw.GetRuleHashes(),
|
||||||
Info("New firewall has been installed")
|
"rulesVersion", fw.rulesVersion,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) reloadSendRecvError(c *config.C) {
|
func (f *Interface) reloadSendRecvError(c *config.C) {
|
||||||
@@ -428,8 +434,7 @@ func (f *Interface) reloadSendRecvError(c *config.C) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
f.l.WithField("sendRecvError", f.sendRecvErrorConfig.String()).
|
f.l.Info("Loaded send_recv_error config", "sendRecvError", f.sendRecvErrorConfig.String())
|
||||||
Info("Loaded send_recv_error config")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -452,8 +457,7 @@ func (f *Interface) reloadAcceptRecvError(c *config.C) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
f.l.WithField("acceptRecvError", f.acceptRecvErrorConfig.String()).
|
f.l.Info("Loaded accept_recv_error config", "acceptRecvError", f.acceptRecvErrorConfig.String())
|
||||||
Info("Loaded accept_recv_error config")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -527,7 +531,7 @@ func (f *Interface) Close() error {
|
|||||||
for i, u := range f.writers {
|
for i, u := range f.writers {
|
||||||
err := u.Close()
|
err := u.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("writer", i).Error("Error while closing udp socket")
|
f.l.Error("Error while closing udp socket", "error", err, "writer", i)
|
||||||
errs = append(errs, err)
|
errs = append(errs, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
228
lighthouse.go
228
lighthouse.go
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -15,10 +16,10 @@ import (
|
|||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/logging"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
)
|
)
|
||||||
@@ -76,12 +77,12 @@ type LightHouse struct {
|
|||||||
|
|
||||||
metrics *MessageMetrics
|
metrics *MessageMetrics
|
||||||
metricHolepunchTx metrics.Counter
|
metricHolepunchTx metrics.Counter
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
|
// NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
|
||||||
// addrMap should be nil unless this is during a config reload
|
// addrMap should be nil unless this is during a config reload
|
||||||
func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) {
|
func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) {
|
||||||
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
|
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
|
||||||
nebulaPort := uint32(c.GetInt("listen.port", 0))
|
nebulaPort := uint32(c.GetInt("listen.port", 0))
|
||||||
if amLighthouse && nebulaPort == 0 {
|
if amLighthouse && nebulaPort == 0 {
|
||||||
@@ -133,7 +134,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
|
|||||||
case *util.ContextualError:
|
case *util.ContextualError:
|
||||||
v.Log(l)
|
v.Log(l)
|
||||||
case error:
|
case error:
|
||||||
l.WithError(err).Error("failed to reload lighthouse")
|
l.Error("failed to reload lighthouse", "error", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -205,8 +206,10 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
|||||||
//TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used
|
//TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used
|
||||||
addr := addrs[0].Unmap()
|
addr := addrs[0].Unmap()
|
||||||
if lh.myVpnNetworksTable.Contains(addr) {
|
if lh.myVpnNetworksTable.Contains(addr) {
|
||||||
lh.l.WithField("addr", rawAddr).WithField("entry", i+1).
|
lh.l.Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range",
|
||||||
Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range")
|
"addr", rawAddr,
|
||||||
|
"entry", i+1,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -224,7 +227,9 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
|||||||
lh.interval.Store(int64(c.GetInt("lighthouse.interval", 10)))
|
lh.interval.Store(int64(c.GetInt("lighthouse.interval", 10)))
|
||||||
|
|
||||||
if !initial {
|
if !initial {
|
||||||
lh.l.Infof("lighthouse.interval changed to %v", lh.interval.Load())
|
lh.l.Info("lighthouse.interval changed",
|
||||||
|
"interval", lh.interval.Load(),
|
||||||
|
)
|
||||||
|
|
||||||
if lh.updateCancel != nil {
|
if lh.updateCancel != nil {
|
||||||
// May not always have a running routine
|
// May not always have a running routine
|
||||||
@@ -336,9 +341,12 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
|||||||
for _, v := range c.GetStringSlice("relay.relays", nil) {
|
for _, v := range c.GetStringSlice("relay.relays", nil) {
|
||||||
configRIP, err := netip.ParseAddr(v)
|
configRIP, err := netip.ParseAddr(v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lh.l.WithField("relay", v).WithError(err).Warn("Parse relay from config failed")
|
lh.l.Warn("Parse relay from config failed",
|
||||||
|
"relay", v,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
lh.l.WithField("relay", v).Info("Read relay from config")
|
lh.l.Info("Read relay from config", "relay", v)
|
||||||
relaysForMe = append(relaysForMe, configRIP)
|
relaysForMe = append(relaysForMe, configRIP)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -363,8 +371,10 @@ func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !lh.myVpnNetworksTable.Contains(addr) {
|
if !lh.myVpnNetworksTable.Contains(addr) {
|
||||||
lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}).
|
lh.l.Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not",
|
||||||
Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not")
|
"vpnAddr", addr,
|
||||||
|
"networks", lh.myVpnNetworks,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
out[i] = addr
|
out[i] = addr
|
||||||
}
|
}
|
||||||
@@ -435,8 +445,11 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !lh.myVpnNetworksTable.Contains(vpnAddr) {
|
if !lh.myVpnNetworksTable.Contains(vpnAddr) {
|
||||||
lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}).
|
lh.l.Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work",
|
||||||
Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work")
|
"vpnAddr", vpnAddr,
|
||||||
|
"networks", lh.myVpnNetworks,
|
||||||
|
"entry", i+1,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
vals, ok := v.([]any)
|
vals, ok := v.([]any)
|
||||||
@@ -537,12 +550,13 @@ func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) {
|
|||||||
lh.Lock()
|
lh.Lock()
|
||||||
rm, ok := lh.addrMap[allVpnAddrs[0]]
|
rm, ok := lh.addrMap[allVpnAddrs[0]]
|
||||||
if ok {
|
if ok {
|
||||||
|
debugEnabled := lh.l.Enabled(context.Background(), slog.LevelDebug)
|
||||||
for _, addr := range allVpnAddrs {
|
for _, addr := range allVpnAddrs {
|
||||||
srm := lh.addrMap[addr]
|
srm := lh.addrMap[addr]
|
||||||
if srm == rm {
|
if srm == rm {
|
||||||
delete(lh.addrMap, addr)
|
delete(lh.addrMap, addr)
|
||||||
if lh.l.Level >= logrus.DebugLevel {
|
if debugEnabled {
|
||||||
lh.l.Debugf("deleting %s from lighthouse.", addr)
|
lh.l.Debug("deleting from lighthouse", "vpnAddr", addr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -659,9 +673,12 @@ func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
|
|||||||
|
|
||||||
func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
|
func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
|
||||||
allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
|
allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
|
||||||
if lh.l.Level >= logrus.TraceLevel {
|
if lh.l.Enabled(context.Background(), logging.LevelTrace) {
|
||||||
lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow).
|
lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow",
|
||||||
Trace("remoteAllowList.Allow")
|
"vpnAddrs", vpnAddrs,
|
||||||
|
"udpAddr", to,
|
||||||
|
"allow", allow,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
if !allow {
|
if !allow {
|
||||||
return false
|
return false
|
||||||
@@ -678,9 +695,12 @@ func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
|
|||||||
func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bool {
|
func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bool {
|
||||||
udpAddr := protoV4AddrPortToNetAddrPort(to)
|
udpAddr := protoV4AddrPortToNetAddrPort(to)
|
||||||
allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr())
|
allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr())
|
||||||
if lh.l.Level >= logrus.TraceLevel {
|
if lh.l.Enabled(context.Background(), logging.LevelTrace) {
|
||||||
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow).
|
lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow",
|
||||||
Trace("remoteAllowList.Allow")
|
"vpnAddr", vpnAddr,
|
||||||
|
"udpAddr", udpAddr,
|
||||||
|
"allow", allow,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !allow {
|
if !allow {
|
||||||
@@ -698,9 +718,12 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bo
|
|||||||
func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bool {
|
func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bool {
|
||||||
udpAddr := protoV6AddrPortToNetAddrPort(to)
|
udpAddr := protoV6AddrPortToNetAddrPort(to)
|
||||||
allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr())
|
allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr())
|
||||||
if lh.l.Level >= logrus.TraceLevel {
|
if lh.l.Enabled(context.Background(), logging.LevelTrace) {
|
||||||
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow).
|
lh.l.Log(context.Background(), logging.LevelTrace, "remoteAllowList.Allow",
|
||||||
Trace("remoteAllowList.Allow")
|
"vpnAddr", vpnAddr,
|
||||||
|
"udpAddr", udpAddr,
|
||||||
|
"allow", allow,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !allow {
|
if !allow {
|
||||||
@@ -775,8 +798,10 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
|
|||||||
|
|
||||||
if v == cert.Version1 {
|
if v == cert.Version1 {
|
||||||
if !addr.Is4() {
|
if !addr.Is4() {
|
||||||
lh.l.WithField("queryVpnAddr", addr).WithField("lighthouseAddr", lhVpnAddr).
|
lh.l.Error("Can't query lighthouse for v6 address using a v1 protocol",
|
||||||
Error("Can't query lighthouse for v6 address using a v1 protocol")
|
"queryVpnAddr", addr,
|
||||||
|
"lighthouseAddr", lhVpnAddr,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -787,9 +812,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
|
|||||||
|
|
||||||
v1Query, err = msg.Marshal()
|
v1Query, err = msg.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lh.l.WithError(err).WithField("queryVpnAddr", addr).
|
lh.l.Error("Failed to marshal lighthouse v1 query payload",
|
||||||
WithField("lighthouseAddr", lhVpnAddr).
|
"error", err,
|
||||||
Error("Failed to marshal lighthouse v1 query payload")
|
"queryVpnAddr", addr,
|
||||||
|
"lighthouseAddr", lhVpnAddr,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -804,9 +831,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
|
|||||||
|
|
||||||
v2Query, err = msg.Marshal()
|
v2Query, err = msg.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lh.l.WithError(err).WithField("queryVpnAddr", addr).
|
lh.l.Error("Failed to marshal lighthouse v2 query payload",
|
||||||
WithField("lighthouseAddr", lhVpnAddr).
|
"error", err,
|
||||||
Error("Failed to marshal lighthouse v2 query payload")
|
"queryVpnAddr", addr,
|
||||||
|
"lighthouseAddr", lhVpnAddr,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -815,7 +844,11 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
|
|||||||
queried++
|
queried++
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
lh.l.Debugf("Can not query lighthouse for %v using unknown protocol version: %v", addr, v)
|
lh.l.Debug("unsupported protocol version",
|
||||||
|
"op", "query",
|
||||||
|
"queryVpnAddr", addr,
|
||||||
|
"version", v,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -907,8 +940,9 @@ func (lh *LightHouse) SendUpdate() {
|
|||||||
if v == cert.Version1 {
|
if v == cert.Version1 {
|
||||||
if v1Update == nil {
|
if v1Update == nil {
|
||||||
if !lh.myVpnNetworks[0].Addr().Is4() {
|
if !lh.myVpnNetworks[0].Addr().Is4() {
|
||||||
lh.l.WithField("lighthouseAddr", lhVpnAddr).
|
lh.l.Warn("cannot update lighthouse using v1 protocol without an IPv4 address",
|
||||||
Warn("cannot update lighthouse using v1 protocol without an IPv4 address")
|
"lighthouseAddr", lhVpnAddr,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
var relays []uint32
|
var relays []uint32
|
||||||
@@ -932,8 +966,10 @@ func (lh *LightHouse) SendUpdate() {
|
|||||||
|
|
||||||
v1Update, err = msg.Marshal()
|
v1Update, err = msg.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr).
|
lh.l.Error("Error while marshaling for lighthouse v1 update",
|
||||||
Error("Error while marshaling for lighthouse v1 update")
|
"error", err,
|
||||||
|
"lighthouseAddr", lhVpnAddr,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -959,8 +995,10 @@ func (lh *LightHouse) SendUpdate() {
|
|||||||
|
|
||||||
v2Update, err = msg.Marshal()
|
v2Update, err = msg.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr).
|
lh.l.Error("Error while marshaling for lighthouse v2 update",
|
||||||
Error("Error while marshaling for lighthouse v2 update")
|
"error", err,
|
||||||
|
"lighthouseAddr", lhVpnAddr,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -969,7 +1007,10 @@ func (lh *LightHouse) SendUpdate() {
|
|||||||
updated++
|
updated++
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
lh.l.Debugf("Can not update lighthouse using unknown protocol version: %v", v)
|
lh.l.Debug("unsupported protocol version",
|
||||||
|
"op", "update",
|
||||||
|
"version", v,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -983,7 +1024,7 @@ type LightHouseHandler struct {
|
|||||||
out []byte
|
out []byte
|
||||||
pb []byte
|
pb []byte
|
||||||
meta *NebulaMeta
|
meta *NebulaMeta
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) NewRequestHandler() *LightHouseHandler {
|
func (lh *LightHouse) NewRequestHandler() *LightHouseHandler {
|
||||||
@@ -1032,14 +1073,19 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [
|
|||||||
n := lhh.resetMeta()
|
n := lhh.resetMeta()
|
||||||
err := n.Unmarshal(p)
|
err := n.Unmarshal(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr).
|
lhh.l.Error("Failed to unmarshal lighthouse packet",
|
||||||
Error("Failed to unmarshal lighthouse packet")
|
"error", err,
|
||||||
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
"udpAddr", rAddr,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if n.Details == nil {
|
if n.Details == nil {
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr).
|
lhh.l.Error("Invalid lighthouse update",
|
||||||
Error("Invalid lighthouse update")
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
"udpAddr", rAddr,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1067,25 +1113,29 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [
|
|||||||
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) {
|
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) {
|
||||||
// Exit if we don't answer queries
|
// Exit if we don't answer queries
|
||||||
if !lhh.lh.amLighthouse {
|
if !lhh.lh.amLighthouse {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.Debugln("I don't answer queries, but received from: ", addr)
|
lhh.l.Debug("I don't answer queries, but received one", "from", addr)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
|
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
|
lhh.l.Debug("Dropping malformed HostQuery",
|
||||||
Debugln("Dropping malformed HostQuery")
|
"from", fromVpnAddrs,
|
||||||
|
"details", n.Details,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
|
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
|
||||||
// this case really shouldn't be possible to represent, but reject it anyway.
|
// this case really shouldn't be possible to represent, but reject it anyway.
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
|
lhh.l.Debug("invalid vpn addr for v1 handleHostQuery",
|
||||||
Debugln("invalid vpn addr for v1 handleHostQuery")
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
"queryVpnAddr", queryVpnAddr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1110,7 +1160,10 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host query reply")
|
lhh.l.Error("Failed to marshal lighthouse host query reply",
|
||||||
|
"error", err,
|
||||||
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1138,8 +1191,10 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
|
|||||||
if ok {
|
if ok {
|
||||||
whereToPunch = newDest
|
whereToPunch = newDest
|
||||||
} else {
|
} else {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
|
lhh.l.Debug("unable to punch to host, no addresses in common",
|
||||||
|
"to", crt.Networks(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1165,7 +1220,10 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host was queried for")
|
lhh.l.Error("Failed to marshal lighthouse host was queried for",
|
||||||
|
"error", err,
|
||||||
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1207,8 +1265,11 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
|
|||||||
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
|
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.WithField("version", v).Debug("unsupported protocol version")
|
lhh.l.Debug("unsupported protocol version",
|
||||||
|
"op", "coalesceAnswers",
|
||||||
|
"version", v,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1221,8 +1282,11 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
|
|||||||
|
|
||||||
certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply")
|
lhh.l.Error("dropping malformed HostQueryReply",
|
||||||
|
"error", err,
|
||||||
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1247,8 +1311,8 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
|
|||||||
|
|
||||||
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
||||||
if !lhh.lh.amLighthouse {
|
if !lhh.lh.amLighthouse {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs)
|
lhh.l.Debug("I am not a lighthouse, do not take host updates", "from", fromVpnAddrs)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1271,8 +1335,11 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
|||||||
|
|
||||||
//Simple check that the host sent this not someone else, if detailsVpnAddr is filled
|
//Simple check that the host sent this not someone else, if detailsVpnAddr is filled
|
||||||
if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
|
if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
|
lhh.l.Debug("Host sent invalid update",
|
||||||
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
"answer", detailsVpnAddr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1294,7 +1361,9 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
|||||||
switch useVersion {
|
switch useVersion {
|
||||||
case cert.Version1:
|
case cert.Version1:
|
||||||
if !fromVpnAddrs[0].Is4() {
|
if !fromVpnAddrs[0].Is4() {
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
|
lhh.l.Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message",
|
||||||
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
vpnAddrB := fromVpnAddrs[0].As4()
|
vpnAddrB := fromVpnAddrs[0].As4()
|
||||||
@@ -1302,13 +1371,16 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
|||||||
case cert.Version2:
|
case cert.Version2:
|
||||||
// do nothing, we want to send a blank message
|
// do nothing, we want to send a blank message
|
||||||
default:
|
default:
|
||||||
lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
|
lhh.l.Error("invalid protocol version", "useVersion", useVersion)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ln, err := n.MarshalTo(lhh.pb)
|
ln, err := n.MarshalTo(lhh.pb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host update ack")
|
lhh.l.Error("Failed to marshal lighthouse host update ack",
|
||||||
|
"error", err,
|
||||||
|
"vpnAddrs", fromVpnAddrs,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1325,8 +1397,11 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
|||||||
|
|
||||||
detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification")
|
lhh.l.Debug("dropping invalid HostPunchNotification",
|
||||||
|
"details", n.Details,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1343,8 +1418,11 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
|||||||
lhh.lh.punchConn.WriteTo(empty, vpnPeer)
|
lhh.lh.punchConn.WriteTo(empty, vpnPeer)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr)
|
lhh.l.Debug("Punching",
|
||||||
|
"vpnPeer", vpnPeer,
|
||||||
|
"logVpnAddr", logVpnAddr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1369,8 +1447,10 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
|||||||
if lhh.lh.punchy.GetRespond() {
|
if lhh.lh.punchy.GetRespond() {
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(lhh.lh.punchy.GetRespondDelay())
|
time.Sleep(lhh.lh.punchy.GetRespondDelay())
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr)
|
lhh.l.Debug("Sending a nebula test packet",
|
||||||
|
"vpnAddr", detailsVpnAddr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
|
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
|
||||||
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
|
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
|
||||||
|
|||||||
45
logger.go
45
logger.go
@@ -1,45 +0,0 @@
|
|||||||
package nebula
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
func configLogger(l *logrus.Logger, c *config.C) error {
|
|
||||||
// set up our logging level
|
|
||||||
logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
|
|
||||||
}
|
|
||||||
l.SetLevel(logLevel)
|
|
||||||
|
|
||||||
disableTimestamp := c.GetBool("logging.disable_timestamp", false)
|
|
||||||
timestampFormat := c.GetString("logging.timestamp_format", "")
|
|
||||||
fullTimestamp := (timestampFormat != "")
|
|
||||||
if timestampFormat == "" {
|
|
||||||
timestampFormat = time.RFC3339
|
|
||||||
}
|
|
||||||
|
|
||||||
logFormat := strings.ToLower(c.GetString("logging.format", "text"))
|
|
||||||
switch logFormat {
|
|
||||||
case "text":
|
|
||||||
l.Formatter = &logrus.TextFormatter{
|
|
||||||
TimestampFormat: timestampFormat,
|
|
||||||
FullTimestamp: fullTimestamp,
|
|
||||||
DisableTimestamp: disableTimestamp,
|
|
||||||
}
|
|
||||||
case "json":
|
|
||||||
l.Formatter = &logrus.JSONFormatter{
|
|
||||||
TimestampFormat: timestampFormat,
|
|
||||||
DisableTimestamp: disableTimestamp,
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
233
logging/logger.go
Normal file
233
logging/logger.go
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
// Package logging wires the nebula runtime-reconfigurable slog handler used
|
||||||
|
// by nebula.Main and the nebula CLI binaries. Callers build a logger with
|
||||||
|
// NewLogger, then call ApplyConfig at startup and from a config reload
|
||||||
|
// callback to push logging.level, logging.format, and
|
||||||
|
// logging.disable_timestamp changes onto the logger without rebuilding it.
|
||||||
|
package logging
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config is the subset of *config.C that ApplyConfig reads. Declaring it
|
||||||
|
// here keeps the logging package from depending on config directly, which
|
||||||
|
// would cycle through the shared test helpers (test.NewLogger imports
|
||||||
|
// logging, and config's tests import test). *config.C satisfies this
|
||||||
|
// interface structurally with no adapter.
|
||||||
|
type Config interface {
|
||||||
|
GetString(key, def string) string
|
||||||
|
GetBool(key string, def bool) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// LevelTrace is a custom slog level below Debug, used when logging.level is
|
||||||
|
// "trace". slog has no builtin trace level; the value is one step below
|
||||||
|
// slog.LevelDebug in slog's 4-point spacing.
|
||||||
|
const LevelTrace = slog.Level(-8)
|
||||||
|
|
||||||
|
// NewLogger returns a *slog.Logger whose level, format, and timestamp
|
||||||
|
// emission can be reconfigured at runtime via ApplyConfig and the SSH debug
|
||||||
|
// commands. The default configuration is info-level text output so log
|
||||||
|
// calls made before ApplyConfig runs still produce output. Timestamps
|
||||||
|
// follow slog's default RFC3339Nano format; set logging.disable_timestamp
|
||||||
|
// in config to suppress them.
|
||||||
|
//
|
||||||
|
// ApplyConfig and the SSH commands discover the reconfig surface via
|
||||||
|
// structural type-assertion on l.Handler(), so replacement implementations
|
||||||
|
// (tests, platform-specific sinks) need only implement the subset of
|
||||||
|
// {SetLevel(slog.Level), SetFormat(string) error, SetDisableTimestamp(bool)}
|
||||||
|
// they care about. Callers that pass a plain *slog.Logger without these
|
||||||
|
// methods get a silent no-op; reconfiguration is always opt-in.
|
||||||
|
func NewLogger(w io.Writer) *slog.Logger {
|
||||||
|
return slog.New(NewHandler(w))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHandler builds the *Handler that NewLogger wraps. Exported for
|
||||||
|
// platform-specific sinks (notably cmd/nebula-service/logs_windows.go)
|
||||||
|
// that want to wrap the handler with extra behavior, such as tagging each
|
||||||
|
// record with its Event Log severity, while still benefiting from all the
|
||||||
|
// level / format / timestamp / WithAttrs machinery implemented here.
|
||||||
|
func NewHandler(w io.Writer) *Handler {
|
||||||
|
root := &handlerRoot{}
|
||||||
|
root.level.Set(slog.LevelInfo)
|
||||||
|
opts := &slog.HandlerOptions{Level: &root.level}
|
||||||
|
return &Handler{
|
||||||
|
root: root,
|
||||||
|
text: slog.NewTextHandler(w, opts),
|
||||||
|
json: slog.NewJSONHandler(w, opts),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handlerRoot carries the reconfiguration state shared by every logger
|
||||||
|
// derived from a NewHandler call. All fields are consulted on the log
|
||||||
|
// path and updated lock-free.
|
||||||
|
type handlerRoot struct {
|
||||||
|
level slog.LevelVar
|
||||||
|
disableTimestamp atomic.Bool
|
||||||
|
// jsonMode picks which of the pre-derived inner handlers Handler.Handle
|
||||||
|
// dispatches to. Flipping it propagates instantly to every derived logger
|
||||||
|
// without rebuilding or chain-replaying anything.
|
||||||
|
jsonMode atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler is the slog.Handler returned by NewHandler. It holds two
|
||||||
|
// pre-derived slog handlers -- one text, one json -- both built from the
|
||||||
|
// same accumulated WithAttrs/WithGroup state. Handle picks which one to
|
||||||
|
// dispatch to based on handlerRoot.jsonMode, so a SetFormat call takes
|
||||||
|
// effect immediately across the whole process without having to rebuild
|
||||||
|
// any derived loggers.
|
||||||
|
type Handler struct {
|
||||||
|
root *handlerRoot
|
||||||
|
text slog.Handler
|
||||||
|
json slog.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) Enabled(_ context.Context, l slog.Level) bool {
|
||||||
|
return h.root.level.Level() <= l
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) Handle(ctx context.Context, r slog.Record) error {
|
||||||
|
if h.root.disableTimestamp.Load() {
|
||||||
|
r.Time = time.Time{}
|
||||||
|
}
|
||||||
|
if h.root.jsonMode.Load() {
|
||||||
|
return h.json.Handle(ctx, r)
|
||||||
|
}
|
||||||
|
return h.text.Handle(ctx, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||||
|
if len(attrs) == 0 {
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
return &Handler{
|
||||||
|
root: h.root,
|
||||||
|
text: h.text.WithAttrs(attrs),
|
||||||
|
json: h.json.WithAttrs(attrs),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) WithGroup(name string) slog.Handler {
|
||||||
|
if name == "" {
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
return &Handler{
|
||||||
|
root: h.root,
|
||||||
|
text: h.text.WithGroup(name),
|
||||||
|
json: h.json.WithGroup(name),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLevel updates the effective log level. Propagates to every derived
|
||||||
|
// logger via the shared LevelVar.
|
||||||
|
func (h *Handler) SetLevel(level slog.Level) { h.root.level.Set(level) }
|
||||||
|
|
||||||
|
// GetLevel reports the current log level.
|
||||||
|
func (h *Handler) GetLevel() slog.Level { return h.root.level.Level() }
|
||||||
|
|
||||||
|
// SetFormat flips the output format atomically. Valid formats are "text"
|
||||||
|
// and "json". Every derived logger sees the new format on its next Handle
|
||||||
|
// call; no rebuild or registration is required.
|
||||||
|
func (h *Handler) SetFormat(format string) error {
|
||||||
|
switch format {
|
||||||
|
case "text":
|
||||||
|
h.root.jsonMode.Store(false)
|
||||||
|
case "json":
|
||||||
|
h.root.jsonMode.Store(true)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown log format `%s`. possible formats: %s", format, []string{"text", "json"})
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFormat reports the currently selected format name.
|
||||||
|
func (h *Handler) GetFormat() string {
|
||||||
|
if h.root.jsonMode.Load() {
|
||||||
|
return "json"
|
||||||
|
}
|
||||||
|
return "text"
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDisableTimestamp toggles whether Handle zeroes r.Time before
|
||||||
|
// dispatching (slog's builtin text/json handlers skip emitting the time
|
||||||
|
// attribute on a zero time).
|
||||||
|
func (h *Handler) SetDisableTimestamp(v bool) { h.root.disableTimestamp.Store(v) }
|
||||||
|
|
||||||
|
// ApplyConfig reads logging.level, logging.format, and (optionally)
|
||||||
|
// logging.disable_timestamp from c and applies them to l. The reconfig
|
||||||
|
// surface is discovered via structural type-assertion on l.Handler(), so
|
||||||
|
// foreign handlers silently opt out of whichever capabilities they do not
|
||||||
|
// implement.
|
||||||
|
//
|
||||||
|
// nebula.Main does NOT call this function on your behalf; callers that want
|
||||||
|
// config-driven log level / format / timestamp updates invoke it at
|
||||||
|
// startup and register it as a reload callback themselves. This keeps the
|
||||||
|
// library from mutating an embedder's logger without their say-so.
|
||||||
|
func ApplyConfig(l *slog.Logger, c Config) error {
|
||||||
|
h := l.Handler()
|
||||||
|
|
||||||
|
lvl, err := ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if ls, ok := h.(interface{ SetLevel(slog.Level) }); ok {
|
||||||
|
ls.SetLevel(lvl)
|
||||||
|
}
|
||||||
|
|
||||||
|
format := strings.ToLower(c.GetString("logging.format", "text"))
|
||||||
|
if fs, ok := h.(interface{ SetFormat(string) error }); ok {
|
||||||
|
if err := fs.SetFormat(format); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ts, ok := h.(interface{ SetDisableTimestamp(bool) }); ok {
|
||||||
|
ts.SetDisableTimestamp(c.GetBool("logging.disable_timestamp", false))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseLevel converts a config-string level name ("trace", "debug", "info",
|
||||||
|
// "warn"/"warning", "error", "fatal"/"panic") to a slog.Level. "fatal" and
|
||||||
|
// "panic" are accepted for backwards compatibility with pre-slog configs
|
||||||
|
// and both map to slog.LevelError.
|
||||||
|
func ParseLevel(s string) (slog.Level, error) {
|
||||||
|
switch s {
|
||||||
|
case "trace":
|
||||||
|
return LevelTrace, nil
|
||||||
|
case "debug":
|
||||||
|
return slog.LevelDebug, nil
|
||||||
|
case "info":
|
||||||
|
return slog.LevelInfo, nil
|
||||||
|
case "warn", "warning":
|
||||||
|
return slog.LevelWarn, nil
|
||||||
|
case "error":
|
||||||
|
return slog.LevelError, nil
|
||||||
|
case "fatal", "panic":
|
||||||
|
return slog.LevelError, nil
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("not a valid logging level: %q", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LevelName returns a human-readable name for a slog.Level matching the
|
||||||
|
// strings accepted by ParseLevel.
|
||||||
|
func LevelName(l slog.Level) string {
|
||||||
|
switch {
|
||||||
|
case l <= LevelTrace:
|
||||||
|
return "trace"
|
||||||
|
case l <= slog.LevelDebug:
|
||||||
|
return "debug"
|
||||||
|
case l <= slog.LevelInfo:
|
||||||
|
return "info"
|
||||||
|
case l <= slog.LevelWarn:
|
||||||
|
return "warn"
|
||||||
|
default:
|
||||||
|
return "error"
|
||||||
|
}
|
||||||
|
}
|
||||||
90
logging/logger_bench_test.go
Normal file
90
logging/logger_bench_test.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package logging
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BenchmarkLogger_* compare the handler returned by NewLogger against a
|
||||||
|
// stock slog text handler. The key thing we care about is the per-log
|
||||||
|
// cost on a logger that has been derived via .With(), because that is the
|
||||||
|
// shape subsystems store on their structs (HostInfo.logger(),
|
||||||
|
// lh.l.With("subsystem", ...), etc.) and call from hot paths.
|
||||||
|
|
||||||
|
func BenchmarkLogger_Stock_RootInfo(b *testing.B) {
|
||||||
|
l := slog.New(slog.DiscardHandler)
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
l.Info("hello", "i", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLogger_Nebula_RootInfo(b *testing.B) {
|
||||||
|
l := NewLogger(io.Discard)
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
l.Info("hello", "i", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLogger_Stock_DerivedInfo(b *testing.B) {
|
||||||
|
l := slog.New(slog.DiscardHandler).With(
|
||||||
|
"subsystem", "bench",
|
||||||
|
"localIndex", 1234,
|
||||||
|
)
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
l.Info("hello", "i", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLogger_Nebula_DerivedInfo(b *testing.B) {
|
||||||
|
l := NewLogger(io.Discard).With(
|
||||||
|
"subsystem", "bench",
|
||||||
|
"localIndex", 1234,
|
||||||
|
)
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
l.Info("hello", "i", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gated-off-path benchmarks: mimic the typical hot-path shape
|
||||||
|
// `if l.Enabled(ctx, slog.LevelDebug) { ... }` where the log is gated below
|
||||||
|
// the active level. This is the dominant pattern in inside.go/outside.go and
|
||||||
|
// what we pay on every packet.
|
||||||
|
func BenchmarkLogger_Stock_DerivedEnabledGateMiss(b *testing.B) {
|
||||||
|
l := slog.New(slog.DiscardHandler).With(
|
||||||
|
"subsystem", "bench",
|
||||||
|
"localIndex", 1234,
|
||||||
|
)
|
||||||
|
ctx := context.Background()
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if l.Enabled(ctx, slog.LevelDebug) {
|
||||||
|
l.Debug("hello", "i", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLogger_Nebula_DerivedEnabledGateMiss(b *testing.B) {
|
||||||
|
l := NewLogger(io.Discard).With(
|
||||||
|
"subsystem", "bench",
|
||||||
|
"localIndex", 1234,
|
||||||
|
)
|
||||||
|
ctx := context.Background()
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if l.Enabled(ctx, slog.LevelDebug) {
|
||||||
|
l.Debug("hello", "i", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
39
main.go
39
main.go
@@ -3,13 +3,13 @@ package nebula
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
"github.com/slackhq/nebula/sshd"
|
"github.com/slackhq/nebula/sshd"
|
||||||
@@ -20,7 +20,7 @@ import (
|
|||||||
|
|
||||||
type m = map[string]any
|
type m = map[string]any
|
||||||
|
|
||||||
func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
|
func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
|
// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -33,11 +33,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
buildVersion = moduleVersion()
|
buildVersion = moduleVersion()
|
||||||
}
|
}
|
||||||
|
|
||||||
l := logger
|
|
||||||
l.Formatter = &logrus.TextFormatter{
|
|
||||||
FullTimestamp: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Print the config if in test, the exit comes later
|
// Print the config if in test, the exit comes later
|
||||||
if configTest {
|
if configTest {
|
||||||
b, err := yaml.Marshal(c.Settings)
|
b, err := yaml.Marshal(c.Settings)
|
||||||
@@ -46,21 +41,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Print the final config
|
// Print the final config
|
||||||
l.Println(string(b))
|
l.Info(string(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
err := configLogger(l, c)
|
|
||||||
if err != nil {
|
|
||||||
return nil, util.ContextualizeIfNeeded("Failed to configure the logger", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.RegisterReloadCallback(func(c *config.C) {
|
|
||||||
err := configLogger(l, c)
|
|
||||||
if err != nil {
|
|
||||||
l.WithError(err).Error("Failed to configure the logger")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
pki, err := NewPKIFromConfig(l, c)
|
pki, err := NewPKIFromConfig(l, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
|
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
|
||||||
@@ -70,9 +53,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
|
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
|
||||||
}
|
}
|
||||||
l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
|
l.Info("Firewall started", "firewallHashes", fw.GetRuleHashes())
|
||||||
|
|
||||||
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
ssh, err := sshd.NewSSHServer(l.With("subsystem", "sshd"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
|
return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
|
||||||
}
|
}
|
||||||
@@ -81,7 +64,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
if c.GetBool("sshd.enabled", false) {
|
if c.GetBool("sshd.enabled", false) {
|
||||||
sshStart, err = configSSH(l, ssh, c)
|
sshStart, err = configSSH(l, ssh, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available")
|
l.Warn("Failed to configure sshd, ssh debugging will not be available", "error", err)
|
||||||
sshStart = nil
|
sshStart = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -99,7 +82,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
routines = 1
|
routines = 1
|
||||||
}
|
}
|
||||||
if routines > 1 {
|
if routines > 1 {
|
||||||
l.WithField("routines", routines).Info("Using multiple routines")
|
l.Info("Using multiple routines", "routines", routines)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// deprecated and undocumented
|
// deprecated and undocumented
|
||||||
@@ -107,7 +90,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
udpQueues := c.GetInt("listen.routines", 1)
|
udpQueues := c.GetInt("listen.routines", 1)
|
||||||
routines = max(tunQueues, udpQueues)
|
routines = max(tunQueues, udpQueues)
|
||||||
if routines != 1 {
|
if routines != 1 {
|
||||||
l.WithField("routines", routines).Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead")
|
l.Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead", "routines", routines)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -120,7 +103,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
conntrackCacheTimeout = 1 * time.Second
|
conntrackCacheTimeout = 1 * time.Second
|
||||||
}
|
}
|
||||||
if conntrackCacheTimeout > 0 {
|
if conntrackCacheTimeout > 0 {
|
||||||
l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache")
|
l.Info("Using routine-local conntrack cache", "duration", conntrackCacheTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
var tun overlay.Device
|
var tun overlay.Device
|
||||||
@@ -166,7 +149,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < routines; i++ {
|
for i := 0; i < routines; i++ {
|
||||||
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
|
l.Info("listening", "addr", netip.AddrPortFrom(listenHost, uint16(port)))
|
||||||
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
|
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
||||||
@@ -217,7 +200,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
|
|
||||||
ds, err := newDnsServerFromConfig(ctx, l, pki.getCertState(), hostMap, c)
|
ds, err := newDnsServerFromConfig(ctx, l, pki.getCertState(), hostMap, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Warn("Failed to start DNS responder")
|
l.Warn("Failed to start DNS responder", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ifConfig := &InterfaceConfig{
|
ifConfig := &InterfaceConfig{
|
||||||
|
|||||||
149
outside.go
149
outside.go
@@ -1,15 +1,16 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
@@ -24,7 +25,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
||||||
if len(packet) > 1 {
|
if len(packet) > 1 {
|
||||||
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err)
|
f.l.Info("Error while parsing inbound packet",
|
||||||
|
"from", via,
|
||||||
|
"error", err,
|
||||||
|
"packet", packet,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -32,8 +37,8 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
|||||||
//l.Error("in packet ", header, packet[HeaderLen:])
|
//l.Error("in packet ", header, packet[HeaderLen:])
|
||||||
if !via.IsRelayed {
|
if !via.IsRelayed {
|
||||||
if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) {
|
if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("from", via).Debug("Refusing to process double encrypted packet")
|
f.l.Debug("Refusing to process double encrypted packet", "from", via)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -87,7 +92,10 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
|||||||
if !ok {
|
if !ok {
|
||||||
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
|
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
|
||||||
// its internal mapping. This should never happen.
|
// its internal mapping. This should never happen.
|
||||||
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
|
hostinfo.logger(f.l).Error("HostInfo missing remote relay index",
|
||||||
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
|
"remoteIndex", h.RemoteIndex,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,7 +116,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
|||||||
// Find the target HostInfo relay object
|
// Find the target HostInfo relay object
|
||||||
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
|
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
|
hostinfo.logger(f.l).Info("Failed to find target host info by ip",
|
||||||
|
"relayTo", relay.PeerAddr,
|
||||||
|
"error", err,
|
||||||
|
"hostinfo.vpnAddrs", hostinfo.vpnAddrs,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,7 +136,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
|||||||
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
|
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
|
hostinfo.logger(f.l).Info("Unexpected target relay state",
|
||||||
|
"relayTo", relay.PeerAddr,
|
||||||
|
"relayFrom", hostinfo.vpnAddrs[0],
|
||||||
|
"targetRelayState", targetRelay.State,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -138,9 +154,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
|||||||
|
|
||||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("from", via).
|
hostinfo.logger(f.l).Error("Failed to decrypt lighthouse packet",
|
||||||
WithField("packet", packet).
|
"error", err,
|
||||||
Error("Failed to decrypt lighthouse packet")
|
"from", via,
|
||||||
|
"packet", packet,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,9 +175,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
|||||||
|
|
||||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("from", via).
|
hostinfo.logger(f.l).Error("Failed to decrypt test packet",
|
||||||
WithField("packet", packet).
|
"error", err,
|
||||||
Error("Failed to decrypt test packet")
|
"from", via,
|
||||||
|
"packet", packet,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -192,14 +212,15 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
|||||||
}
|
}
|
||||||
_, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
_, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("from", via).
|
hostinfo.logger(f.l).Error("Failed to decrypt CloseTunnel packet",
|
||||||
WithField("packet", packet).
|
"error", err,
|
||||||
Error("Failed to decrypt CloseTunnel packet")
|
"from", via,
|
||||||
|
"packet", packet,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo.logger(f.l).WithField("from", via).
|
hostinfo.logger(f.l).Info("Close tunnel received, tearing down.", "from", via)
|
||||||
Info("Close tunnel received, tearing down.")
|
|
||||||
|
|
||||||
f.closeTunnel(hostinfo)
|
f.closeTunnel(hostinfo)
|
||||||
return
|
return
|
||||||
@@ -211,9 +232,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
|||||||
|
|
||||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("from", via).
|
hostinfo.logger(f.l).Error("Failed to decrypt Control packet",
|
||||||
WithField("packet", packet).
|
"error", err,
|
||||||
Error("Failed to decrypt Control packet")
|
"from", via,
|
||||||
|
"packet", packet,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -221,7 +244,9 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
|||||||
|
|
||||||
default:
|
default:
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via)
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
hostinfo.logger(f.l).Debug("Unexpected packet received", "from", via)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -247,20 +272,27 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
|
|||||||
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) {
|
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) {
|
||||||
if !via.IsRelayed && hostinfo.remote != via.UdpAddr {
|
if !via.IsRelayed && hostinfo.remote != via.UdpAddr {
|
||||||
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
|
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
|
||||||
hostinfo.logger(f.l).WithField("newAddr", via.UdpAddr).Debug("lighthouse.remote_allow_list denied roaming")
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
hostinfo.logger(f.l).Debug("lighthouse.remote_allow_list denied roaming", "newAddr", via.UdpAddr)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
|
if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
|
hostinfo.logger(f.l).Debug("Suppressing roam back to previous remote",
|
||||||
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
|
"suppressSeconds", RoamingSuppressSeconds,
|
||||||
|
"udpAddr", hostinfo.remote,
|
||||||
|
"newAddr", via.UdpAddr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
|
hostinfo.logger(f.l).Info("Host roamed to new udp ip/port.",
|
||||||
Info("Host roamed to new udp ip/port.")
|
"udpAddr", hostinfo.remote,
|
||||||
|
"newAddr", via.UdpAddr,
|
||||||
|
)
|
||||||
hostinfo.lastRoam = time.Now()
|
hostinfo.lastRoam = time.Now()
|
||||||
hostinfo.lastRoamRemote = hostinfo.remote
|
hostinfo.lastRoamRemote = hostinfo.remote
|
||||||
hostinfo.SetRemote(via.UdpAddr)
|
hostinfo.SetRemote(via.UdpAddr)
|
||||||
@@ -491,8 +523,9 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !hostinfo.ConnectionState.window.Update(f.l, mc) {
|
if !hostinfo.ConnectionState.window.Update(f.l, mc) {
|
||||||
hostinfo.logger(f.l).WithField("header", h).
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
Debugln("dropping out of window packet")
|
hostinfo.logger(f.l).Debug("dropping out of window packet", "header", h)
|
||||||
|
}
|
||||||
return nil, errors.New("out of window packet")
|
return nil, errors.New("out of window packet")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -504,20 +537,23 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
|
|
||||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
hostinfo.logger(f.l).Error("Failed to decrypt packet", "error", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
err = newPacket(out, true, fwPacket)
|
err = newPacket(out, true, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
|
hostinfo.logger(f.l).Warn("Error while validating inbound packet",
|
||||||
Warnf("Error while validating inbound packet")
|
"error", err,
|
||||||
|
"packet", out,
|
||||||
|
)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
|
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
|
||||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
Debugln("dropping out of window packet")
|
hostinfo.logger(f.l).Debug("dropping out of window packet", "fwPacket", fwPacket)
|
||||||
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -526,10 +562,11 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
||||||
// This gives us a buffer to build the reject packet in
|
// This gives us a buffer to build the reject packet in
|
||||||
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
|
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
hostinfo.logger(f.l).Debug("dropping inbound packet",
|
||||||
WithField("reason", dropReason).
|
"fwPacket", fwPacket,
|
||||||
Debugln("dropping inbound packet")
|
"reason", dropReason,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -537,7 +574,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
f.connectionManager.In(hostinfo)
|
f.connectionManager.In(hostinfo)
|
||||||
_, err = f.readers[q].Write(out)
|
_, err = f.readers[q].Write(out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Failed to write to tun")
|
f.l.Error("Failed to write to tun", "error", err)
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -553,35 +590,41 @@ func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) {
|
|||||||
|
|
||||||
b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
|
b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
|
||||||
_ = f.outside.WriteTo(b, endpoint)
|
_ = f.outside.WriteTo(b, endpoint)
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("index", index).
|
f.l.Debug("Recv error sent",
|
||||||
WithField("udpAddr", endpoint).
|
"index", index,
|
||||||
Debug("Recv error sent")
|
"udpAddr", endpoint,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
|
func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
|
||||||
if !f.acceptRecvErrorConfig.ShouldRecvError(addr) {
|
if !f.acceptRecvErrorConfig.ShouldRecvError(addr) {
|
||||||
f.l.WithField("index", h.RemoteIndex).
|
f.l.Debug("Recv error received, ignoring",
|
||||||
WithField("udpAddr", addr).
|
"index", h.RemoteIndex,
|
||||||
Debug("Recv error received, ignoring")
|
"udpAddr", addr,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.WithField("index", h.RemoteIndex).
|
f.l.Debug("Recv error received",
|
||||||
WithField("udpAddr", addr).
|
"index", h.RemoteIndex,
|
||||||
Debug("Recv error received")
|
"udpAddr", addr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo := f.hostMap.QueryReverseIndex(h.RemoteIndex)
|
hostinfo := f.hostMap.QueryReverseIndex(h.RemoteIndex)
|
||||||
if hostinfo == nil {
|
if hostinfo == nil {
|
||||||
f.l.WithField("remoteIndex", h.RemoteIndex).Debugln("Did not find remote index in main hostmap")
|
f.l.Debug("Did not find remote index in main hostmap", "remoteIndex", h.RemoteIndex)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if hostinfo.remote.IsValid() && hostinfo.remote != addr {
|
if hostinfo.remote.IsValid() && hostinfo.remote != addr {
|
||||||
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
|
f.l.Info("Someone spoofing recv_errors?",
|
||||||
|
"addr", addr,
|
||||||
|
"hostinfoRemote", hostinfo.remote,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
package test
|
// Package overlaytest provides fakes of overlay.Device for tests that do
|
||||||
|
// not want to touch a real tun device or route table.
|
||||||
|
package overlaytest
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
@@ -8,6 +10,9 @@ import (
|
|||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// NoopTun is an overlay.Device that silently discards every read and write.
|
||||||
|
// Useful in tests that need to construct a nebula Interface but do not
|
||||||
|
// exercise the datapath.
|
||||||
type NoopTun struct{}
|
type NoopTun struct{}
|
||||||
|
|
||||||
func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways {
|
func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways {
|
||||||
@@ -2,6 +2,7 @@ package overlay
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -9,7 +10,6 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
@@ -48,11 +48,14 @@ func (r Route) String() string {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) {
|
func makeRouteTree(l *slog.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) {
|
||||||
routeTree := new(bart.Table[routing.Gateways])
|
routeTree := new(bart.Table[routing.Gateways])
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !allowMTU && r.MTU > 0 {
|
if !allowMTU && r.MTU > 0 {
|
||||||
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
|
l.Warn("route MTU is not supported on this platform",
|
||||||
|
"goos", runtime.GOOS,
|
||||||
|
"route", r,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
gateways := r.Via
|
gateways := r.Via
|
||||||
|
|||||||
@@ -295,7 +295,7 @@ func Test_makeRouteTree(t *testing.T) {
|
|||||||
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, routes, 2)
|
assert.Len(t, routes, 2)
|
||||||
routeTree, err := makeRouteTree(l, routes, true)
|
routeTree, err := makeRouteTree(test.NewLogger(), routes, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ip, err := netip.ParseAddr("1.0.0.2")
|
ip, err := netip.ParseAddr("1.0.0.2")
|
||||||
@@ -367,7 +367,7 @@ func Test_makeMultipathUnsafeRouteTree(t *testing.T) {
|
|||||||
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, routes, 3)
|
assert.Len(t, routes, 3)
|
||||||
routeTree, err := makeRouteTree(l, routes, true)
|
routeTree, err := makeRouteTree(test.NewLogger(), routes, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ip, err := netip.ParseAddr("192.168.86.1")
|
ip, err := netip.ParseAddr("192.168.86.1")
|
||||||
|
|||||||
@@ -2,10 +2,10 @@ package overlay
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
)
|
)
|
||||||
@@ -22,9 +22,9 @@ func (e *NameError) Error() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: We may be able to remove routines
|
// TODO: We may be able to remove routines
|
||||||
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
type DeviceFactory func(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
||||||
|
|
||||||
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
func NewDeviceFromConfig(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||||
switch {
|
switch {
|
||||||
case c.GetBool("tun.disabled", false):
|
case c.GetBool("tun.disabled", false):
|
||||||
tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
|
tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
|
||||||
@@ -36,7 +36,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Pref
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
|
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
|
||||||
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
return func(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||||
return newTunFromFd(c, l, *fd, vpnNetworks)
|
return newTunFromFd(c, l, *fd, vpnNetworks)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,12 +6,12 @@ package overlay
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
@@ -23,10 +23,10 @@ type tun struct {
|
|||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
|
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
|
||||||
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
|
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
|
||||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||||
@@ -53,7 +53,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||||
return nil, fmt.Errorf("newTun not supported in Android")
|
return nil, fmt.Errorf("newTun not supported in Android")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -14,7 +15,6 @@ import (
|
|||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
@@ -30,7 +30,7 @@ type tun struct {
|
|||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
linkAddr *netroute.LinkAddr
|
linkAddr *netroute.LinkAddr
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
|
|
||||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||||
out []byte
|
out []byte
|
||||||
@@ -79,7 +79,7 @@ type ifreqAlias6 struct {
|
|||||||
Lifetime addrLifetime
|
Lifetime addrLifetime
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||||
name := c.GetString("tun.dev", "")
|
name := c.GetString("tun.dev", "")
|
||||||
ifIndex := -1
|
ifIndex := -1
|
||||||
if name != "" && name != "utun" {
|
if name != "" && name != "utun" {
|
||||||
@@ -153,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
|
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -389,8 +389,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
err := addRoute(r.Cidr, t.linkAddr)
|
err := addRoute(r.Cidr, t.linkAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, unix.EEXIST) {
|
if errors.Is(err, unix.EEXIST) {
|
||||||
t.l.WithField("route", r.Cidr).
|
t.l.Warn("unable to add unsafe_route, identical route already exists", "route", r.Cidr)
|
||||||
Warnf("unable to add unsafe_route, identical route already exists")
|
|
||||||
} else {
|
} else {
|
||||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||||
if logErrors {
|
if logErrors {
|
||||||
@@ -400,7 +399,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Added route")
|
t.l.Info("Added route", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -415,9 +414,9 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||||||
|
|
||||||
err := delRoute(r.Cidr, t.linkAddr)
|
err := delRoute(r.Cidr, t.linkAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.Error("Failed to remove route", "error", err, "route", r)
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.Info("Removed route", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
@@ -19,10 +20,10 @@ type disabledTun struct {
|
|||||||
// Track these metrics since we don't have the tun device to do it for us
|
// Track these metrics since we don't have the tun device to do it for us
|
||||||
tx metrics.Counter
|
tx metrics.Counter
|
||||||
rx metrics.Counter
|
rx metrics.Counter
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
|
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *slog.Logger) *disabledTun {
|
||||||
tun := &disabledTun{
|
tun := &disabledTun{
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
read: make(chan []byte, queueLen),
|
read: make(chan []byte, queueLen),
|
||||||
@@ -67,8 +68,8 @@ func (t *disabledTun) Read(b []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
t.tx.Inc(1)
|
t.tx.Inc(1)
|
||||||
if t.l.Level >= logrus.DebugLevel {
|
if t.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
t.l.WithField("raw", prettyPacket(r)).Debugf("Write payload")
|
t.l.Debug("Write payload", "raw", prettyPacket(r))
|
||||||
}
|
}
|
||||||
|
|
||||||
return copy(b, r), nil
|
return copy(b, r), nil
|
||||||
@@ -85,7 +86,7 @@ func (t *disabledTun) handleICMPEchoRequest(b []byte) bool {
|
|||||||
select {
|
select {
|
||||||
case t.read <- out:
|
case t.read <- out:
|
||||||
default:
|
default:
|
||||||
t.l.Debugf("tun_disabled: dropped ICMP Echo Reply response")
|
t.l.Debug("tun_disabled: dropped ICMP Echo Reply response")
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
@@ -96,11 +97,11 @@ func (t *disabledTun) Write(b []byte) (int, error) {
|
|||||||
|
|
||||||
// Check for ICMP Echo Request before spending time doing the full parsing
|
// Check for ICMP Echo Request before spending time doing the full parsing
|
||||||
if t.handleICMPEchoRequest(b) {
|
if t.handleICMPEchoRequest(b) {
|
||||||
if t.l.Level >= logrus.DebugLevel {
|
if t.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request")
|
t.l.Debug("Disabled tun responded to ICMP Echo Request", "raw", prettyPacket(b))
|
||||||
}
|
}
|
||||||
} else if t.l.Level >= logrus.DebugLevel {
|
} else if t.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload")
|
t.l.Debug("Disabled tun received unexpected payload", "raw", prettyPacket(b))
|
||||||
}
|
}
|
||||||
return len(b), nil
|
return len(b), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -17,8 +18,9 @@ import (
|
|||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
netroute "golang.org/x/net/route"
|
netroute "golang.org/x/net/route"
|
||||||
@@ -93,7 +95,7 @@ type tun struct {
|
|||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
linkAddr *netroute.LinkAddr
|
linkAddr *netroute.LinkAddr
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
|
|
||||||
fd int
|
fd int
|
||||||
shutdownR int // read end of the shutdown pipe; closing the write end wakes blocked polls
|
shutdownR int // read end of the shutdown pipe; closing the write end wakes blocked polls
|
||||||
@@ -243,7 +245,7 @@ func (t *tun) Close() error {
|
|||||||
|
|
||||||
if t.fd >= 0 {
|
if t.fd >= 0 {
|
||||||
if err := unix.Close(t.fd); err != nil {
|
if err := unix.Close(t.fd); err != nil {
|
||||||
t.l.WithError(err).Error("Error closing device")
|
t.l.Error("Error closing device", "error", err)
|
||||||
}
|
}
|
||||||
t.fd = -1
|
t.fd = -1
|
||||||
}
|
}
|
||||||
@@ -264,7 +266,7 @@ func (t *tun) Close() error {
|
|||||||
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).Error("Error destroying tunnel")
|
t.l.Error("Error destroying tunnel", "error", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -277,11 +279,11 @@ func (t *tun) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
|
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||||
// Try to open existing tun device
|
// Try to open existing tun device
|
||||||
var fd int
|
var fd int
|
||||||
var err error
|
var err error
|
||||||
@@ -584,7 +586,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
return retErr
|
return retErr
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Added route")
|
t.l.Info("Added route", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -599,9 +601,9 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||||||
|
|
||||||
err := delRoute(r.Cidr, t.linkAddr)
|
err := delRoute(r.Cidr, t.linkAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.Error("Failed to remove route", "error", err, "route", r)
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.Info("Removed route", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -14,7 +15,6 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
@@ -25,14 +25,14 @@ type tun struct {
|
|||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||||
return nil, fmt.Errorf("newTun not supported in iOS")
|
return nil, fmt.Errorf("newTun not supported in iOS")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
|
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
|
||||||
t := &tun{
|
t := &tun{
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
@@ -17,7 +18,6 @@ import (
|
|||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
@@ -213,7 +213,7 @@ type tun struct {
|
|||||||
routesFromSystem map[netip.Prefix]routing.Gateways
|
routesFromSystem map[netip.Prefix]routing.Gateways
|
||||||
routesFromSystemLock sync.Mutex
|
routesFromSystemLock sync.Mutex
|
||||||
|
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Networks() []netip.Prefix {
|
func (t *tun) Networks() []netip.Prefix {
|
||||||
@@ -238,7 +238,7 @@ type ifreqQLEN struct {
|
|||||||
pad [8]byte
|
pad [8]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
t, err := newTunGeneric(c, l, deviceFd, vpnNetworks)
|
t, err := newTunGeneric(c, l, deviceFd, vpnNetworks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -249,7 +249,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
||||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
|
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
|
||||||
@@ -299,7 +299,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|||||||
}
|
}
|
||||||
|
|
||||||
// newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error.
|
// newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error.
|
||||||
func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
func newTunGeneric(c *config.C, l *slog.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
tfd, err := newTunFd(fd)
|
tfd, err := newTunFd(fd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = unix.Close(fd)
|
_ = unix.Close(fd)
|
||||||
@@ -378,16 +378,16 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||||||
if !initial {
|
if !initial {
|
||||||
if oldMaxMTU != newMaxMTU {
|
if oldMaxMTU != newMaxMTU {
|
||||||
t.setMTU()
|
t.setMTU()
|
||||||
t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU)
|
t.l.Info("Set max MTU", "mtu", t.MaxMTU, "oldMTU", oldMaxMTU)
|
||||||
}
|
}
|
||||||
|
|
||||||
if oldDefaultMTU != newDefaultMTU {
|
if oldDefaultMTU != newDefaultMTU {
|
||||||
for i := range t.vpnNetworks {
|
for i := range t.vpnNetworks {
|
||||||
err := t.setDefaultRoute(t.vpnNetworks[i])
|
err := t.setDefaultRoute(t.vpnNetworks[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.Warn(err)
|
t.l.Warn(err.Error())
|
||||||
} else {
|
} else {
|
||||||
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
|
t.l.Info("Set default MTU", "mtu", t.DefaultMTU, "oldMTU", oldDefaultMTU)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -492,9 +492,9 @@ func (t *tun) addIPs(link netlink.Link) error {
|
|||||||
}
|
}
|
||||||
err = netlink.AddrDel(link, &al[i])
|
err = netlink.AddrDel(link, &al[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).Error("failed to remove address from tun address list")
|
t.l.Error("failed to remove address from tun address list", "error", err)
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
|
t.l.Info("removed address not listed in cert(s)", "removed", al[i].String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -538,12 +538,12 @@ func (t *tun) Activate() error {
|
|||||||
ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
|
ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
|
||||||
if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
|
if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
|
||||||
// If we can't set the queue length nebula will still work but it may lead to packet loss
|
// If we can't set the queue length nebula will still work but it may lead to packet loss
|
||||||
t.l.WithError(err).Error("Failed to set tun tx queue length")
|
t.l.Error("Failed to set tun tx queue length", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
const modeNone = 1
|
const modeNone = 1
|
||||||
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
|
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
|
||||||
t.l.WithError(err).Warn("Failed to disable link local address generation")
|
t.l.Warn("Failed to disable link local address generation", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = t.addIPs(link); err != nil {
|
if err = t.addIPs(link); err != nil {
|
||||||
@@ -582,7 +582,7 @@ func (t *tun) setMTU() {
|
|||||||
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)}
|
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)}
|
||||||
if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
|
if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
|
||||||
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well
|
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well
|
||||||
t.l.WithError(err).Error("Failed to set tun mtu")
|
t.l.Error("Failed to set tun mtu", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -605,7 +605,7 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
|
|||||||
}
|
}
|
||||||
err := netlink.RouteReplace(&nr)
|
err := netlink.RouteReplace(&nr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying")
|
t.l.Warn("Failed to set default route MTU, retrying", "error", err, "cidr", cidr)
|
||||||
//retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument`
|
//retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument`
|
||||||
for i := 0; i < 2; i++ {
|
for i := 0; i < 2; i++ {
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
@@ -613,7 +613,11 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
break
|
break
|
||||||
} else {
|
} else {
|
||||||
t.l.WithError(err).WithField("cidr", cidr).WithField("mtu", t.DefaultMTU).Warn("Failed to set default route MTU, retrying")
|
t.l.Warn("Failed to set default route MTU, retrying",
|
||||||
|
"error", err,
|
||||||
|
"cidr", cidr,
|
||||||
|
"mtu", t.DefaultMTU,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -658,7 +662,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
return retErr
|
return retErr
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Added route")
|
t.l.Info("Added route", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -690,9 +694,9 @@ func (t *tun) removeRoutes(routes []Route) {
|
|||||||
|
|
||||||
err := netlink.RouteDel(&nr)
|
err := netlink.RouteDel(&nr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.Error("Failed to remove route", "error", err, "route", r)
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.Info("Removed route", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -721,11 +725,11 @@ func (t *tun) watchRoutes() {
|
|||||||
netlinkOptions := netlink.RouteSubscribeOptions{
|
netlinkOptions := netlink.RouteSubscribeOptions{
|
||||||
ReceiveBufferSize: t.useSystemRoutesBufferSize,
|
ReceiveBufferSize: t.useSystemRoutesBufferSize,
|
||||||
ReceiveBufferForceSize: t.useSystemRoutesBufferSize != 0,
|
ReceiveBufferForceSize: t.useSystemRoutesBufferSize != 0,
|
||||||
ErrorCallback: func(e error) { t.l.WithError(e).Errorf("netlink error") },
|
ErrorCallback: func(e error) { t.l.Error("netlink error", "error", e) },
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := netlink.RouteSubscribeWithOptions(rch, doneChan, netlinkOptions); err != nil {
|
if err := netlink.RouteSubscribeWithOptions(rch, doneChan, netlinkOptions); err != nil {
|
||||||
t.l.WithError(err).Errorf("failed to subscribe to system route changes")
|
t.l.Error("failed to subscribe to system route changes", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -767,7 +771,7 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
|
|||||||
|
|
||||||
link, err := netlink.LinkByName(t.Device)
|
link, err := netlink.LinkByName(t.Device)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithField("deviceName", t.Device).Error("Ignoring route update: failed to get link by name")
|
t.l.Error("Ignoring route update: failed to get link by name", "deviceName", t.Device)
|
||||||
return gateways
|
return gateways
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -779,10 +783,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
|
|||||||
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
|
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
|
||||||
} else {
|
} else {
|
||||||
// Gateway isn't in our overlay network, ignore
|
// Gateway isn't in our overlay network, ignore
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
|
t.l.Debug("Ignoring route update, gateway is not in our network", "route", r)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
|
t.l.Debug("Ignoring route update, invalid gateway or via address", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -795,10 +799,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
|
|||||||
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
|
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
|
||||||
} else {
|
} else {
|
||||||
// Gateway isn't in our overlay network, ignore
|
// Gateway isn't in our overlay network, ignore
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
|
t.l.Debug("Ignoring route update, gateway is not in our network", "route", r)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
|
t.l.Debug("Ignoring route update, invalid gateway or via address", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -830,18 +834,18 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
|||||||
gateways := t.getGatewaysFromRoute(&r.Route)
|
gateways := t.getGatewaysFromRoute(&r.Route)
|
||||||
if len(gateways) == 0 {
|
if len(gateways) == 0 {
|
||||||
// No gateways relevant to our network, no routing changes required.
|
// No gateways relevant to our network, no routing changes required.
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
|
t.l.Debug("Ignoring route update, no gateways", "route", r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Dst == nil {
|
if r.Dst == nil {
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, no destination address")
|
t.l.Debug("Ignoring route update, no destination address", "route", r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
|
dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address")
|
t.l.Debug("Ignoring route update, invalid destination address", "route", r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -852,12 +856,12 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
|||||||
|
|
||||||
t.routesFromSystemLock.Lock()
|
t.routesFromSystemLock.Lock()
|
||||||
if r.Type == unix.RTM_NEWROUTE {
|
if r.Type == unix.RTM_NEWROUTE {
|
||||||
t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route")
|
t.l.Info("Adding route", "destination", dst, "via", gateways)
|
||||||
t.routesFromSystem[dst] = gateways
|
t.routesFromSystem[dst] = gateways
|
||||||
newTree.Insert(dst, gateways)
|
newTree.Insert(dst, gateways)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route")
|
t.l.Info("Removing route", "destination", dst, "via", gateways)
|
||||||
delete(t.routesFromSystem, dst)
|
delete(t.routesFromSystem, dst)
|
||||||
newTree.Delete(dst)
|
newTree.Delete(dst)
|
||||||
}
|
}
|
||||||
@@ -888,18 +892,18 @@ func (t *tun) Close() error {
|
|||||||
}
|
}
|
||||||
err := t.readers[i].Close()
|
err := t.readers[i].Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithField("reader", i).WithError(err).Error("error closing tun reader")
|
t.l.Error("error closing tun reader", "reader", i, "error", err)
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("reader", i).Info("closed tun reader")
|
t.l.Info("closed tun reader", "reader", i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//this is t.readers[0] too
|
//this is t.readers[0] too
|
||||||
err := t.tunFile.Close()
|
err := t.tunFile.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithField("reader", 0).WithError(err).Error("error closing tun reader")
|
t.l.Error("error closing tun reader", "reader", 0, "error", err)
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("reader", 0).Info("closed tun reader")
|
t.l.Info("closed tun reader", "reader", 0)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
@@ -15,7 +16,6 @@ import (
|
|||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
@@ -63,18 +63,18 @@ type tun struct {
|
|||||||
MTU int
|
MTU int
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
f *os.File
|
f *os.File
|
||||||
fd int
|
fd int
|
||||||
}
|
}
|
||||||
|
|
||||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
|
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||||
// Try to open tun device
|
// Try to open tun device
|
||||||
var err error
|
var err error
|
||||||
deviceName := c.GetString("tun.dev", "")
|
deviceName := c.GetString("tun.dev", "")
|
||||||
@@ -92,7 +92,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
|
|
||||||
err = unix.SetNonblock(fd, true)
|
err = unix.SetNonblock(fd, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
|
l.Warn("Failed to set the tun device as nonblocking", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t := &tun{
|
t := &tun{
|
||||||
@@ -416,7 +416,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
return retErr
|
return retErr
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Added route")
|
t.l.Info("Added route", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -431,9 +431,9 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||||||
|
|
||||||
err := delRoute(r.Cidr, t.vpnNetworks)
|
err := delRoute(r.Cidr, t.vpnNetworks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.Error("Failed to remove route", "error", err, "route", r)
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.Info("Removed route", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
@@ -15,7 +16,6 @@ import (
|
|||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
@@ -54,7 +54,7 @@ type tun struct {
|
|||||||
MTU int
|
MTU int
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
f *os.File
|
f *os.File
|
||||||
fd int
|
fd int
|
||||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||||
@@ -63,11 +63,11 @@ type tun struct {
|
|||||||
|
|
||||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
|
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||||
// Try to open tun device
|
// Try to open tun device
|
||||||
var err error
|
var err error
|
||||||
deviceName := c.GetString("tun.dev", "")
|
deviceName := c.GetString("tun.dev", "")
|
||||||
@@ -85,7 +85,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
|
|
||||||
err = unix.SetNonblock(fd, true)
|
err = unix.SetNonblock(fd, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
|
l.Warn("Failed to set the tun device as nonblocking", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t := &tun{
|
t := &tun{
|
||||||
@@ -336,7 +336,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
return retErr
|
return retErr
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Added route")
|
t.l.Info("Added route", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -351,9 +351,9 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||||||
|
|
||||||
err := delRoute(r.Cidr, t.vpnNetworks)
|
err := delRoute(r.Cidr, t.vpnNetworks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.Error("Failed to remove route", "error", err, "route", r)
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.Info("Removed route", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -4,14 +4,15 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
@@ -21,14 +22,14 @@ type TestTun struct {
|
|||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
Routes []Route
|
Routes []Route
|
||||||
routeTree *bart.Table[routing.Gateways]
|
routeTree *bart.Table[routing.Gateways]
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
|
|
||||||
closed atomic.Bool
|
closed atomic.Bool
|
||||||
rxPackets chan []byte // Packets to receive into nebula
|
rxPackets chan []byte // Packets to receive into nebula
|
||||||
TxPackets chan []byte // Packets transmitted outside by nebula
|
TxPackets chan []byte // Packets transmitted outside by nebula
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
|
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
|
||||||
_, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true)
|
_, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -49,7 +50,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) {
|
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*TestTun, error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported")
|
return nil, fmt.Errorf("newTunFromFd not supported")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -61,8 +62,8 @@ func (t *TestTun) Send(packet []byte) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if t.l.Level >= logrus.DebugLevel {
|
if t.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
t.l.WithField("dataLen", len(packet)).Debug("Tun receiving injected packet")
|
t.l.Debug("Tun receiving injected packet", "dataLen", len(packet))
|
||||||
}
|
}
|
||||||
t.rxPackets <- packet
|
t.rxPackets <- packet
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"crypto"
|
"crypto"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -16,7 +17,6 @@ import (
|
|||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
@@ -33,16 +33,16 @@ type winTun struct {
|
|||||||
MTU int
|
MTU int
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
|
|
||||||
tun *wintun.NativeTun
|
tun *wintun.NativeTun
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) {
|
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (Device, error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
|
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
|
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
|
||||||
err := checkWinTunExists()
|
err := checkWinTunExists()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("can not load the wintun driver: %w", err)
|
return nil, fmt.Errorf("can not load the wintun driver: %w", err)
|
||||||
@@ -71,7 +71,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
|
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
|
||||||
// Trying a second time resolves the issue.
|
// Trying a second time resolves the issue.
|
||||||
l.WithError(err).Debug("Failed to create wintun device, retrying")
|
l.Debug("Failed to create wintun device, retrying", "error", err)
|
||||||
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
|
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &NameError{
|
return nil, &NameError{
|
||||||
@@ -170,7 +170,7 @@ func (t *winTun) addRoutes(logErrors bool) error {
|
|||||||
return retErr
|
return retErr
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Added route")
|
t.l.Info("Added route", "route", r)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !foundDefault4 {
|
if !foundDefault4 {
|
||||||
@@ -208,9 +208,9 @@ func (t *winTun) removeRoutes(routes []Route) error {
|
|||||||
// See comment on luid.AddRoute
|
// See comment on luid.AddRoute
|
||||||
err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
|
err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.Error("Failed to remove route", "error", err, "route", r)
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.Info("Removed route", "route", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -2,14 +2,14 @@ package overlay
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
func NewUserDeviceFromConfig(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||||
return NewUserDevice(vpnNetworks)
|
return NewUserDevice(vpnNetworks)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
18
pki.go
18
pki.go
@@ -6,6 +6,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
@@ -15,7 +16,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
@@ -24,7 +24,7 @@ import (
|
|||||||
type PKI struct {
|
type PKI struct {
|
||||||
cs atomic.Pointer[CertState]
|
cs atomic.Pointer[CertState]
|
||||||
caPool atomic.Pointer[cert.CAPool]
|
caPool atomic.Pointer[cert.CAPool]
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type CertState struct {
|
type CertState struct {
|
||||||
@@ -46,7 +46,7 @@ type CertState struct {
|
|||||||
myVpnBroadcastAddrsTable *bart.Lite
|
myVpnBroadcastAddrsTable *bart.Lite
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
|
func NewPKIFromConfig(l *slog.Logger, c *config.C) (*PKI, error) {
|
||||||
pki := &PKI{l: l}
|
pki := &PKI{l: l}
|
||||||
err := pki.reload(c, true)
|
err := pki.reload(c, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -182,9 +182,9 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
|||||||
p.cs.Store(newState)
|
p.cs.Store(newState)
|
||||||
|
|
||||||
if initial {
|
if initial {
|
||||||
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
|
p.l.Debug("Client nebula certificate(s)", "cert", newState)
|
||||||
} else {
|
} else {
|
||||||
p.l.WithField("cert", newState).Info("Client certificate(s) refreshed from disk")
|
p.l.Info("Client certificate(s) refreshed from disk", "cert", newState)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -196,7 +196,7 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
p.caPool.Store(caPool)
|
p.caPool.Store(caPool)
|
||||||
p.l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
|
p.l.Debug("Trusted CA fingerprints", "fingerprints", caPool.GetFingerprints())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -487,7 +487,7 @@ func loadCertificate(b []byte) (cert.Certificate, []byte, error) {
|
|||||||
return c, b, nil
|
return c, b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
|
func loadCAPoolFromConfig(l *slog.Logger, c *config.C) (*cert.CAPool, error) {
|
||||||
caPathOrPEM := c.GetString("pki.ca", "")
|
caPathOrPEM := c.GetString("pki.ca", "")
|
||||||
if caPathOrPEM == "" {
|
if caPathOrPEM == "" {
|
||||||
return nil, errors.New("no pki.ca path or PEM data provided")
|
return nil, errors.New("no pki.ca path or PEM data provided")
|
||||||
@@ -512,7 +512,7 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
|
|||||||
for _, crt := range caPool.CAs {
|
for _, crt := range caPool.CAs {
|
||||||
if crt.Certificate.Expired(time.Now()) {
|
if crt.Certificate.Expired(time.Now()) {
|
||||||
expired++
|
expired++
|
||||||
l.WithField("cert", crt).Warn("expired certificate present in CA pool")
|
l.Warn("expired certificate present in CA pool", "cert", crt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -530,7 +530,7 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
|
|||||||
caPool.BlocklistFingerprint(fp)
|
caPool.BlocklistFingerprint(fp)
|
||||||
}
|
}
|
||||||
|
|
||||||
l.WithField("fingerprintCount", len(bl)).Info("Blocklisted certificates")
|
l.Info("Blocklisted certificates", "fingerprintCount", len(bl))
|
||||||
}
|
}
|
||||||
|
|
||||||
return caPool, nil
|
return caPool, nil
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ func BenchmarkReloadConfigWithCAs(b *testing.B) {
|
|||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
require.NoError(b, c.Load(dir))
|
require.NoError(b, c.Load(dir))
|
||||||
|
|
||||||
_, err := NewPKIFromConfig(l, c)
|
_, err := NewPKIFromConfig(test.NewLogger(), c)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
|
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
|
|||||||
14
punchy.go
14
punchy.go
@@ -1,10 +1,10 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"log/slog"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -14,10 +14,10 @@ type Punchy struct {
|
|||||||
delay atomic.Int64
|
delay atomic.Int64
|
||||||
respondDelay atomic.Int64
|
respondDelay atomic.Int64
|
||||||
punchEverything atomic.Bool
|
punchEverything atomic.Bool
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPunchyFromConfig(l *logrus.Logger, c *config.C) *Punchy {
|
func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy {
|
||||||
p := &Punchy{l: l}
|
p := &Punchy{l: l}
|
||||||
|
|
||||||
p.reload(c, true)
|
p.reload(c, true)
|
||||||
@@ -62,7 +62,7 @@ func (p *Punchy) reload(c *config.C, initial bool) {
|
|||||||
p.respond.Store(yes)
|
p.respond.Store(yes)
|
||||||
|
|
||||||
if !initial {
|
if !initial {
|
||||||
p.l.Infof("punchy.respond changed to %v", p.GetRespond())
|
p.l.Info("punchy.respond changed", "respond", p.GetRespond())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,21 +70,21 @@ func (p *Punchy) reload(c *config.C, initial bool) {
|
|||||||
if initial || c.HasChanged("punchy.delay") {
|
if initial || c.HasChanged("punchy.delay") {
|
||||||
p.delay.Store((int64)(c.GetDuration("punchy.delay", time.Second)))
|
p.delay.Store((int64)(c.GetDuration("punchy.delay", time.Second)))
|
||||||
if !initial {
|
if !initial {
|
||||||
p.l.Infof("punchy.delay changed to %s", p.GetDelay())
|
p.l.Info("punchy.delay changed", "delay", p.GetDelay())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if initial || c.HasChanged("punchy.target_all_remotes") {
|
if initial || c.HasChanged("punchy.target_all_remotes") {
|
||||||
p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false))
|
p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false))
|
||||||
if !initial {
|
if !initial {
|
||||||
p.l.WithField("target_all_remotes", p.GetTargetEverything()).Info("punchy.target_all_remotes changed")
|
p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", p.GetTargetEverything())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if initial || c.HasChanged("punchy.respond_delay") {
|
if initial || c.HasChanged("punchy.respond_delay") {
|
||||||
p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second)))
|
p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second)))
|
||||||
if !initial {
|
if !initial {
|
||||||
p.l.Infof("punchy.respond_delay changed to %s", p.GetRespondDelay())
|
p.l.Info("punchy.respond_delay changed", "respond_delay", p.GetRespondDelay())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
173
punchy_test.go
173
punchy_test.go
@@ -1,6 +1,8 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -15,7 +17,7 @@ func TestNewPunchyFromConfig(t *testing.T) {
|
|||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
|
|
||||||
// Test defaults
|
// Test defaults
|
||||||
p := NewPunchyFromConfig(l, c)
|
p := NewPunchyFromConfig(test.NewLogger(), c)
|
||||||
assert.False(t, p.GetPunch())
|
assert.False(t, p.GetPunch())
|
||||||
assert.False(t, p.GetRespond())
|
assert.False(t, p.GetRespond())
|
||||||
assert.Equal(t, time.Second, p.GetDelay())
|
assert.Equal(t, time.Second, p.GetDelay())
|
||||||
@@ -23,33 +25,33 @@ func TestNewPunchyFromConfig(t *testing.T) {
|
|||||||
|
|
||||||
// punchy deprecation
|
// punchy deprecation
|
||||||
c.Settings["punchy"] = true
|
c.Settings["punchy"] = true
|
||||||
p = NewPunchyFromConfig(l, c)
|
p = NewPunchyFromConfig(test.NewLogger(), c)
|
||||||
assert.True(t, p.GetPunch())
|
assert.True(t, p.GetPunch())
|
||||||
|
|
||||||
// punchy.punch
|
// punchy.punch
|
||||||
c.Settings["punchy"] = map[string]any{"punch": true}
|
c.Settings["punchy"] = map[string]any{"punch": true}
|
||||||
p = NewPunchyFromConfig(l, c)
|
p = NewPunchyFromConfig(test.NewLogger(), c)
|
||||||
assert.True(t, p.GetPunch())
|
assert.True(t, p.GetPunch())
|
||||||
|
|
||||||
// punch_back deprecation
|
// punch_back deprecation
|
||||||
c.Settings["punch_back"] = true
|
c.Settings["punch_back"] = true
|
||||||
p = NewPunchyFromConfig(l, c)
|
p = NewPunchyFromConfig(test.NewLogger(), c)
|
||||||
assert.True(t, p.GetRespond())
|
assert.True(t, p.GetRespond())
|
||||||
|
|
||||||
// punchy.respond
|
// punchy.respond
|
||||||
c.Settings["punchy"] = map[string]any{"respond": true}
|
c.Settings["punchy"] = map[string]any{"respond": true}
|
||||||
c.Settings["punch_back"] = false
|
c.Settings["punch_back"] = false
|
||||||
p = NewPunchyFromConfig(l, c)
|
p = NewPunchyFromConfig(test.NewLogger(), c)
|
||||||
assert.True(t, p.GetRespond())
|
assert.True(t, p.GetRespond())
|
||||||
|
|
||||||
// punchy.delay
|
// punchy.delay
|
||||||
c.Settings["punchy"] = map[string]any{"delay": "1m"}
|
c.Settings["punchy"] = map[string]any{"delay": "1m"}
|
||||||
p = NewPunchyFromConfig(l, c)
|
p = NewPunchyFromConfig(test.NewLogger(), c)
|
||||||
assert.Equal(t, time.Minute, p.GetDelay())
|
assert.Equal(t, time.Minute, p.GetDelay())
|
||||||
|
|
||||||
// punchy.respond_delay
|
// punchy.respond_delay
|
||||||
c.Settings["punchy"] = map[string]any{"respond_delay": "1m"}
|
c.Settings["punchy"] = map[string]any{"respond_delay": "1m"}
|
||||||
p = NewPunchyFromConfig(l, c)
|
p = NewPunchyFromConfig(test.NewLogger(), c)
|
||||||
assert.Equal(t, time.Minute, p.GetRespondDelay())
|
assert.Equal(t, time.Minute, p.GetRespondDelay())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,7 +64,7 @@ punchy:
|
|||||||
delay: 1m
|
delay: 1m
|
||||||
respond: false
|
respond: false
|
||||||
`))
|
`))
|
||||||
p := NewPunchyFromConfig(l, c)
|
p := NewPunchyFromConfig(test.NewLogger(), c)
|
||||||
assert.Equal(t, delay, p.GetDelay())
|
assert.Equal(t, delay, p.GetDelay())
|
||||||
assert.False(t, p.GetRespond())
|
assert.False(t, p.GetRespond())
|
||||||
|
|
||||||
@@ -76,3 +78,158 @@ punchy:
|
|||||||
assert.Equal(t, newDelay, p.GetDelay())
|
assert.Equal(t, newDelay, p.GetDelay())
|
||||||
assert.True(t, p.GetRespond())
|
assert.True(t, p.GetRespond())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The tests below pin the shape of each log line Punchy produces so changes
|
||||||
|
// cannot silently break whatever operators are grepping for. The assertions
|
||||||
|
// are on the structured message + attrs (e.g. "punchy.respond changed" with
|
||||||
|
// a respond=true field) rather than a formatted string.
|
||||||
|
//
|
||||||
|
// Punchy.reload also emits a spurious "Changing punchy.punch with reload is
|
||||||
|
// not supported" warning whenever any key under punchy changes, because of
|
||||||
|
// the c.HasChanged("punchy") fallback kept for the deprecated top-level
|
||||||
|
// punchy form. The tests filter by message rather than asserting total
|
||||||
|
// entry counts so that warning is tolerated without being locked into
|
||||||
|
// the format.
|
||||||
|
|
||||||
|
type capturedEntry struct {
|
||||||
|
Level slog.Level
|
||||||
|
Msg string
|
||||||
|
Attrs map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
// capturingHandler is a slog.Handler that records each Record it receives so
|
||||||
|
// tests can assert on the level, message, and attribute map of individual log
|
||||||
|
// lines without coupling to any specific text format.
|
||||||
|
type capturingHandler struct {
|
||||||
|
entries []capturedEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *capturingHandler) Enabled(_ context.Context, _ slog.Level) bool { return true }
|
||||||
|
|
||||||
|
func (h *capturingHandler) Handle(_ context.Context, r slog.Record) error {
|
||||||
|
e := capturedEntry{
|
||||||
|
Level: r.Level,
|
||||||
|
Msg: r.Message,
|
||||||
|
Attrs: make(map[string]any),
|
||||||
|
}
|
||||||
|
r.Attrs(func(a slog.Attr) bool {
|
||||||
|
e.Attrs[a.Key] = a.Value.Resolve().Any()
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
h.entries = append(h.entries, e)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *capturingHandler) WithAttrs(_ []slog.Attr) slog.Handler { return h }
|
||||||
|
func (h *capturingHandler) WithGroup(_ string) slog.Handler { return h }
|
||||||
|
|
||||||
|
func newCapturingPunchyLogger(t *testing.T) (*slog.Logger, *capturingHandler) {
|
||||||
|
t.Helper()
|
||||||
|
hook := &capturingHandler{}
|
||||||
|
return slog.New(hook), hook
|
||||||
|
}
|
||||||
|
|
||||||
|
func findEntry(t *testing.T, entries []capturedEntry, msg string) capturedEntry {
|
||||||
|
t.Helper()
|
||||||
|
for _, e := range entries {
|
||||||
|
if e.Msg == msg {
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Fatalf("no entry with message %q among %d entries", msg, len(entries))
|
||||||
|
return capturedEntry{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPunchy_LogFormat_InitialEnabled(t *testing.T) {
|
||||||
|
l, hook := newCapturingPunchyLogger(t)
|
||||||
|
c := config.NewC(test.NewLogger())
|
||||||
|
require.NoError(t, c.LoadString(`punchy: {punch: true}`))
|
||||||
|
|
||||||
|
NewPunchyFromConfig(l, c)
|
||||||
|
|
||||||
|
entry := findEntry(t, hook.entries, "punchy enabled")
|
||||||
|
assert.Equal(t, slog.LevelInfo, entry.Level)
|
||||||
|
assert.Empty(t, entry.Attrs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPunchy_LogFormat_InitialDisabled(t *testing.T) {
|
||||||
|
l, hook := newCapturingPunchyLogger(t)
|
||||||
|
c := config.NewC(test.NewLogger())
|
||||||
|
require.NoError(t, c.LoadString(`punchy: {punch: false}`))
|
||||||
|
|
||||||
|
NewPunchyFromConfig(l, c)
|
||||||
|
|
||||||
|
entry := findEntry(t, hook.entries, "punchy disabled")
|
||||||
|
assert.Equal(t, slog.LevelInfo, entry.Level)
|
||||||
|
assert.Empty(t, entry.Attrs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPunchy_LogFormat_ReloadPunchUnsupported(t *testing.T) {
|
||||||
|
l, hook := newCapturingPunchyLogger(t)
|
||||||
|
c := config.NewC(test.NewLogger())
|
||||||
|
require.NoError(t, c.LoadString(`punchy: {punch: false}`))
|
||||||
|
NewPunchyFromConfig(l, c)
|
||||||
|
hook.entries = nil
|
||||||
|
|
||||||
|
require.NoError(t, c.ReloadConfigString(`punchy: {punch: true}`))
|
||||||
|
|
||||||
|
entry := findEntry(t, hook.entries, "Changing punchy.punch with reload is not supported, ignoring.")
|
||||||
|
assert.Equal(t, slog.LevelWarn, entry.Level)
|
||||||
|
assert.Empty(t, entry.Attrs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPunchy_LogFormat_ReloadRespond(t *testing.T) {
|
||||||
|
l, hook := newCapturingPunchyLogger(t)
|
||||||
|
c := config.NewC(test.NewLogger())
|
||||||
|
require.NoError(t, c.LoadString(`punchy: {respond: false}`))
|
||||||
|
NewPunchyFromConfig(l, c)
|
||||||
|
hook.entries = nil
|
||||||
|
|
||||||
|
require.NoError(t, c.ReloadConfigString(`punchy: {respond: true}`))
|
||||||
|
|
||||||
|
entry := findEntry(t, hook.entries, "punchy.respond changed")
|
||||||
|
assert.Equal(t, slog.LevelInfo, entry.Level)
|
||||||
|
assert.Equal(t, map[string]any{"respond": true}, entry.Attrs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPunchy_LogFormat_ReloadDelay(t *testing.T) {
|
||||||
|
l, hook := newCapturingPunchyLogger(t)
|
||||||
|
c := config.NewC(test.NewLogger())
|
||||||
|
require.NoError(t, c.LoadString(`punchy: {delay: 1s}`))
|
||||||
|
NewPunchyFromConfig(l, c)
|
||||||
|
hook.entries = nil
|
||||||
|
|
||||||
|
require.NoError(t, c.ReloadConfigString(`punchy: {delay: 10s}`))
|
||||||
|
|
||||||
|
entry := findEntry(t, hook.entries, "punchy.delay changed")
|
||||||
|
assert.Equal(t, slog.LevelInfo, entry.Level)
|
||||||
|
assert.Equal(t, map[string]any{"delay": 10 * time.Second}, entry.Attrs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPunchy_LogFormat_ReloadTargetAllRemotes(t *testing.T) {
|
||||||
|
l, hook := newCapturingPunchyLogger(t)
|
||||||
|
c := config.NewC(test.NewLogger())
|
||||||
|
require.NoError(t, c.LoadString(`punchy: {target_all_remotes: false}`))
|
||||||
|
NewPunchyFromConfig(l, c)
|
||||||
|
hook.entries = nil
|
||||||
|
|
||||||
|
require.NoError(t, c.ReloadConfigString(`punchy: {target_all_remotes: true}`))
|
||||||
|
|
||||||
|
entry := findEntry(t, hook.entries, "punchy.target_all_remotes changed")
|
||||||
|
assert.Equal(t, slog.LevelInfo, entry.Level)
|
||||||
|
assert.Equal(t, map[string]any{"target_all_remotes": true}, entry.Attrs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPunchy_LogFormat_ReloadRespondDelay(t *testing.T) {
|
||||||
|
l, hook := newCapturingPunchyLogger(t)
|
||||||
|
c := config.NewC(test.NewLogger())
|
||||||
|
require.NoError(t, c.LoadString(`punchy: {respond_delay: 5s}`))
|
||||||
|
NewPunchyFromConfig(l, c)
|
||||||
|
hook.entries = nil
|
||||||
|
|
||||||
|
require.NoError(t, c.ReloadConfigString(`punchy: {respond_delay: 15s}`))
|
||||||
|
|
||||||
|
entry := findEntry(t, hook.entries, "punchy.respond_delay changed")
|
||||||
|
assert.Equal(t, slog.LevelInfo, entry.Level)
|
||||||
|
assert.Equal(t, map[string]any{"respond_delay": 15 * time.Second}, entry.Attrs)
|
||||||
|
}
|
||||||
|
|||||||
165
relay_manager.go
165
relay_manager.go
@@ -5,22 +5,22 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
)
|
)
|
||||||
|
|
||||||
type relayManager struct {
|
type relayManager struct {
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
hostmap *HostMap
|
hostmap *HostMap
|
||||||
amRelay atomic.Bool
|
amRelay atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRelayManager(ctx context.Context, l *logrus.Logger, hostmap *HostMap, c *config.C) *relayManager {
|
func NewRelayManager(ctx context.Context, l *slog.Logger, hostmap *HostMap, c *config.C) *relayManager {
|
||||||
rm := &relayManager{
|
rm := &relayManager{
|
||||||
l: l,
|
l: l,
|
||||||
hostmap: hostmap,
|
hostmap: hostmap,
|
||||||
@@ -29,7 +29,7 @@ func NewRelayManager(ctx context.Context, l *logrus.Logger, hostmap *HostMap, c
|
|||||||
c.RegisterReloadCallback(func(c *config.C) {
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
err := rm.reload(c, false)
|
err := rm.reload(c, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Error("Failed to reload relay_manager")
|
rm.l.Error("Failed to reload relay_manager", "error", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
return rm
|
return rm
|
||||||
@@ -52,7 +52,7 @@ func (rm *relayManager) setAmRelay(v bool) {
|
|||||||
|
|
||||||
// AddRelay finds an available relay index on the hostmap, and associates the relay info with it.
|
// AddRelay finds an available relay index on the hostmap, and associates the relay info with it.
|
||||||
// relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp.
|
// relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp.
|
||||||
func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) {
|
func AddRelay(l *slog.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) {
|
||||||
hm.Lock()
|
hm.Lock()
|
||||||
defer hm.Unlock()
|
defer hm.Unlock()
|
||||||
for range 32 {
|
for range 32 {
|
||||||
@@ -92,24 +92,24 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti
|
|||||||
func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) {
|
func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) {
|
||||||
relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex)
|
relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex)
|
||||||
if !ok {
|
if !ok {
|
||||||
fields := logrus.Fields{
|
var relayFrom, relayTo any
|
||||||
"relay": relayHostInfo.vpnAddrs[0],
|
|
||||||
"initiatorRelayIndex": m.InitiatorRelayIndex,
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.RelayFromAddr == nil {
|
if m.RelayFromAddr == nil {
|
||||||
fields["relayFrom"] = m.OldRelayFromAddr
|
relayFrom = m.OldRelayFromAddr
|
||||||
} else {
|
} else {
|
||||||
fields["relayFrom"] = m.RelayFromAddr
|
relayFrom = m.RelayFromAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.RelayToAddr == nil {
|
if m.RelayToAddr == nil {
|
||||||
fields["relayTo"] = m.OldRelayToAddr
|
relayTo = m.OldRelayToAddr
|
||||||
} else {
|
} else {
|
||||||
fields["relayTo"] = m.RelayToAddr
|
relayTo = m.RelayToAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
rm.l.WithFields(fields).Info("relayManager failed to update relay")
|
rm.l.Info("relayManager failed to update relay",
|
||||||
|
"relay", relayHostInfo.vpnAddrs[0],
|
||||||
|
"initiatorRelayIndex", m.InitiatorRelayIndex,
|
||||||
|
"relayFrom", relayFrom,
|
||||||
|
"relayTo", relayTo,
|
||||||
|
)
|
||||||
return nil, fmt.Errorf("unknown relay")
|
return nil, fmt.Errorf("unknown relay")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -120,7 +120,7 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) {
|
|||||||
msg := &NebulaControl{}
|
msg := &NebulaControl{}
|
||||||
err := msg.Unmarshal(d)
|
err := msg.Unmarshal(d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger(f.l).WithError(err).Error("Failed to unmarshal control message")
|
h.logger(f.l).Error("Failed to unmarshal control message", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -147,20 +147,20 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) {
|
func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) {
|
||||||
rm.l.WithFields(logrus.Fields{
|
rm.l.Info("handleCreateRelayResponse",
|
||||||
"relayFrom": protoAddrToNetAddr(m.RelayFromAddr),
|
"relayFrom", protoAddrToNetAddr(m.RelayFromAddr),
|
||||||
"relayTo": protoAddrToNetAddr(m.RelayToAddr),
|
"relayTo", protoAddrToNetAddr(m.RelayToAddr),
|
||||||
"initiatorRelayIndex": m.InitiatorRelayIndex,
|
"initiatorRelayIndex", m.InitiatorRelayIndex,
|
||||||
"responderRelayIndex": m.ResponderRelayIndex,
|
"responderRelayIndex", m.ResponderRelayIndex,
|
||||||
"vpnAddrs": h.vpnAddrs}).
|
"vpnAddrs", h.vpnAddrs,
|
||||||
Info("handleCreateRelayResponse")
|
)
|
||||||
|
|
||||||
target := m.RelayToAddr
|
target := m.RelayToAddr
|
||||||
targetAddr := protoAddrToNetAddr(target)
|
targetAddr := protoAddrToNetAddr(target)
|
||||||
|
|
||||||
relay, err := rm.EstablishRelay(h, m)
|
relay, err := rm.EstablishRelay(h, m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rm.l.WithError(err).Error("Failed to update relay for relayTo")
|
rm.l.Error("Failed to update relay for relayTo", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Do I need to complete the relays now?
|
// Do I need to complete the relays now?
|
||||||
@@ -170,12 +170,12 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
|
|||||||
// I'm the middle man. Let the initiator know that the I've established the relay they requested.
|
// I'm the middle man. Let the initiator know that the I've established the relay they requested.
|
||||||
peerHostInfo := rm.hostmap.QueryVpnAddr(relay.PeerAddr)
|
peerHostInfo := rm.hostmap.QueryVpnAddr(relay.PeerAddr)
|
||||||
if peerHostInfo == nil {
|
if peerHostInfo == nil {
|
||||||
rm.l.WithField("relayTo", relay.PeerAddr).Error("Can't find a HostInfo for peer")
|
rm.l.Error("Can't find a HostInfo for peer", "relayTo", relay.PeerAddr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr)
|
peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr)
|
||||||
if !ok {
|
if !ok {
|
||||||
rm.l.WithField("relayTo", peerHostInfo.vpnAddrs[0]).Error("peerRelay does not have Relay state for relayTo")
|
rm.l.Error("peerRelay does not have Relay state for relayTo", "relayTo", peerHostInfo.vpnAddrs[0])
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
switch peerRelay.State {
|
switch peerRelay.State {
|
||||||
@@ -193,12 +193,13 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
|
|||||||
if v == cert.Version1 {
|
if v == cert.Version1 {
|
||||||
peer := peerHostInfo.vpnAddrs[0]
|
peer := peerHostInfo.vpnAddrs[0]
|
||||||
if !peer.Is4() {
|
if !peer.Is4() {
|
||||||
rm.l.WithField("relayFrom", peer).
|
rm.l.Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address",
|
||||||
WithField("relayTo", target).
|
"relayFrom", peer,
|
||||||
WithField("initiatorRelayIndex", resp.InitiatorRelayIndex).
|
"relayTo", target,
|
||||||
WithField("responderRelayIndex", resp.ResponderRelayIndex).
|
"initiatorRelayIndex", resp.InitiatorRelayIndex,
|
||||||
WithField("vpnAddrs", peerHostInfo.vpnAddrs).
|
"responderRelayIndex", resp.ResponderRelayIndex,
|
||||||
Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address")
|
"vpnAddrs", peerHostInfo.vpnAddrs,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -213,17 +214,16 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
|
|||||||
|
|
||||||
msg, err := resp.Marshal()
|
msg, err := resp.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rm.l.WithError(err).
|
rm.l.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err)
|
||||||
Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
|
|
||||||
} else {
|
} else {
|
||||||
f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
rm.l.WithFields(logrus.Fields{
|
rm.l.Info("send CreateRelayResponse",
|
||||||
"relayFrom": resp.RelayFromAddr,
|
"relayFrom", resp.RelayFromAddr,
|
||||||
"relayTo": resp.RelayToAddr,
|
"relayTo", resp.RelayToAddr,
|
||||||
"initiatorRelayIndex": resp.InitiatorRelayIndex,
|
"initiatorRelayIndex", resp.InitiatorRelayIndex,
|
||||||
"responderRelayIndex": resp.ResponderRelayIndex,
|
"responderRelayIndex", resp.ResponderRelayIndex,
|
||||||
"vpnAddrs": peerHostInfo.vpnAddrs}).
|
"vpnAddrs", peerHostInfo.vpnAddrs,
|
||||||
Info("send CreateRelayResponse")
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -232,17 +232,18 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
|||||||
from := protoAddrToNetAddr(m.RelayFromAddr)
|
from := protoAddrToNetAddr(m.RelayFromAddr)
|
||||||
target := protoAddrToNetAddr(m.RelayToAddr)
|
target := protoAddrToNetAddr(m.RelayToAddr)
|
||||||
|
|
||||||
logMsg := rm.l.WithFields(logrus.Fields{
|
logMsg := rm.l.With(
|
||||||
"relayFrom": from,
|
"relayFrom", from,
|
||||||
"relayTo": target,
|
"relayTo", target,
|
||||||
"initiatorRelayIndex": m.InitiatorRelayIndex,
|
"initiatorRelayIndex", m.InitiatorRelayIndex,
|
||||||
"vpnAddrs": h.vpnAddrs})
|
"vpnAddrs", h.vpnAddrs,
|
||||||
|
)
|
||||||
|
|
||||||
logMsg.Info("handleCreateRelayRequest")
|
logMsg.Info("handleCreateRelayRequest")
|
||||||
// Is the source of the relay me? This should never happen, but did happen due to
|
// Is the source of the relay me? This should never happen, but did happen due to
|
||||||
// an issue migrating relays over to newly re-handshaked host info objects.
|
// an issue migrating relays over to newly re-handshaked host info objects.
|
||||||
if f.myVpnAddrsTable.Contains(from) {
|
if f.myVpnAddrsTable.Contains(from) {
|
||||||
logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
|
logMsg.Error("Discarding relay request from myself", "myIP", from)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -261,37 +262,37 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
|||||||
if existingRelay.RemoteIndex != m.InitiatorRelayIndex {
|
if existingRelay.RemoteIndex != m.InitiatorRelayIndex {
|
||||||
// We got a brand new Relay request, because its index is different than what we saw before.
|
// We got a brand new Relay request, because its index is different than what we saw before.
|
||||||
// This should never happen. The peer should never change an index, once created.
|
// This should never happen. The peer should never change an index, once created.
|
||||||
logMsg.WithFields(logrus.Fields{
|
logMsg.Error("Existing relay mismatch with CreateRelayRequest",
|
||||||
"existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
|
"existingRemoteIndex", existingRelay.RemoteIndex)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case Disestablished:
|
case Disestablished:
|
||||||
if existingRelay.RemoteIndex != m.InitiatorRelayIndex {
|
if existingRelay.RemoteIndex != m.InitiatorRelayIndex {
|
||||||
// We got a brand new Relay request, because its index is different than what we saw before.
|
// We got a brand new Relay request, because its index is different than what we saw before.
|
||||||
// This should never happen. The peer should never change an index, once created.
|
// This should never happen. The peer should never change an index, once created.
|
||||||
logMsg.WithFields(logrus.Fields{
|
logMsg.Error("Existing relay mismatch with CreateRelayRequest",
|
||||||
"existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
|
"existingRemoteIndex", existingRelay.RemoteIndex)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Mark the relay as 'Established' because it's safe to use again
|
// Mark the relay as 'Established' because it's safe to use again
|
||||||
h.relayState.UpdateRelayForByIpState(from, Established)
|
h.relayState.UpdateRelayForByIpState(from, Established)
|
||||||
case PeerRequested:
|
case PeerRequested:
|
||||||
// I should never be in this state, because I am terminal, not forwarding.
|
// I should never be in this state, because I am terminal, not forwarding.
|
||||||
logMsg.WithFields(logrus.Fields{
|
logMsg.Error("Unexpected Relay State found",
|
||||||
"existingRemoteIndex": existingRelay.RemoteIndex,
|
"existingRemoteIndex", existingRelay.RemoteIndex,
|
||||||
"state": existingRelay.State}).Error("Unexpected Relay State found")
|
"state", existingRelay.State)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
_, err := AddRelay(rm.l, h, f.hostMap, from, &m.InitiatorRelayIndex, TerminalType, Established)
|
_, err := AddRelay(rm.l, h, f.hostMap, from, &m.InitiatorRelayIndex, TerminalType, Established)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logMsg.WithError(err).Error("Failed to add relay")
|
logMsg.Error("Failed to add relay", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
relay, ok := h.relayState.QueryRelayForByIp(from)
|
relay, ok := h.relayState.QueryRelayForByIp(from)
|
||||||
if !ok {
|
if !ok {
|
||||||
logMsg.WithField("from", from).Error("Relay State not found")
|
logMsg.Error("Relay State not found", "from", from)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -313,17 +314,16 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
|||||||
|
|
||||||
msg, err := resp.Marshal()
|
msg, err := resp.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logMsg.
|
logMsg.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err)
|
||||||
WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
|
|
||||||
} else {
|
} else {
|
||||||
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
|
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
rm.l.WithFields(logrus.Fields{
|
rm.l.Info("send CreateRelayResponse",
|
||||||
"relayFrom": from,
|
"relayFrom", from,
|
||||||
"relayTo": target,
|
"relayTo", target,
|
||||||
"initiatorRelayIndex": resp.InitiatorRelayIndex,
|
"initiatorRelayIndex", resp.InitiatorRelayIndex,
|
||||||
"responderRelayIndex": resp.ResponderRelayIndex,
|
"responderRelayIndex", resp.ResponderRelayIndex,
|
||||||
"vpnAddrs": h.vpnAddrs}).
|
"vpnAddrs", h.vpnAddrs,
|
||||||
Info("send CreateRelayResponse")
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
@@ -363,12 +363,13 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
|||||||
|
|
||||||
if v == cert.Version1 {
|
if v == cert.Version1 {
|
||||||
if !h.vpnAddrs[0].Is4() {
|
if !h.vpnAddrs[0].Is4() {
|
||||||
rm.l.WithField("relayFrom", h.vpnAddrs[0]).
|
rm.l.Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address",
|
||||||
WithField("relayTo", target).
|
"relayFrom", h.vpnAddrs[0],
|
||||||
WithField("initiatorRelayIndex", req.InitiatorRelayIndex).
|
"relayTo", target,
|
||||||
WithField("responderRelayIndex", req.ResponderRelayIndex).
|
"initiatorRelayIndex", req.InitiatorRelayIndex,
|
||||||
WithField("vpnAddr", target).
|
"responderRelayIndex", req.ResponderRelayIndex,
|
||||||
Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address")
|
"vpnAddr", target,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -383,17 +384,16 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
|||||||
|
|
||||||
msg, err := req.Marshal()
|
msg, err := req.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logMsg.
|
logMsg.Error("relayManager Failed to marshal Control message to create relay", "error", err)
|
||||||
WithError(err).Error("relayManager Failed to marshal Control message to create relay")
|
|
||||||
} else {
|
} else {
|
||||||
f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
|
f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
rm.l.WithFields(logrus.Fields{
|
rm.l.Info("send CreateRelayRequest",
|
||||||
"relayFrom": h.vpnAddrs[0],
|
"relayFrom", h.vpnAddrs[0],
|
||||||
"relayTo": target,
|
"relayTo", target,
|
||||||
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
"initiatorRelayIndex", req.InitiatorRelayIndex,
|
||||||
"responderRelayIndex": req.ResponderRelayIndex,
|
"responderRelayIndex", req.ResponderRelayIndex,
|
||||||
"vpnAddr": target}).
|
"vpnAddr", target,
|
||||||
Info("send CreateRelayRequest")
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Also track the half-created Relay state just received
|
// Also track the half-created Relay state just received
|
||||||
@@ -401,8 +401,7 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
|||||||
if !ok {
|
if !ok {
|
||||||
_, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, PeerRequested)
|
_, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, PeerRequested)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logMsg.
|
logMsg.Error("relayManager Failed to allocate a local index for relay", "error", err)
|
||||||
WithError(err).Error("relayManager Failed to allocate a local index for relay")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -10,8 +11,6 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// forEachFunc is used to benefit folks that want to do work inside the lock
|
// forEachFunc is used to benefit folks that want to do work inside the lock
|
||||||
@@ -66,11 +65,11 @@ type hostnamesResults struct {
|
|||||||
network string
|
network string
|
||||||
lookupTimeout time.Duration
|
lookupTimeout time.Duration
|
||||||
cancelFn func()
|
cancelFn func()
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
ips atomic.Pointer[map[netip.AddrPort]struct{}]
|
ips atomic.Pointer[map[netip.AddrPort]struct{}]
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) {
|
func NewHostnameResults(ctx context.Context, l *slog.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) {
|
||||||
r := &hostnamesResults{
|
r := &hostnamesResults{
|
||||||
hostnames: make([]hostnamePort, len(hostPorts)),
|
hostnames: make([]hostnamePort, len(hostPorts)),
|
||||||
network: network,
|
network: network,
|
||||||
@@ -121,7 +120,11 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
|
|||||||
addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, r.network, hostPort.name)
|
addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, r.network, hostPort.name)
|
||||||
timeoutCancel()
|
timeoutCancel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithFields(logrus.Fields{"hostname": hostPort.name, "network": r.network}).WithError(err).Error("DNS resolution failed for static_map host")
|
l.Error("DNS resolution failed for static_map host",
|
||||||
|
"hostname", hostPort.name,
|
||||||
|
"network", r.network,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for _, a := range addrs {
|
for _, a := range addrs {
|
||||||
@@ -145,7 +148,10 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if different {
|
if different {
|
||||||
l.WithFields(logrus.Fields{"origSet": origSet, "newSet": netipAddrs}).Info("DNS results changed for host list")
|
l.Info("DNS results changed for host list",
|
||||||
|
"origSet", origSet,
|
||||||
|
"newSet", netipAddrs,
|
||||||
|
)
|
||||||
r.ips.Store(&netipAddrs)
|
r.ips.Store(&netipAddrs)
|
||||||
onUpdate()
|
onUpdate()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,11 +10,11 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"dario.cat/mergo"
|
"dario.cat/mergo"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/cert_test"
|
"github.com/slackhq/nebula/cert_test"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/logging"
|
||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
"go.yaml.in/yaml/v3"
|
"go.yaml.in/yaml/v3"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
@@ -75,8 +75,7 @@ func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp n
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger := logrus.New()
|
logger := logging.NewLogger(os.Stdout)
|
||||||
logger.Out = os.Stdout
|
|
||||||
|
|
||||||
control, err := nebula.Main(&c, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
|
control, err := nebula.Main(&c, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
85
ssh.go
85
ssh.go
@@ -6,21 +6,21 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"maps"
|
"maps"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/pprof"
|
"runtime/pprof"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/logging"
|
||||||
"github.com/slackhq/nebula/sshd"
|
"github.com/slackhq/nebula/sshd"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -57,12 +57,12 @@ type sshDeviceInfoFlags struct {
|
|||||||
Pretty bool
|
Pretty bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) {
|
func wireSSHReload(l *slog.Logger, ssh *sshd.SSHServer, c *config.C) {
|
||||||
c.RegisterReloadCallback(func(c *config.C) {
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
if c.GetBool("sshd.enabled", false) {
|
if c.GetBool("sshd.enabled", false) {
|
||||||
sshRun, err := configSSH(l, ssh, c)
|
sshRun, err := configSSH(l, ssh, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Error("Failed to reconfigure the sshd")
|
l.Error("Failed to reconfigure the sshd", "error", err)
|
||||||
ssh.Stop()
|
ssh.Stop()
|
||||||
}
|
}
|
||||||
if sshRun != nil {
|
if sshRun != nil {
|
||||||
@@ -78,7 +78,7 @@ func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) {
|
|||||||
// updates the passed-in SSHServer. On success, it returns a function
|
// updates the passed-in SSHServer. On success, it returns a function
|
||||||
// that callers may invoke to run the configured ssh server. On
|
// that callers may invoke to run the configured ssh server. On
|
||||||
// failure, it returns nil, error.
|
// failure, it returns nil, error.
|
||||||
func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) {
|
func configSSH(l *slog.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) {
|
||||||
listen := c.GetString("sshd.listen", "")
|
listen := c.GetString("sshd.listen", "")
|
||||||
if listen == "" {
|
if listen == "" {
|
||||||
return nil, fmt.Errorf("sshd.listen must be provided")
|
return nil, fmt.Errorf("sshd.listen must be provided")
|
||||||
@@ -120,7 +120,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
|
|||||||
for _, caAuthorizedKey := range rawCAs {
|
for _, caAuthorizedKey := range rawCAs {
|
||||||
err := ssh.AddTrustedCA(caAuthorizedKey)
|
err := ssh.AddTrustedCA(caAuthorizedKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).WithField("sshCA", caAuthorizedKey).Warn("SSH CA had an error, ignoring")
|
l.Warn("SSH CA had an error, ignoring", "error", err, "sshCA", caAuthorizedKey)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -131,13 +131,13 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
|
|||||||
for _, rk := range keys {
|
for _, rk := range keys {
|
||||||
kDef, ok := rk.(map[string]any)
|
kDef, ok := rk.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
l.WithField("sshKeyConfig", rk).Warn("Authorized user had an error, ignoring")
|
l.Warn("Authorized user had an error, ignoring", "sshKeyConfig", rk)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
user, ok := kDef["user"].(string)
|
user, ok := kDef["user"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the user field")
|
l.Warn("Authorized user is missing the user field", "sshKeyConfig", rk)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -146,7 +146,11 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
|
|||||||
case string:
|
case string:
|
||||||
err := ssh.AddAuthorizedKey(user, v)
|
err := ssh.AddAuthorizedKey(user, v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).WithField("sshKeyConfig", rk).WithField("sshKey", v).Warn("Failed to authorize key")
|
l.Warn("Failed to authorize key",
|
||||||
|
"error", err,
|
||||||
|
"sshKeyConfig", rk,
|
||||||
|
"sshKey", v,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -154,19 +158,25 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
|
|||||||
for _, subK := range v {
|
for _, subK := range v {
|
||||||
sk, ok := subK.(string)
|
sk, ok := subK.(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
l.WithField("sshKeyConfig", rk).WithField("sshKey", subK).Warn("Did not understand ssh key")
|
l.Warn("Did not understand ssh key",
|
||||||
|
"sshKeyConfig", rk,
|
||||||
|
"sshKey", subK,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err := ssh.AddAuthorizedKey(user, sk)
|
err := ssh.AddAuthorizedKey(user, sk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).WithField("sshKeyConfig", sk).Warn("Failed to authorize key")
|
l.Warn("Failed to authorize key",
|
||||||
|
"error", err,
|
||||||
|
"sshKeyConfig", sk,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the keys field or was not understood")
|
l.Warn("Authorized user is missing the keys field or was not understood", "sshKeyConfig", rk)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -178,7 +188,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
|
|||||||
ssh.Stop()
|
ssh.Stop()
|
||||||
runner = func() {
|
runner = func() {
|
||||||
if err := ssh.Run(listen); err != nil {
|
if err := ssh.Run(listen); err != nil {
|
||||||
l.WithField("err", err).Warn("Failed to run the SSH server")
|
l.Warn("Failed to run the SSH server", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -188,7 +198,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
|
|||||||
return runner, nil
|
return runner, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) {
|
func attachCommands(l *slog.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) {
|
||||||
// sandboxDir defaults to a dir in temp. The intention is that end user will
|
// sandboxDir defaults to a dir in temp. The intention is that end user will
|
||||||
// create this dir as needed. Overriding this config value to "" allows
|
// create this dir as needed. Overriding this config value to "" allows
|
||||||
// writing to anywhere in the system.
|
// writing to anywhere in the system.
|
||||||
@@ -789,36 +799,45 @@ func sshGetMutexProfile(sandboxDir string, fs any, a []string, w sshd.StringWrit
|
|||||||
return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a))
|
return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a))
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshLogLevel(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error {
|
func sshLogLevel(l *slog.Logger, fs any, a []string, w sshd.StringWriter) error {
|
||||||
if len(a) == 0 {
|
ctrl, ok := l.Handler().(interface {
|
||||||
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
|
GetLevel() slog.Level
|
||||||
|
SetLevel(slog.Level)
|
||||||
|
})
|
||||||
|
if !ok {
|
||||||
|
return w.WriteLine("Log level is not reconfigurable on this logger")
|
||||||
}
|
}
|
||||||
|
|
||||||
level, err := logrus.ParseLevel(a[0])
|
if len(a) == 0 {
|
||||||
|
return w.WriteLine(fmt.Sprintf("Log level is: %s", logging.LevelName(ctrl.GetLevel())))
|
||||||
|
}
|
||||||
|
|
||||||
|
level, err := logging.ParseLevel(strings.ToLower(a[0]))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return w.WriteLine(fmt.Sprintf("Unknown log level %s. Possible log levels: %s", a, logrus.AllLevels))
|
return w.WriteLine(fmt.Sprintf("Unknown log level %s. Possible log levels: trace, debug, info, warn, error", a))
|
||||||
}
|
}
|
||||||
|
|
||||||
l.SetLevel(level)
|
ctrl.SetLevel(level)
|
||||||
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
|
return w.WriteLine(fmt.Sprintf("Log level is: %s", logging.LevelName(ctrl.GetLevel())))
|
||||||
|
}
|
||||||
|
|
||||||
|
func sshLogFormat(l *slog.Logger, fs any, a []string, w sshd.StringWriter) error {
|
||||||
|
ctrl, ok := l.Handler().(interface {
|
||||||
|
GetFormat() string
|
||||||
|
SetFormat(string) error
|
||||||
|
})
|
||||||
|
if !ok {
|
||||||
|
return w.WriteLine("Log format is not reconfigurable on this logger")
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshLogFormat(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error {
|
|
||||||
if len(a) == 0 {
|
if len(a) == 0 {
|
||||||
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
|
return w.WriteLine(fmt.Sprintf("Log format is: %s", ctrl.GetFormat()))
|
||||||
}
|
}
|
||||||
|
|
||||||
logFormat := strings.ToLower(a[0])
|
if err := ctrl.SetFormat(strings.ToLower(a[0])); err != nil {
|
||||||
switch logFormat {
|
return err
|
||||||
case "text":
|
|
||||||
l.Formatter = &logrus.TextFormatter{}
|
|
||||||
case "json":
|
|
||||||
l.Formatter = &logrus.JSONFormatter{}
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
|
|
||||||
}
|
}
|
||||||
|
return w.WriteLine(fmt.Sprintf("Log format is: %s", ctrl.GetFormat()))
|
||||||
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
|
func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
|
||||||
|
|||||||
@@ -5,16 +5,16 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/armon/go-radix"
|
"github.com/armon/go-radix"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SSHServer struct {
|
type SSHServer struct {
|
||||||
config *ssh.ServerConfig
|
config *ssh.ServerConfig
|
||||||
l *logrus.Entry
|
l *slog.Logger
|
||||||
|
|
||||||
certChecker *ssh.CertChecker
|
certChecker *ssh.CertChecker
|
||||||
|
|
||||||
@@ -33,7 +33,7 @@ type SSHServer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen
|
// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen
|
||||||
func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
|
func NewSSHServer(l *slog.Logger) (*SSHServer, error) {
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
s := &SSHServer{
|
s := &SSHServer{
|
||||||
@@ -121,7 +121,7 @@ func (s *SSHServer) AddTrustedCA(pubKey string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.trustedCAs = append(s.trustedCAs, pk)
|
s.trustedCAs = append(s.trustedCAs, pk)
|
||||||
s.l.WithField("sshKey", pubKey).Info("Trusted CA key")
|
s.l.Info("Trusted CA key", "sshKey", pubKey)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -139,7 +139,10 @@ func (s *SSHServer) AddAuthorizedKey(user, pubKey string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tk[string(pk.Marshal())] = true
|
tk[string(pk.Marshal())] = true
|
||||||
s.l.WithField("sshKey", pubKey).WithField("sshUser", user).Info("Authorized ssh key")
|
s.l.Info("Authorized ssh key",
|
||||||
|
"sshKey", pubKey,
|
||||||
|
"sshUser", user,
|
||||||
|
)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -156,7 +159,7 @@ func (s *SSHServer) Run(addr string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.l.WithField("sshListener", addr).Info("SSH server is listening")
|
s.l.Info("SSH server is listening", "sshListener", addr)
|
||||||
|
|
||||||
// Run loops until there is an error
|
// Run loops until there is an error
|
||||||
s.run()
|
s.run()
|
||||||
@@ -172,7 +175,7 @@ func (s *SSHServer) run() {
|
|||||||
c, err := s.listener.Accept()
|
c, err := s.listener.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, net.ErrClosed) {
|
if !errors.Is(err, net.ErrClosed) {
|
||||||
s.l.WithError(err).Warn("Error in listener, shutting down")
|
s.l.Warn("Error in listener, shutting down", "error", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -193,23 +196,29 @@ func (s *SSHServer) run() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l := s.l.WithError(err).WithField("remoteAddress", c.RemoteAddr())
|
l := s.l.With(
|
||||||
|
"error", err,
|
||||||
|
"remoteAddress", c.RemoteAddr(),
|
||||||
|
)
|
||||||
if conn != nil {
|
if conn != nil {
|
||||||
l = l.WithField("sshUser", conn.User())
|
l = l.With("sshUser", conn.User())
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}
|
}
|
||||||
if fp != "" {
|
if fp != "" {
|
||||||
l = l.WithField("sshFingerprint", fp)
|
l = l.With("sshFingerprint", fp)
|
||||||
}
|
}
|
||||||
l.Warn("failed to handshake")
|
l.Warn("failed to handshake")
|
||||||
sessionCancel()
|
sessionCancel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
l := s.l.WithField("sshUser", conn.User())
|
l := s.l.With("sshUser", conn.User())
|
||||||
l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in")
|
l.Info("ssh user logged in",
|
||||||
|
"remoteAddress", c.RemoteAddr(),
|
||||||
|
"sshFingerprint", fp,
|
||||||
|
)
|
||||||
|
|
||||||
NewSession(s.commands, conn, chans, sessionCancel, l.WithField("subsystem", "sshd.session"))
|
NewSession(s.commands, conn, chans, sessionCancel, l.With("subsystem", "sshd.session"))
|
||||||
|
|
||||||
go ssh.DiscardRequests(reqs)
|
go ssh.DiscardRequests(reqs)
|
||||||
|
|
||||||
@@ -221,7 +230,7 @@ func (s *SSHServer) Stop() {
|
|||||||
// Close the listener, this will cause all session to terminate as well, see SSHServer.Run
|
// Close the listener, this will cause all session to terminate as well, see SSHServer.Run
|
||||||
if s.listener != nil {
|
if s.listener != nil {
|
||||||
if err := s.listener.Close(); err != nil {
|
if err := s.listener.Close(); err != nil {
|
||||||
s.l.WithError(err).Warn("Failed to close the sshd listener")
|
s.l.Warn("Failed to close the sshd listener", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,25 +2,25 @@ package sshd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/anmitsu/go-shlex"
|
"github.com/anmitsu/go-shlex"
|
||||||
"github.com/armon/go-radix"
|
"github.com/armon/go-radix"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
"golang.org/x/term"
|
"golang.org/x/term"
|
||||||
)
|
)
|
||||||
|
|
||||||
type session struct {
|
type session struct {
|
||||||
l *logrus.Entry
|
l *slog.Logger
|
||||||
c *ssh.ServerConn
|
c *ssh.ServerConn
|
||||||
term *term.Terminal
|
term *term.Terminal
|
||||||
commands *radix.Tree
|
commands *radix.Tree
|
||||||
cancel func()
|
cancel func()
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, cancel func(), l *logrus.Entry) *session {
|
func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, cancel func(), l *slog.Logger) *session {
|
||||||
s := &session{
|
s := &session{
|
||||||
commands: radix.NewFromMap(commands.ToMap()),
|
commands: radix.NewFromMap(commands.ToMap()),
|
||||||
l: l,
|
l: l,
|
||||||
@@ -45,14 +45,14 @@ func (s *session) handleChannels(chans <-chan ssh.NewChannel) {
|
|||||||
defer s.Close()
|
defer s.Close()
|
||||||
for newChannel := range chans {
|
for newChannel := range chans {
|
||||||
if newChannel.ChannelType() != "session" {
|
if newChannel.ChannelType() != "session" {
|
||||||
s.l.WithField("sshChannelType", newChannel.ChannelType()).Error("unknown channel type")
|
s.l.Error("unknown channel type", "sshChannelType", newChannel.ChannelType())
|
||||||
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
|
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
channel, requests, err := newChannel.Accept()
|
channel, requests, err := newChannel.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.l.WithError(err).Warn("could not accept channel")
|
s.l.Warn("could not accept channel", "error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,12 +95,12 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) {
|
|||||||
return
|
return
|
||||||
|
|
||||||
default:
|
default:
|
||||||
s.l.WithField("sshRequest", req.Type).Debug("Rejected unknown request")
|
s.l.Debug("Rejected unknown request", "sshRequest", req.Type)
|
||||||
err = req.Reply(false, nil)
|
err = req.Reply(false, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.l.WithError(err).Info("Error handling ssh session requests")
|
s.l.Info("Error handling ssh session requests", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
24
stats.go
24
stats.go
@@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"runtime"
|
"runtime"
|
||||||
@@ -15,14 +16,13 @@ import (
|
|||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
// startStats initializes stats from config. On success, if any further work
|
// startStats initializes stats from config. On success, if any further work
|
||||||
// is needed to serve stats, it returns a func to handle that work. If no
|
// is needed to serve stats, it returns a func to handle that work. If no
|
||||||
// work is needed, it'll return nil. On failure, it returns nil, error.
|
// work is needed, it'll return nil. On failure, it returns nil, error.
|
||||||
func startStats(l *logrus.Logger, c *config.C, buildVersion string, configTest bool) (func(), error) {
|
func startStats(l *slog.Logger, c *config.C, buildVersion string, configTest bool) (func(), error) {
|
||||||
mType := c.GetString("stats.type", "")
|
mType := c.GetString("stats.type", "")
|
||||||
if mType == "" || mType == "none" {
|
if mType == "" || mType == "none" {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@@ -59,7 +59,7 @@ func startStats(l *logrus.Logger, c *config.C, buildVersion string, configTest b
|
|||||||
return startFn, nil
|
return startFn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func startGraphiteStats(l *logrus.Logger, i time.Duration, c *config.C, configTest bool) error {
|
func startGraphiteStats(l *slog.Logger, i time.Duration, c *config.C, configTest bool) error {
|
||||||
proto := c.GetString("stats.protocol", "tcp")
|
proto := c.GetString("stats.protocol", "tcp")
|
||||||
host := c.GetString("stats.host", "")
|
host := c.GetString("stats.host", "")
|
||||||
if host == "" {
|
if host == "" {
|
||||||
@@ -73,13 +73,17 @@ func startGraphiteStats(l *logrus.Logger, i time.Duration, c *config.C, configTe
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !configTest {
|
if !configTest {
|
||||||
l.Infof("Starting graphite. Interval: %s, prefix: %s, addr: %s", i, prefix, addr)
|
l.Info("Starting graphite",
|
||||||
|
"interval", i,
|
||||||
|
"prefix", prefix,
|
||||||
|
"addr", addr.String(),
|
||||||
|
)
|
||||||
go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr)
|
go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func startPrometheusStats(l *logrus.Logger, i time.Duration, c *config.C, buildVersion string, configTest bool) (func(), error) {
|
func startPrometheusStats(l *slog.Logger, i time.Duration, c *config.C, buildVersion string, configTest bool) (func(), error) {
|
||||||
namespace := c.GetString("stats.namespace", "")
|
namespace := c.GetString("stats.namespace", "")
|
||||||
subsystem := c.GetString("stats.subsystem", "")
|
subsystem := c.GetString("stats.subsystem", "")
|
||||||
|
|
||||||
@@ -116,9 +120,15 @@ func startPrometheusStats(l *logrus.Logger, i time.Duration, c *config.C, buildV
|
|||||||
|
|
||||||
var startFn func()
|
var startFn func()
|
||||||
if !configTest {
|
if !configTest {
|
||||||
|
// promhttp.HandlerOpts.ErrorLog needs a stdlib-shaped Println logger,
|
||||||
|
// so bridge our slog.Logger back to a *log.Logger that emits at Error.
|
||||||
|
errLog := slog.NewLogLogger(l.Handler(), slog.LevelError)
|
||||||
startFn = func() {
|
startFn = func() {
|
||||||
l.Infof("Prometheus stats listening on %s at %s", listen, path)
|
l.Info("Prometheus stats listening",
|
||||||
http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: l}))
|
"listen", listen,
|
||||||
|
"path", path,
|
||||||
|
)
|
||||||
|
http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: errLog}))
|
||||||
log.Fatal(http.ListenAndServe(listen, nil))
|
log.Fatal(http.ListenAndServe(listen, nil))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,29 +1,73 @@
|
|||||||
package test
|
package test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/slackhq/nebula/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewLogger() *logrus.Logger {
|
// NewLogger returns a *slog.Logger suitable for use in tests. Output goes to
|
||||||
l := logrus.New()
|
// io.Discard by default; set TEST_LOGS=1 (info), 2 (debug), or 3 (trace) to
|
||||||
|
// stream output to stderr for local debugging.
|
||||||
|
func NewLogger() *slog.Logger {
|
||||||
v := os.Getenv("TEST_LOGS")
|
v := os.Getenv("TEST_LOGS")
|
||||||
if v == "" {
|
if v == "" {
|
||||||
l.SetOutput(io.Discard)
|
return slog.New(slog.DiscardHandler)
|
||||||
return l
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
level := slog.LevelInfo
|
||||||
switch v {
|
switch v {
|
||||||
case "2":
|
case "2":
|
||||||
l.SetLevel(logrus.DebugLevel)
|
level = slog.LevelDebug
|
||||||
case "3":
|
case "3":
|
||||||
l.SetLevel(logrus.TraceLevel)
|
level = logging.LevelTrace
|
||||||
default:
|
}
|
||||||
l.SetLevel(logrus.InfoLevel)
|
return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level}))
|
||||||
}
|
}
|
||||||
|
|
||||||
return l
|
// NewLoggerWithOutput returns a *slog.Logger whose text output is captured by
|
||||||
|
// w. Timestamps are suppressed so tests can assert on exact output without
|
||||||
|
// baking the current time into expected strings.
|
||||||
|
func NewLoggerWithOutput(w io.Writer) *slog.Logger {
|
||||||
|
return slog.New(&stripTimeHandler{inner: slog.NewTextHandler(w, nil)})
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLoggerWithOutputAndLevel is NewLoggerWithOutput with an explicit level
|
||||||
|
// so tests can exercise Enabled-gated paths.
|
||||||
|
func NewLoggerWithOutputAndLevel(w io.Writer, level slog.Level) *slog.Logger {
|
||||||
|
return slog.New(&stripTimeHandler{inner: slog.NewTextHandler(w, &slog.HandlerOptions{Level: level})})
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewJSONLoggerWithOutput returns a *slog.Logger emitting JSON to w with
|
||||||
|
// timestamps suppressed, for tests that pin the JSON shape.
|
||||||
|
func NewJSONLoggerWithOutput(w io.Writer, level slog.Level) *slog.Logger {
|
||||||
|
return slog.New(&stripTimeHandler{inner: slog.NewJSONHandler(w, &slog.HandlerOptions{Level: level})})
|
||||||
|
}
|
||||||
|
|
||||||
|
// stripTimeHandler zeros each record's time before delegating so slog's
|
||||||
|
// built-in handlers skip emitting the time attribute. Used to avoid
|
||||||
|
// timestamp-dependent assertions in tests without resorting to ReplaceAttr.
|
||||||
|
type stripTimeHandler struct {
|
||||||
|
inner slog.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *stripTimeHandler) Enabled(ctx context.Context, l slog.Level) bool {
|
||||||
|
return h.inner.Enabled(ctx, l)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *stripTimeHandler) Handle(ctx context.Context, r slog.Record) error {
|
||||||
|
r.Time = time.Time{}
|
||||||
|
return h.inner.Handle(ctx, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *stripTimeHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||||
|
return &stripTimeHandler{inner: h.inner.WithAttrs(attrs)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *stripTimeHandler) WithGroup(name string) slog.Handler {
|
||||||
|
return &stripTimeHandler{inner: h.inner.WithGroup(name)}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,11 +9,12 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"log/slog"
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||||
return NewGenericListener(l, ip, port, multi, batch)
|
return NewGenericListener(l, ip, port, multi, batch)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,11 +12,12 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"log/slog"
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||||
return NewGenericListener(l, ip, port, multi, batch)
|
return NewGenericListener(l, ip, port, multi, batch)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,12 +8,12 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"syscall"
|
"syscall"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
@@ -22,12 +22,12 @@ type StdConn struct {
|
|||||||
*net.UDPConn
|
*net.UDPConn
|
||||||
isV4 bool
|
isV4 bool
|
||||||
sysFd uintptr
|
sysFd uintptr
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Conn = &StdConn{}
|
var _ Conn = &StdConn{}
|
||||||
|
|
||||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||||
lc := NewListenConfig(multi)
|
lc := NewListenConfig(multi)
|
||||||
pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
|
pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -176,7 +176,7 @@ func (u *StdConn) ListenOut(r EncReader) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
u.l.WithError(err).Error("unexpected udp socket receive error")
|
u.l.Error("unexpected udp socket receive error", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
||||||
@@ -196,7 +196,7 @@ func (u *StdConn) Rebind() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
u.l.WithError(err).Error("Failed to rebind udp socket")
|
u.l.Error("Failed to rebind udp socket", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -12,22 +12,22 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GenericConn struct {
|
type GenericConn struct {
|
||||||
*net.UDPConn
|
*net.UDPConn
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Conn = &GenericConn{}
|
var _ Conn = &GenericConn{}
|
||||||
|
|
||||||
func NewGenericListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
func NewGenericListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||||
lc := NewListenConfig(multi)
|
lc := NewListenConfig(multi)
|
||||||
pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
|
pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -88,7 +88,7 @@ func (u *GenericConn) ListenOut(r EncReader) error {
|
|||||||
// Dampen unexpected message warns to once per minute
|
// Dampen unexpected message warns to once per minute
|
||||||
if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute {
|
if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute {
|
||||||
lastRecvErr = time.Now()
|
lastRecvErr = time.Now()
|
||||||
u.l.WithError(err).Warn("unexpected udp socket receive error")
|
u.l.Warn("unexpected udp socket receive error", "error", err)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,13 +7,13 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"syscall"
|
"syscall"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
@@ -22,7 +22,7 @@ type StdConn struct {
|
|||||||
udpConn *net.UDPConn
|
udpConn *net.UDPConn
|
||||||
rawConn syscall.RawConn
|
rawConn syscall.RawConn
|
||||||
isV4 bool
|
isV4 bool
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
batch int
|
batch int
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -38,7 +38,7 @@ func setReusePort(network, address string, c syscall.RawConn) error {
|
|||||||
return opErr
|
return opErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||||
listen := netip.AddrPortFrom(ip, uint16(port))
|
listen := netip.AddrPortFrom(ip, uint16(port))
|
||||||
lc := net.ListenConfig{}
|
lc := net.ListenConfig{}
|
||||||
if multi {
|
if multi {
|
||||||
@@ -242,12 +242,12 @@ func (u *StdConn) ReloadConfig(c *config.C) {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
s, err := u.GetRecvBuffer()
|
s, err := u.GetRecvBuffer()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
u.l.WithField("size", s).Info("listen.read_buffer was set")
|
u.l.Info("listen.read_buffer was set", "size", s)
|
||||||
} else {
|
} else {
|
||||||
u.l.WithError(err).Warn("Failed to get listen.read_buffer")
|
u.l.Warn("Failed to get listen.read_buffer", "error", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
u.l.WithError(err).Error("Failed to set listen.read_buffer")
|
u.l.Error("Failed to set listen.read_buffer", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -257,12 +257,12 @@ func (u *StdConn) ReloadConfig(c *config.C) {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
s, err := u.GetSendBuffer()
|
s, err := u.GetSendBuffer()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
u.l.WithField("size", s).Info("listen.write_buffer was set")
|
u.l.Info("listen.write_buffer was set", "size", s)
|
||||||
} else {
|
} else {
|
||||||
u.l.WithError(err).Warn("Failed to get listen.write_buffer")
|
u.l.Warn("Failed to get listen.write_buffer", "error", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
u.l.WithError(err).Error("Failed to set listen.write_buffer")
|
u.l.Error("Failed to set listen.write_buffer", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -273,12 +273,12 @@ func (u *StdConn) ReloadConfig(c *config.C) {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
s, err := u.GetSoMark()
|
s, err := u.GetSoMark()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
u.l.WithField("mark", s).Info("listen.so_mark was set")
|
u.l.Info("listen.so_mark was set", "mark", s)
|
||||||
} else {
|
} else {
|
||||||
u.l.WithError(err).Warn("Failed to get listen.so_mark")
|
u.l.Warn("Failed to get listen.so_mark", "error", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
u.l.WithError(err).Error("Failed to set listen.so_mark")
|
u.l.Error("Failed to set listen.so_mark", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,11 +11,12 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"log/slog"
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||||
return NewGenericListener(l, ip, port, multi, batch)
|
return NewGenericListener(l, ip, port, multi, batch)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -17,7 +18,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
"golang.zx2c4.com/wireguard/conn/winrio"
|
"golang.zx2c4.com/wireguard/conn/winrio"
|
||||||
@@ -53,14 +53,14 @@ type ringBuffer struct {
|
|||||||
|
|
||||||
type RIOConn struct {
|
type RIOConn struct {
|
||||||
isOpen atomic.Bool
|
isOpen atomic.Bool
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
sock windows.Handle
|
sock windows.Handle
|
||||||
rx, tx ringBuffer
|
rx, tx ringBuffer
|
||||||
rq winrio.Rq
|
rq winrio.Rq
|
||||||
results [packetsPerRing]winrio.Result
|
results [packetsPerRing]winrio.Result
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, error) {
|
func NewRIOListener(l *slog.Logger, addr netip.Addr, port int) (*RIOConn, error) {
|
||||||
if !winrio.Initialize() {
|
if !winrio.Initialize() {
|
||||||
return nil, errors.New("could not initialize winrio")
|
return nil, errors.New("could not initialize winrio")
|
||||||
}
|
}
|
||||||
@@ -83,7 +83,7 @@ func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, erro
|
|||||||
return u, nil
|
return u, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error {
|
func (u *RIOConn) bind(l *slog.Logger, sa windows.Sockaddr) error {
|
||||||
var err error
|
var err error
|
||||||
u.sock, err = winrio.Socket(windows.AF_INET6, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
|
u.sock, err = winrio.Socket(windows.AF_INET6, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -103,7 +103,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
// This is a best-effort to prevent errors from being returned by the udp recv operation.
|
// This is a best-effort to prevent errors from being returned by the udp recv operation.
|
||||||
// Quietly log a failure and continue.
|
// Quietly log a failure and continue.
|
||||||
l.WithError(err).Debug("failed to set UDP_CONNRESET ioctl")
|
l.Debug("failed to set UDP_CONNRESET ioctl", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = 0
|
ret = 0
|
||||||
@@ -114,7 +114,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
// This is a best-effort to prevent errors from being returned by the udp recv operation.
|
// This is a best-effort to prevent errors from being returned by the udp recv operation.
|
||||||
// Quietly log a failure and continue.
|
// Quietly log a failure and continue.
|
||||||
l.WithError(err).Debug("failed to set UDP_NETRESET ioctl")
|
l.Debug("failed to set UDP_NETRESET ioctl", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = u.rx.Open()
|
err = u.rx.Open()
|
||||||
@@ -156,7 +156,7 @@ func (u *RIOConn) ListenOut(r EncReader) error {
|
|||||||
// Dampen unexpected message warns to once per minute
|
// Dampen unexpected message warns to once per minute
|
||||||
if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute {
|
if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute {
|
||||||
lastRecvErr = time.Now()
|
lastRecvErr = time.Now()
|
||||||
u.l.WithError(err).Warn("unexpected udp socket receive error")
|
u.l.Warn("unexpected udp socket receive error", "error", err)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,12 +4,13 @@
|
|||||||
package udp
|
package udp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
)
|
)
|
||||||
@@ -46,10 +47,10 @@ type TesterConn struct {
|
|||||||
done chan struct{}
|
done chan struct{}
|
||||||
closeOnce sync.Once
|
closeOnce sync.Once
|
||||||
|
|
||||||
l *logrus.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) {
|
func NewListener(l *slog.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) {
|
||||||
return &TesterConn{
|
return &TesterConn{
|
||||||
Addr: netip.AddrPortFrom(ip, uint16(port)),
|
Addr: netip.AddrPortFrom(ip, uint16(port)),
|
||||||
RxPackets: make(chan *Packet, 10),
|
RxPackets: make(chan *Packet, 10),
|
||||||
@@ -67,11 +68,12 @@ func (u *TesterConn) Send(packet *Packet) {
|
|||||||
if err := h.Parse(packet.Data); err != nil {
|
if err := h.Parse(packet.Data); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
if u.l.Level >= logrus.DebugLevel {
|
if u.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
u.l.WithField("header", h).
|
u.l.Debug("UDP receiving injected packet",
|
||||||
WithField("udpAddr", packet.From).
|
"header", h,
|
||||||
WithField("dataLen", len(packet.Data)).
|
"udpAddr", packet.From,
|
||||||
Debug("UDP receiving injected packet")
|
"dataLen", len(packet.Data),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case <-u.done:
|
case <-u.done:
|
||||||
|
|||||||
@@ -5,14 +5,13 @@ package udp
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||||
if multi {
|
if multi {
|
||||||
//NOTE: Technically we can support it with RIO but it wouldn't be at the socket level
|
//NOTE: Technically we can support it with RIO but it wouldn't be at the socket level
|
||||||
// The udp stack would need to be reworked to hide away the implementation differences between
|
// The udp stack would need to be reworked to hide away the implementation differences between
|
||||||
@@ -25,7 +24,7 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
|
|||||||
return rc, nil
|
return rc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
l.WithError(err).Error("Falling back to standard udp sockets")
|
l.Error("Falling back to standard udp sockets", "error", err)
|
||||||
return NewGenericListener(l, ip, port, multi, batch)
|
return NewGenericListener(l, ip, port, multi, batch)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
package util
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ContextualError struct {
|
type ContextualError struct {
|
||||||
@@ -28,12 +28,12 @@ func ContextualizeIfNeeded(msg string, err error) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LogWithContextIfNeeded is a helper function to log an error line for an error or ContextualError
|
// LogWithContextIfNeeded is a helper function to log an error line for an error or ContextualError
|
||||||
func LogWithContextIfNeeded(msg string, err error, l *logrus.Logger) {
|
func LogWithContextIfNeeded(msg string, err error, l *slog.Logger) {
|
||||||
switch v := err.(type) {
|
switch v := err.(type) {
|
||||||
case *ContextualError:
|
case *ContextualError:
|
||||||
v.Log(l)
|
v.Log(l)
|
||||||
default:
|
default:
|
||||||
l.WithError(err).Error(msg)
|
l.Error(msg, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,10 +51,19 @@ func (ce *ContextualError) Unwrap() error {
|
|||||||
return ce.RealError
|
return ce.RealError
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ce *ContextualError) Log(lr *logrus.Logger) {
|
// Log emits ce as a single error-level log line with Fields and RealError
|
||||||
|
// promoted to top-level attributes, producing a flat shape callers can grep
|
||||||
|
// or parse without walking into a nested object.
|
||||||
|
func (ce *ContextualError) Log(l *slog.Logger) {
|
||||||
|
attrs := make([]slog.Attr, 0, len(ce.Fields)+1)
|
||||||
|
for k, v := range ce.Fields {
|
||||||
|
attrs = append(attrs, slog.Any(k, v))
|
||||||
|
}
|
||||||
if ce.RealError != nil {
|
if ce.RealError != nil {
|
||||||
lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context)
|
attrs = append(attrs, slog.Any("error", ce.RealError))
|
||||||
} else {
|
|
||||||
lr.WithFields(ce.Fields).Error(ce.Context)
|
|
||||||
}
|
}
|
||||||
|
// LogAttrs is intentional: attrs is built from a map[string]any so it has
|
||||||
|
// no pair-form equivalent.
|
||||||
|
//nolint:sloglint
|
||||||
|
l.LogAttrs(context.Background(), slog.LevelError, ce.Context, attrs...)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,95 +1,67 @@
|
|||||||
package util
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
type m = map[string]any
|
type m = map[string]any
|
||||||
|
|
||||||
type TestLogWriter struct {
|
|
||||||
Logs []string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewTestLogWriter() *TestLogWriter {
|
|
||||||
return &TestLogWriter{Logs: make([]string, 0)}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tl *TestLogWriter) Write(p []byte) (n int, err error) {
|
|
||||||
tl.Logs = append(tl.Logs, string(p))
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tl *TestLogWriter) Reset() {
|
|
||||||
tl.Logs = tl.Logs[:0]
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestContextualError_Log(t *testing.T) {
|
func TestContextualError_Log(t *testing.T) {
|
||||||
l := logrus.New()
|
buf := &bytes.Buffer{}
|
||||||
l.Formatter = &logrus.TextFormatter{
|
l := test.NewLoggerWithOutput(buf)
|
||||||
DisableTimestamp: true,
|
|
||||||
DisableColors: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
tl := NewTestLogWriter()
|
|
||||||
l.Out = tl
|
|
||||||
|
|
||||||
// Test a full context line
|
// Test a full context line
|
||||||
tl.Reset()
|
buf.Reset()
|
||||||
e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
|
e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
|
||||||
e.Log(l)
|
e.Log(l)
|
||||||
assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs)
|
assert.Equal(t, "level=ERROR msg=\"test message\" field=1 error=error\n", buf.String())
|
||||||
|
|
||||||
// Test a line with an error and msg but no fields
|
// Test a line with an error and msg but no fields
|
||||||
tl.Reset()
|
buf.Reset()
|
||||||
e = NewContextualError("test message", nil, errors.New("error"))
|
e = NewContextualError("test message", nil, errors.New("error"))
|
||||||
e.Log(l)
|
e.Log(l)
|
||||||
assert.Equal(t, []string{"level=error msg=\"test message\" error=error\n"}, tl.Logs)
|
assert.Equal(t, "level=ERROR msg=\"test message\" error=error\n", buf.String())
|
||||||
|
|
||||||
// Test just a context and fields
|
// Test just a context and fields
|
||||||
tl.Reset()
|
buf.Reset()
|
||||||
e = NewContextualError("test message", m{"field": "1"}, nil)
|
e = NewContextualError("test message", m{"field": "1"}, nil)
|
||||||
e.Log(l)
|
e.Log(l)
|
||||||
assert.Equal(t, []string{"level=error msg=\"test message\" field=1\n"}, tl.Logs)
|
assert.Equal(t, "level=ERROR msg=\"test message\" field=1\n", buf.String())
|
||||||
|
|
||||||
// Test just a context
|
// Test just a context
|
||||||
tl.Reset()
|
buf.Reset()
|
||||||
e = NewContextualError("test message", nil, nil)
|
e = NewContextualError("test message", nil, nil)
|
||||||
e.Log(l)
|
e.Log(l)
|
||||||
assert.Equal(t, []string{"level=error msg=\"test message\"\n"}, tl.Logs)
|
assert.Equal(t, "level=ERROR msg=\"test message\"\n", buf.String())
|
||||||
|
|
||||||
// Test just an error
|
// Test just an error
|
||||||
tl.Reset()
|
buf.Reset()
|
||||||
e = NewContextualError("", nil, errors.New("error"))
|
e = NewContextualError("", nil, errors.New("error"))
|
||||||
e.Log(l)
|
e.Log(l)
|
||||||
assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs)
|
assert.Equal(t, "level=ERROR msg=\"\" error=error\n", buf.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLogWithContextIfNeeded(t *testing.T) {
|
func TestLogWithContextIfNeeded(t *testing.T) {
|
||||||
l := logrus.New()
|
buf := &bytes.Buffer{}
|
||||||
l.Formatter = &logrus.TextFormatter{
|
l := test.NewLoggerWithOutput(buf)
|
||||||
DisableTimestamp: true,
|
|
||||||
DisableColors: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
tl := NewTestLogWriter()
|
|
||||||
l.Out = tl
|
|
||||||
|
|
||||||
// Test ignoring fallback context
|
// Test ignoring fallback context
|
||||||
tl.Reset()
|
buf.Reset()
|
||||||
e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
|
e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
|
||||||
LogWithContextIfNeeded("This should get thrown away", e, l)
|
LogWithContextIfNeeded("This should get thrown away", e, l)
|
||||||
assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs)
|
assert.Equal(t, "level=ERROR msg=\"test message\" field=1 error=error\n", buf.String())
|
||||||
|
|
||||||
// Test using fallback context
|
// Test using fallback context
|
||||||
tl.Reset()
|
buf.Reset()
|
||||||
err := fmt.Errorf("this is a normal error")
|
err := fmt.Errorf("this is a normal error")
|
||||||
LogWithContextIfNeeded("Fallback context woo", err, l)
|
LogWithContextIfNeeded("Fallback context woo", err, l)
|
||||||
assert.Equal(t, []string{"level=error msg=\"Fallback context woo\" error=\"this is a normal error\"\n"}, tl.Logs)
|
assert.Equal(t, "level=ERROR msg=\"Fallback context woo\" error=\"this is a normal error\"\n", buf.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestContextualizeIfNeeded(t *testing.T) {
|
func TestContextualizeIfNeeded(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user