diff --git a/firewall.go b/firewall.go index 971c156..9f2cee5 100644 --- a/firewall.go +++ b/firewall.go @@ -423,7 +423,7 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table") // Drop returns an error if the packet should be dropped, explaining why. It // returns nil if the packet should not be dropped. -func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error { +func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache *firewall.ConntrackCache) error { // Check if we spoke to this tuple, if we did then allow this packet if f.inConns(fp, h, caPool, localCache) { return nil @@ -490,11 +490,9 @@ func (f *Firewall) EmitStats() { metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV())) } -func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool { - if localCache != nil { - if _, ok := localCache[fp]; ok { - return true - } +func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache *firewall.ConntrackCache) bool { + if localCache != nil && localCache.Has(fp) { + return true } conntrack := f.Conntrack conntrack.Lock() @@ -559,7 +557,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, conntrack.Unlock() if localCache != nil { - localCache[fp] = struct{}{} + localCache.Add(fp) } return true diff --git a/firewall/cache.go b/firewall/cache.go index 71b83f4..6bbe509 100644 --- a/firewall/cache.go +++ b/firewall/cache.go @@ -1,6 +1,7 @@ package firewall import ( + "sync" "sync/atomic" "time" @@ -9,13 +10,58 @@ import ( // ConntrackCache is used as a local routine cache to know if a given flow // has been seen in the conntrack table. -type ConntrackCache map[Packet]struct{} +type ConntrackCache struct { + mu sync.Mutex + entries map[Packet]struct{} +} + +func newConntrackCache() *ConntrackCache { + return &ConntrackCache{entries: make(map[Packet]struct{})} +} + +func (c *ConntrackCache) Has(p Packet) bool { + if c == nil { + return false + } + c.mu.Lock() + _, ok := c.entries[p] + c.mu.Unlock() + return ok +} + +func (c *ConntrackCache) Add(p Packet) { + if c == nil { + return + } + c.mu.Lock() + c.entries[p] = struct{}{} + c.mu.Unlock() +} + +func (c *ConntrackCache) Len() int { + if c == nil { + return 0 + } + c.mu.Lock() + l := len(c.entries) + c.mu.Unlock() + return l +} + +func (c *ConntrackCache) Reset(capHint int) { + if c == nil { + return + } + c.mu.Lock() + c.entries = make(map[Packet]struct{}, capHint) + c.mu.Unlock() +} type ConntrackCacheTicker struct { cacheV uint64 cacheTick atomic.Uint64 - cache ConntrackCache + cache *ConntrackCache } func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker { @@ -23,9 +69,7 @@ func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker { return nil } - c := &ConntrackCacheTicker{ - cache: ConntrackCache{}, - } + c := &ConntrackCacheTicker{cache: newConntrackCache()} go c.tick(d) @@ -41,17 +85,17 @@ func (c *ConntrackCacheTicker) tick(d time.Duration) { // Get checks if the cache ticker has moved to the next version before returning // the map. If it has moved, we reset the map. -func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache { +func (c *ConntrackCacheTicker) Get(l *logrus.Logger) *ConntrackCache { if c == nil { return nil } if tick := c.cacheTick.Load(); tick != c.cacheV { c.cacheV = tick - if ll := len(c.cache); ll > 0 { + if ll := c.cache.Len(); ll > 0 { if l.Level == logrus.DebugLevel { l.WithField("len", ll).Debug("resetting conntrack cache") } - c.cache = make(ConntrackCache, ll) + c.cache.Reset(ll) } } diff --git a/inside.go b/inside.go index 9e3672c..66b136b 100644 --- a/inside.go +++ b/inside.go @@ -13,7 +13,7 @@ import ( "github.com/slackhq/nebula/routing" ) -func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache *firewall.ConntrackCache) { err := newPacket(packet, false, fwPacket) if err != nil { if f.l.Level >= logrus.DebugLevel { diff --git a/outside.go b/outside.go index 3eeaa4f..29d5134 100644 --- a/outside.go +++ b/outside.go @@ -20,7 +20,7 @@ const ( minFwPacketLen = 4 ) -func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache *firewall.ConntrackCache) { err := h.Parse(packet) if err != nil { // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors @@ -466,7 +466,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] return out, nil } -func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, addr netip.AddrPort, recvIndex uint32) bool { +func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache *firewall.ConntrackCache, addr netip.AddrPort, recvIndex uint32) bool { var ( err error pkt *overlay.Packet