mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-23 00:44:25 +01:00
add locking for stop crash
This commit is contained in:
10
firewall.go
10
firewall.go
@@ -423,7 +423,7 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
|||||||
|
|
||||||
// Drop returns an error if the packet should be dropped, explaining why. It
|
// Drop returns an error if the packet should be dropped, explaining why. It
|
||||||
// returns nil if the packet should not be dropped.
|
// returns nil if the packet should not be dropped.
|
||||||
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
|
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache *firewall.ConntrackCache) error {
|
||||||
// Check if we spoke to this tuple, if we did then allow this packet
|
// Check if we spoke to this tuple, if we did then allow this packet
|
||||||
if f.inConns(fp, h, caPool, localCache) {
|
if f.inConns(fp, h, caPool, localCache) {
|
||||||
return nil
|
return nil
|
||||||
@@ -490,12 +490,10 @@ func (f *Firewall) EmitStats() {
|
|||||||
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
|
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
|
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache *firewall.ConntrackCache) bool {
|
||||||
if localCache != nil {
|
if localCache != nil && localCache.Has(fp) {
|
||||||
if _, ok := localCache[fp]; ok {
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
|
||||||
conntrack := f.Conntrack
|
conntrack := f.Conntrack
|
||||||
conntrack.Lock()
|
conntrack.Lock()
|
||||||
|
|
||||||
@@ -559,7 +557,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
|||||||
conntrack.Unlock()
|
conntrack.Unlock()
|
||||||
|
|
||||||
if localCache != nil {
|
if localCache != nil {
|
||||||
localCache[fp] = struct{}{}
|
localCache.Add(fp)
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package firewall
|
package firewall
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -9,13 +10,58 @@ import (
|
|||||||
|
|
||||||
// ConntrackCache is used as a local routine cache to know if a given flow
|
// ConntrackCache is used as a local routine cache to know if a given flow
|
||||||
// has been seen in the conntrack table.
|
// has been seen in the conntrack table.
|
||||||
type ConntrackCache map[Packet]struct{}
|
type ConntrackCache struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
entries map[Packet]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newConntrackCache() *ConntrackCache {
|
||||||
|
return &ConntrackCache{entries: make(map[Packet]struct{})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConntrackCache) Has(p Packet) bool {
|
||||||
|
if c == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
_, ok := c.entries[p]
|
||||||
|
c.mu.Unlock()
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConntrackCache) Add(p Packet) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
c.entries[p] = struct{}{}
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConntrackCache) Len() int {
|
||||||
|
if c == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
l := len(c.entries)
|
||||||
|
c.mu.Unlock()
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConntrackCache) Reset(capHint int) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
c.entries = make(map[Packet]struct{}, capHint)
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
type ConntrackCacheTicker struct {
|
type ConntrackCacheTicker struct {
|
||||||
cacheV uint64
|
cacheV uint64
|
||||||
cacheTick atomic.Uint64
|
cacheTick atomic.Uint64
|
||||||
|
|
||||||
cache ConntrackCache
|
cache *ConntrackCache
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
|
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
|
||||||
@@ -23,9 +69,7 @@ func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
c := &ConntrackCacheTicker{
|
c := &ConntrackCacheTicker{cache: newConntrackCache()}
|
||||||
cache: ConntrackCache{},
|
|
||||||
}
|
|
||||||
|
|
||||||
go c.tick(d)
|
go c.tick(d)
|
||||||
|
|
||||||
@@ -41,17 +85,17 @@ func (c *ConntrackCacheTicker) tick(d time.Duration) {
|
|||||||
|
|
||||||
// Get checks if the cache ticker has moved to the next version before returning
|
// Get checks if the cache ticker has moved to the next version before returning
|
||||||
// the map. If it has moved, we reset the map.
|
// the map. If it has moved, we reset the map.
|
||||||
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
|
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) *ConntrackCache {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if tick := c.cacheTick.Load(); tick != c.cacheV {
|
if tick := c.cacheTick.Load(); tick != c.cacheV {
|
||||||
c.cacheV = tick
|
c.cacheV = tick
|
||||||
if ll := len(c.cache); ll > 0 {
|
if ll := c.cache.Len(); ll > 0 {
|
||||||
if l.Level == logrus.DebugLevel {
|
if l.Level == logrus.DebugLevel {
|
||||||
l.WithField("len", ll).Debug("resetting conntrack cache")
|
l.WithField("len", ll).Debug("resetting conntrack cache")
|
||||||
}
|
}
|
||||||
c.cache = make(ConntrackCache, ll)
|
c.cache.Reset(ll)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache *firewall.ConntrackCache) {
|
||||||
err := newPacket(packet, false, fwPacket)
|
err := newPacket(packet, false, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ const (
|
|||||||
minFwPacketLen = 4
|
minFwPacketLen = 4
|
||||||
)
|
)
|
||||||
|
|
||||||
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
|
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache *firewall.ConntrackCache) {
|
||||||
err := h.Parse(packet)
|
err := h.Parse(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
||||||
@@ -466,7 +466,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, addr netip.AddrPort, recvIndex uint32) bool {
|
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache *firewall.ConntrackCache, addr netip.AddrPort, recvIndex uint32) bool {
|
||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
pkt *overlay.Packet
|
pkt *overlay.Packet
|
||||||
|
|||||||
Reference in New Issue
Block a user