also validate hostinfo locks

This commit is contained in:
Wade Simmons 2023-05-09 11:22:55 -04:00
parent 3e5e48f937
commit 9105eba939
3 changed files with 46 additions and 28 deletions

View File

@ -261,7 +261,7 @@ func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRang
r := map[uint32]*HostInfo{} r := map[uint32]*HostInfo{}
relays := map[uint32]*HostInfo{} relays := map[uint32]*HostInfo{}
m := HostMap{ m := HostMap{
syncRWMutex: newSyncRWMutex("hostmap", name), syncRWMutex: newSyncRWMutex(mutexKey{Type: "hostmap", SubType: name}),
name: name, name: name,
Indexes: i, Indexes: i,
Relays: relays, Relays: relays,
@ -322,7 +322,7 @@ func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp, init func(hostinfo *HostInfo)) (
if h, ok := hm.Hosts[vpnIp]; !ok { if h, ok := hm.Hosts[vpnIp]; !ok {
hm.RUnlock() hm.RUnlock()
h = &HostInfo{ h = &HostInfo{
syncRWMutex: newSyncRWMutex("hostinfo"), syncRWMutex: newSyncRWMutex(mutexKey{Type: "hostinfo", ID: uint32(vpnIp)}),
vpnIp: vpnIp, vpnIp: vpnIp,
HandshakePacket: make(map[uint8][]byte, 0), HandshakePacket: make(map[uint8][]byte, 0),
relayState: RelayState{ relayState: RelayState{

View File

@ -9,6 +9,12 @@ import (
type syncRWMutex = sync.RWMutex type syncRWMutex = sync.RWMutex
func newSyncRWMutex(t ...string) syncRWMutex { func newSyncRWMutex(mutexKey) syncRWMutex {
return sync.RWMutex{} return sync.RWMutex{}
} }
type mutexKey struct {
Type string
SubType string
ID uint32
}

View File

@ -4,63 +4,75 @@
package nebula package nebula
import ( import (
"strings" "fmt"
"sync" "sync"
"github.com/timandy/routine" "github.com/timandy/routine"
) )
var threadLocal routine.ThreadLocal = routine.NewThreadLocalWithInitial(func() any { return map[string]bool{} }) var threadLocal routine.ThreadLocal = routine.NewThreadLocalWithInitial(func() any { return map[mutexKey]bool{} })
type mutexKey struct {
Type string
SubType string
ID uint32
}
type syncRWMutex struct { type syncRWMutex struct {
sync.RWMutex sync.RWMutex
mutexType string mutexKey
} }
func newSyncRWMutex(t ...string) syncRWMutex { func newSyncRWMutex(key mutexKey) syncRWMutex {
return syncRWMutex{ return syncRWMutex{
mutexType: strings.Join(t, "-"), mutexKey: key,
} }
} }
func checkMutex(state map[string]bool, add string) { func checkMutex(state map[mutexKey]bool, add mutexKey) {
if add == "hostinfo" { switch add.Type {
if state["hostmap-main"] { case "hostinfo":
panic("grabbing hostinfo lock and already have hostmap-main") // Check for any other hostinfo keys:
} for k, v := range state {
if state["hostmap-pending"] { if k.Type == "hostinfo" && v {
panic("grabbing hostinfo lock and already have hostmap-pending") panic(fmt.Errorf("grabbing hostinfo lock and already have a hostinfo lock: state=%v add=%v", state, add))
} }
} }
if add == "hostmap-pending" { if state[mutexKey{Type: "hostmap", SubType: "main"}] {
if state["hostmap-main"] { panic(fmt.Errorf("grabbing hostinfo lock and already have hostmap-main: state=%v add=%v", state, add))
panic("grabbing hostmap-pending lock and already have hostmap-main") }
if state[mutexKey{Type: "hostmap", SubType: "pending"}] {
panic(fmt.Errorf("grabbing hostinfo lock and already have hostmap-pending: state=%v add=%v", state, add))
}
case "hostmap-pending":
if state[mutexKey{Type: "hostmap", SubType: "main"}] {
panic(fmt.Errorf("grabbing hostmap-pending lock and already have hostmap-main: state=%v add=%v", state, add))
} }
} }
} }
func (s *syncRWMutex) Lock() { func (s *syncRWMutex) Lock() {
m := threadLocal.Get().(map[string]bool) m := threadLocal.Get().(map[mutexKey]bool)
checkMutex(m, s.mutexType) checkMutex(m, s.mutexKey)
m[s.mutexType] = true m[s.mutexKey] = true
s.RWMutex.Lock() s.RWMutex.Lock()
} }
func (s *syncRWMutex) Unlock() { func (s *syncRWMutex) Unlock() {
m := threadLocal.Get().(map[string]bool) m := threadLocal.Get().(map[mutexKey]bool)
m[s.mutexType] = false m[s.mutexKey] = false
s.RWMutex.Unlock() s.RWMutex.Unlock()
} }
func (s *syncRWMutex) RLock() { func (s *syncRWMutex) RLock() {
m := threadLocal.Get().(map[string]bool) m := threadLocal.Get().(map[mutexKey]bool)
checkMutex(m, s.mutexType) checkMutex(m, s.mutexKey)
m[s.mutexType] = true m[s.mutexKey] = true
s.RWMutex.RLock() s.RWMutex.RLock()
} }
func (s *syncRWMutex) RUnlock() { func (s *syncRWMutex) RUnlock() {
m := threadLocal.Get().(map[string]bool) m := threadLocal.Get().(map[mutexKey]bool)
m[s.mutexType] = false m[s.mutexKey] = false
s.RWMutex.RUnlock() s.RWMutex.RUnlock()
} }