diff --git a/connection_manager.go b/connection_manager.go index a189756..924d198 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -3,7 +3,6 @@ package nebula import ( "bytes" "context" - "sync" "time" "github.com/rcrowley/go-metrics" @@ -27,14 +26,14 @@ const ( type connectionManager struct { in map[uint32]struct{} - inLock *sync.RWMutex + inLock syncRWMutex out map[uint32]struct{} - outLock *sync.RWMutex + outLock syncRWMutex // relayUsed holds which relay localIndexs are in use relayUsed map[uint32]struct{} - relayUsedLock *sync.RWMutex + relayUsedLock syncRWMutex hostMap *HostMap trafficTimer *LockingTimerWheel[uint32] @@ -59,11 +58,11 @@ func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface nc := &connectionManager{ hostMap: intf.hostMap, in: make(map[uint32]struct{}), - inLock: &sync.RWMutex{}, + inLock: newSyncRWMutex(mutexKey{Type: mutexKeyTypeConnectionManagerIn}), out: make(map[uint32]struct{}), - outLock: &sync.RWMutex{}, + outLock: newSyncRWMutex(mutexKey{Type: mutexKeyTypeConnectionManagerOut}), relayUsed: make(map[uint32]struct{}), - relayUsedLock: &sync.RWMutex{}, + relayUsedLock: newSyncRWMutex(mutexKey{Type: mutexKeyTypeConnectionManagerRelayUsed}), trafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, max), intf: intf, pendingDeletion: make(map[uint32]struct{}), diff --git a/connection_state.go b/connection_state.go index 8ef8b3a..31d2102 100644 --- a/connection_state.go +++ b/connection_state.go @@ -3,7 +3,6 @@ package nebula import ( "crypto/rand" "encoding/json" - "sync" "sync/atomic" "github.com/flynn/noise" @@ -23,7 +22,7 @@ type ConnectionState struct { initiator bool messageCounter atomic.Uint64 window *Bits - writeLock sync.Mutex + writeLock syncMutex } func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { @@ -71,6 +70,7 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i initiator: initiator, window: b, myCert: certState.Certificate, + writeLock: newSyncMutex(mutexKey{Type: mutexKeyTypeConnectionStateWrite}), } return ci diff --git a/firewall.go b/firewall.go index 64fada3..513aaf5 100644 --- a/firewall.go +++ b/firewall.go @@ -11,7 +11,6 @@ import ( "reflect" "strconv" "strings" - "sync" "time" "github.com/rcrowley/go-metrics" @@ -78,7 +77,7 @@ type firewallMetrics struct { } type FirewallConntrack struct { - sync.Mutex + syncMutex Conns map[firewall.Packet]*conn TimerWheel *TimerWheel[firewall.Packet] @@ -149,6 +148,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D return &Firewall{ Conntrack: &FirewallConntrack{ + syncMutex: newSyncMutex(mutexKey{Type: mutexKeyTypeFirewallConntrack}), Conns: make(map[firewall.Packet]*conn), TimerWheel: NewTimerWheel[firewall.Packet](min, max), }, diff --git a/go.mod b/go.mod index f84a9fc..da13f10 100644 --- a/go.mod +++ b/go.mod @@ -39,8 +39,11 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/emirpasic/gods v1.18.1 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/google/btree v1.0.1 // indirect + github.com/google/uuid v1.3.0 // indirect + github.com/heimdalr/dag v1.4.0 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 // indirect diff --git a/go.sum b/go.sum index 32c9d12..876e354 100644 --- a/go.sum +++ b/go.sum @@ -22,6 +22,8 @@ github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go. github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= +github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/flynn/noise v1.0.0 h1:DlTHqmzmvcEiKj+4RYo/imoswx/4r6iBlCMfVtrMXpQ= github.com/flynn/noise v1.0.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= @@ -31,6 +33,7 @@ github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9 github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/go-test/deep v1.1.0/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= @@ -58,6 +61,10 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/heimdalr/dag v1.4.0 h1:zG3JA4RDVLc55k3AXAgfwa+EgBNZ0TkfOO3C29Ucpmg= +github.com/heimdalr/dag v1.4.0/go.mod h1:OCh6ghKmU0hPjtwMqWBoNxPmtRioKd1xSu7Zs4sbIqM= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= diff --git a/handshake_manager.go b/handshake_manager.go index e49b9ba..62204b3 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -7,7 +7,6 @@ import ( "encoding/binary" "errors" "net" - "sync" "time" "github.com/rcrowley/go-metrics" @@ -65,7 +64,7 @@ type HandshakeManager struct { } type HandshakeHostInfo struct { - sync.Mutex + syncMutex startTime time.Time // Time that we first started trying with this handshake ready bool // Is the handshake ready @@ -397,6 +396,7 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han } hh := &HandshakeHostInfo{ + syncMutex: newSyncMutex(mutexKey{Type: mutexKeyTypeHandshakeHostInfo, ID: uint32(vpnIp)}), hostinfo: hostinfo, startTime: time.Now(), } diff --git a/lighthouse.go b/lighthouse.go index 2193ad3..7fc573a 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -7,7 +7,6 @@ import ( "fmt" "net" "net/netip" - "sync" "sync/atomic" "time" @@ -33,7 +32,7 @@ type netIpAndPort struct { type LightHouse struct { //TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time - sync.RWMutex //Because we concurrently read and write to our maps + syncRWMutex //Because we concurrently read and write to our maps ctx context.Context amLighthouse bool myVpnIp iputil.VpnIp @@ -101,6 +100,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, ones, _ := myVpnNet.Mask.Size() h := LightHouse{ + syncRWMutex: newSyncRWMutex(mutexKey{Type: mutexKeyTypeLightHouse}), ctx: ctx, amLighthouse: amLighthouse, myVpnIp: iputil.Ip2VpnIp(myVpnNet.IP), diff --git a/mutex.go b/mutex.go index 1bb3c20..bcf6d88 100644 --- a/mutex.go +++ b/mutex.go @@ -1,27 +1,59 @@ -//go:build !mutex_debug -// +build !mutex_debug - package nebula -import ( - "sync" -) - -type syncRWMutex = sync.RWMutex +import "fmt" type mutexKeyType string const ( - mutexKeyTypeHostMap mutexKeyType = "hostmap" - mutexKeyTypeHostInfo = "hostinfo" - mutexKeyTypeHandshakeManager = "handshake-manager" + mutexKeyTypeHostMap mutexKeyType = "hostmap" + + mutexKeyTypeLightHouse = "lighthouse" + mutexKeyTypeFirewallConntrack = "firewall-conntrack" + mutexKeyTypeHostInfo = "hostinfo" + mutexKeyTypeHandshakeHostInfo = "handshake-hostinfo" + mutexKeyTypeHandshakeManager = "handshake-manager" + mutexKeyTypeConnectionStateWrite = "connection-state-write-lock" + + mutexKeyTypeConnectionManagerIn = "connection-manager-in-lock" + mutexKeyTypeConnectionManagerOut = "connection-manager-out-lock" + mutexKeyTypeConnectionManagerRelayUsed = "connection-manager-relay-used-lock" ) -func newSyncRWMutex(mutexKey) syncRWMutex { - return sync.RWMutex{} +// For each Key in this map, the Value is a list of lock types you can already have +// when you want to grab that Key. This ensures that locks are always fetched +// in the same order, to prevent deadlocks. +var allowedConcurrentLocks = map[mutexKeyType][]mutexKeyType{ + mutexKeyTypeHostMap: {mutexKeyTypeHandshakeHostInfo}, + mutexKeyTypeFirewallConntrack: {mutexKeyTypeHandshakeHostInfo}, + + mutexKeyTypeHandshakeManager: {mutexKeyTypeHostMap}, + mutexKeyTypeConnectionStateWrite: {mutexKeyTypeHostMap}, + + mutexKeyTypeLightHouse: {mutexKeyTypeHandshakeManager}, + + mutexKeyTypeConnectionManagerIn: {mutexKeyTypeHostMap}, + mutexKeyTypeConnectionManagerOut: {mutexKeyTypeConnectionStateWrite, mutexKeyTypeConnectionManagerIn}, + mutexKeyTypeConnectionManagerRelayUsed: {mutexKeyTypeHandshakeHostInfo}, } type mutexKey struct { Type mutexKeyType ID uint32 } + +type mutexValue struct { + file string + line int +} + +func (m mutexKey) String() string { + if m.ID == 0 { + return fmt.Sprintf("%s", m.Type) + } else { + return fmt.Sprintf("%s(%d)", m.Type, m.ID) + } +} + +func (m mutexValue) String() string { + return fmt.Sprintf("%s:%d", m.file, m.line) +} diff --git a/mutex_debug.go b/mutex_debug.go index 605964b..ce52590 100644 --- a/mutex_debug.go +++ b/mutex_debug.go @@ -8,34 +8,42 @@ import ( "runtime" "sync" + "github.com/heimdalr/dag" "github.com/timandy/routine" ) var threadLocal routine.ThreadLocal = routine.NewThreadLocalWithInitial(func() any { return map[mutexKey]mutexValue{} }) -type mutexKeyType string +var allowedDAG *dag.DAG -const ( - mutexKeyTypeHostMap mutexKeyType = "hostmap" - mutexKeyTypeHostInfo = "hostinfo" - mutexKeyTypeHandshakeManager = "handshake-manager" -) +func init() { + allowedDAG = dag.NewDAG() + for k, v := range allowedConcurrentLocks { + allowedDAG.AddVertexByID(string(k), k) + for _, t := range v { + if _, err := allowedDAG.GetVertex(string(t)); err != nil { + allowedDAG.AddVertexByID(string(t), t) + } + } + } + for k, v := range allowedConcurrentLocks { + for _, t := range v { + allowedDAG.AddEdge(string(t), string(k)) + } + } -// For each Key in this map, the Value is a list of lock types you can already have -// when you want to grab that Key. This ensures that locks are always fetched -// in the same order, to prevent deadlocks. -var allowedConcurrentLocks = map[mutexKeyType][]mutexKeyType{ - mutexKeyTypeHandshakeManager: {mutexKeyTypeHostMap}, -} + for k := range allowedConcurrentLocks { + anc, err := allowedDAG.GetAncestors(string(k)) + if err != nil { + panic(err) + } -type mutexKey struct { - Type mutexKeyType - ID uint32 -} - -type mutexValue struct { - file string - line int + var allowed []mutexKeyType + for t := range anc { + allowed = append(allowed, mutexKeyType(t)) + } + allowedConcurrentLocks[k] = allowed + } } type syncRWMutex struct { @@ -43,12 +51,23 @@ type syncRWMutex struct { mutexKey } +type syncMutex struct { + sync.Mutex + mutexKey +} + func newSyncRWMutex(key mutexKey) syncRWMutex { return syncRWMutex{ mutexKey: key, } } +func newSyncMutex(key mutexKey) syncMutex { + return syncMutex{ + mutexKey: key, + } +} + func alertMutex(err error) { panic(err) // NOTE: you could switch to this log Line and remove the panic if you want @@ -108,14 +127,17 @@ func (s *syncRWMutex) RUnlock() { s.RWMutex.RUnlock() } -func (m mutexKey) String() string { - if m.ID == 0 { - return fmt.Sprintf("%s", m.Type) - } else { - return fmt.Sprintf("%s(%d)", m.Type, m.ID) - } +func (s *syncMutex) Lock() { + m := threadLocal.Get().(map[mutexKey]mutexValue) + checkMutex(m, s.mutexKey) + v := mutexValue{} + _, v.file, v.line, _ = runtime.Caller(1) + m[s.mutexKey] = v + s.Mutex.Lock() } -func (m mutexValue) String() string { - return fmt.Sprintf("%s:%d", m.file, m.line) +func (s *syncMutex) Unlock() { + m := threadLocal.Get().(map[mutexKey]mutexValue) + delete(m, s.mutexKey) + s.Mutex.Unlock() } diff --git a/mutex_nodebug.go b/mutex_nodebug.go new file mode 100644 index 0000000..823c4a7 --- /dev/null +++ b/mutex_nodebug.go @@ -0,0 +1,19 @@ +//go:build !mutex_debug +// +build !mutex_debug + +package nebula + +import ( + "sync" +) + +type syncRWMutex = sync.RWMutex +type syncMutex = sync.Mutex + +func newSyncRWMutex(mutexKey) syncRWMutex { + return sync.RWMutex{} +} + +func newSyncMutex(mutexKey) syncMutex { + return sync.Mutex{} +}