diff --git a/connection_manager_test.go b/connection_manager_test.go index ecd2880..3f0dc40 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -22,7 +22,7 @@ func newTestLighthouse() *LightHouse { addrMap: map[netip.Addr]*RemoteList{}, queryChan: make(chan netip.Addr, 10), } - lighthouses := map[netip.Addr]struct{}{} + lighthouses := []netip.Addr{} staticList := map[netip.Addr]struct{}{} lh.lighthouses.Store(&lighthouses) diff --git a/lighthouse.go b/lighthouse.go index 1b67e3a..809b04e 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -57,7 +57,7 @@ type LightHouse struct { // staticList exists to avoid having a bool in each addrMap entry // since static should be rare staticList atomic.Pointer[map[netip.Addr]struct{}] - lighthouses atomic.Pointer[map[netip.Addr]struct{}] + lighthouses atomic.Pointer[[]netip.Addr] interval atomic.Int64 updateCancel context.CancelFunc @@ -108,7 +108,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), l: l, } - lighthouses := make(map[netip.Addr]struct{}) + lighthouses := make([]netip.Addr, 0) h.lighthouses.Store(&lighthouses) staticList := make(map[netip.Addr]struct{}) h.staticList.Store(&staticList) @@ -144,7 +144,7 @@ func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} { return *lh.staticList.Load() } -func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} { +func (lh *LightHouse) GetLighthouses() []netip.Addr { return *lh.lighthouses.Load() } @@ -307,13 +307,12 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { } if initial || c.HasChanged("lighthouse.hosts") { - lhMap := make(map[netip.Addr]struct{}) - err := lh.parseLighthouses(c, lhMap) + lhList, err := lh.parseLighthouses(c) if err != nil { return err } - lh.lighthouses.Store(&lhMap) + lh.lighthouses.Store(&lhList) if !initial { //NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic lh.l.Info("lighthouse.hosts has changed") @@ -347,36 +346,37 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { return nil } -func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error { +func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) { lhs := c.GetStringSlice("lighthouse.hosts", []string{}) if lh.amLighthouse && len(lhs) != 0 { lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config") } + out := make([]netip.Addr, len(lhs)) for i, host := range lhs { addr, err := netip.ParseAddr(host) if err != nil { - return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err) + return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err) } if !lh.myVpnNetworksTable.Contains(addr) { - return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil) + return nil, util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil) } - lhMap[addr] = struct{}{} + out[i] = addr } - if !lh.amLighthouse && len(lhMap) == 0 { + if !lh.amLighthouse && len(out) == 0 { lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries") } staticList := lh.GetStaticHostList() - for lhAddr, _ := range lhMap { - if _, ok := staticList[lhAddr]; !ok { - return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr) + for i := range out { + if _, ok := staticList[out[i]]; !ok { + return nil, fmt.Errorf("lighthouse %s does not have a static_host_map entry", out[i]) } } - return nil + return out, nil } func getStaticMapCadence(c *config.C) (time.Duration, error) { @@ -711,15 +711,22 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo } func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool { - _, ok := lh.GetLighthouses()[vpnAddr] - return ok + l := lh.GetLighthouses() + for i := range l { + if l[i] == vpnAddr { + return true + } + } + return false } -func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddr []netip.Addr) bool { +func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool { l := lh.GetLighthouses() - for _, a := range vpnAddr { - if _, ok := l[a]; ok { - return true + for i := range vpnAddrs { + for j := range l { + if l[j] == vpnAddrs[i] { + return true + } } } return false @@ -761,7 +768,7 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { queried := 0 lighthouses := lh.GetLighthouses() - for lhVpnAddr := range lighthouses { + for _, lhVpnAddr := range lighthouses { hi := lh.ifce.GetHostInfo(lhVpnAddr) if hi != nil { v = hi.ConnectionState.myCert.Version() @@ -879,7 +886,7 @@ func (lh *LightHouse) SendUpdate() { updated := 0 lighthouses := lh.GetLighthouses() - for lhVpnAddr := range lighthouses { + for _, lhVpnAddr := range lighthouses { var v cert.Version hi := lh.ifce.GetHostInfo(lhVpnAddr) if hi != nil { @@ -1289,7 +1296,6 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { //It's possible the lighthouse is communicating with us using a non primary vpn addr, //which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs. - //maybe one day we'll have a better idea, if it matters. if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) { return }