mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-16 04:47:38 +02:00
Merge remote-tracking branch 'origin/master' into multiport
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ConntrackCache is used as a local routine cache to know if a given flow
|
||||
@@ -15,41 +15,49 @@ type ConntrackCacheTicker struct {
|
||||
cacheV uint64
|
||||
cacheTick atomic.Uint64
|
||||
|
||||
l *slog.Logger
|
||||
cache ConntrackCache
|
||||
}
|
||||
|
||||
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
|
||||
func NewConntrackCacheTicker(ctx context.Context, l *slog.Logger, d time.Duration) *ConntrackCacheTicker {
|
||||
if d == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
c := &ConntrackCacheTicker{
|
||||
l: l,
|
||||
cache: ConntrackCache{},
|
||||
}
|
||||
|
||||
go c.tick(d)
|
||||
go c.tick(ctx, d)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *ConntrackCacheTicker) tick(d time.Duration) {
|
||||
func (c *ConntrackCacheTicker) tick(ctx context.Context, d time.Duration) {
|
||||
t := time.NewTicker(d)
|
||||
defer t.Stop()
|
||||
for {
|
||||
time.Sleep(d)
|
||||
c.cacheTick.Add(1)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
c.cacheTick.Add(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get checks if the cache ticker has moved to the next version before returning
|
||||
// the map. If it has moved, we reset the map.
|
||||
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
|
||||
func (c *ConntrackCacheTicker) Get() ConntrackCache {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if tick := c.cacheTick.Load(); tick != c.cacheV {
|
||||
c.cacheV = tick
|
||||
if ll := len(c.cache); ll > 0 {
|
||||
if l.Level == logrus.DebugLevel {
|
||||
l.WithField("len", ll).Debug("resetting conntrack cache")
|
||||
if c.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
c.l.Debug("resetting conntrack cache", "len", ll)
|
||||
}
|
||||
c.cache = make(ConntrackCache, ll)
|
||||
}
|
||||
|
||||
69
firewall/cache_test.go
Normal file
69
firewall/cache_test.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// The tests below pin the log format produced by ConntrackCacheTicker.Get
|
||||
// so changes cannot silently break what operators are grepping for. The
|
||||
// ticker's internal state (cache + cacheTick) is poked directly to avoid
|
||||
// racing a goroutine-driven tick in tests.
|
||||
|
||||
func newFixedTicker(t *testing.T, l *slog.Logger, cacheLen int) *ConntrackCacheTicker {
|
||||
t.Helper()
|
||||
c := &ConntrackCacheTicker{
|
||||
l: l,
|
||||
cache: make(ConntrackCache, cacheLen),
|
||||
}
|
||||
for i := 0; i < cacheLen; i++ {
|
||||
c.cache[Packet{LocalPort: uint16(i) + 1}] = struct{}{}
|
||||
}
|
||||
c.cacheTick.Store(1) // cacheV starts at 0, so Get() takes the reset path
|
||||
return c
|
||||
}
|
||||
|
||||
func TestConntrackCacheTicker_Get_TextFormat(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug)
|
||||
|
||||
c := newFixedTicker(t, l, 3)
|
||||
c.Get()
|
||||
|
||||
assert.Equal(t, "level=DEBUG msg=\"resetting conntrack cache\" len=3\n", buf.String())
|
||||
}
|
||||
|
||||
func TestConntrackCacheTicker_Get_JSONFormat(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
l := test.NewJSONLoggerWithOutput(buf, slog.LevelDebug)
|
||||
|
||||
c := newFixedTicker(t, l, 2)
|
||||
c.Get()
|
||||
|
||||
assert.JSONEq(t, `{"level":"DEBUG","msg":"resetting conntrack cache","len":2}`, strings.TrimSpace(buf.String()))
|
||||
}
|
||||
|
||||
func TestConntrackCacheTicker_Get_QuietBelowDebug(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelInfo)
|
||||
|
||||
c := newFixedTicker(t, l, 5)
|
||||
c.Get()
|
||||
|
||||
assert.Empty(t, buf.String())
|
||||
}
|
||||
|
||||
func TestConntrackCacheTicker_Get_QuietWhenCacheEmpty(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug)
|
||||
|
||||
c := newFixedTicker(t, l, 0)
|
||||
c.Get()
|
||||
|
||||
assert.Empty(t, buf.String())
|
||||
}
|
||||
@@ -23,7 +23,10 @@ const (
|
||||
type Packet struct {
|
||||
LocalAddr netip.Addr
|
||||
RemoteAddr netip.Addr
|
||||
LocalPort uint16
|
||||
// LocalPort is the destination port for incoming traffic, or the source port for outgoing. Zero for ICMP.
|
||||
LocalPort uint16
|
||||
// RemotePort is the source port for incoming traffic, or the destination port for outgoing.
|
||||
// For ICMP, it's the "identifier". This is only used for connection tracking, actual firewall rules will not filter on ICMP identifier
|
||||
RemotePort uint16
|
||||
Protocol uint8
|
||||
Fragment bool
|
||||
@@ -47,6 +50,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) {
|
||||
proto = "tcp"
|
||||
case ProtoICMP:
|
||||
proto = "icmp"
|
||||
case ProtoICMPv6:
|
||||
proto = "icmpv6"
|
||||
case ProtoUDP:
|
||||
proto = "udp"
|
||||
default:
|
||||
|
||||
Reference in New Issue
Block a user