add locking for stop crash

This commit is contained in:
Ryan
2025-11-05 11:58:25 -05:00
parent c8980d34cf
commit 2d128a3254
4 changed files with 60 additions and 18 deletions

View File

@@ -1,6 +1,7 @@
package firewall
import (
"sync"
"sync/atomic"
"time"
@@ -9,13 +10,58 @@ import (
// ConntrackCache is used as a local routine cache to know if a given flow
// has been seen in the conntrack table.
type ConntrackCache map[Packet]struct{}
type ConntrackCache struct {
mu sync.Mutex
entries map[Packet]struct{}
}
func newConntrackCache() *ConntrackCache {
return &ConntrackCache{entries: make(map[Packet]struct{})}
}
func (c *ConntrackCache) Has(p Packet) bool {
if c == nil {
return false
}
c.mu.Lock()
_, ok := c.entries[p]
c.mu.Unlock()
return ok
}
func (c *ConntrackCache) Add(p Packet) {
if c == nil {
return
}
c.mu.Lock()
c.entries[p] = struct{}{}
c.mu.Unlock()
}
func (c *ConntrackCache) Len() int {
if c == nil {
return 0
}
c.mu.Lock()
l := len(c.entries)
c.mu.Unlock()
return l
}
func (c *ConntrackCache) Reset(capHint int) {
if c == nil {
return
}
c.mu.Lock()
c.entries = make(map[Packet]struct{}, capHint)
c.mu.Unlock()
}
type ConntrackCacheTicker struct {
cacheV uint64
cacheTick atomic.Uint64
cache ConntrackCache
cache *ConntrackCache
}
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
@@ -23,9 +69,7 @@ func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
return nil
}
c := &ConntrackCacheTicker{
cache: ConntrackCache{},
}
c := &ConntrackCacheTicker{cache: newConntrackCache()}
go c.tick(d)
@@ -41,17 +85,17 @@ func (c *ConntrackCacheTicker) tick(d time.Duration) {
// Get checks if the cache ticker has moved to the next version before returning
// the map. If it has moved, we reset the map.
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) *ConntrackCache {
if c == nil {
return nil
}
if tick := c.cacheTick.Load(); tick != c.cacheV {
c.cacheV = tick
if ll := len(c.cache); ll > 0 {
if ll := c.cache.Len(); ll > 0 {
if l.Level == logrus.DebugLevel {
l.WithField("len", ll).Debug("resetting conntrack cache")
}
c.cache = make(ConntrackCache, ll)
c.cache.Reset(ll)
}
}