diff --git a/control_test.go b/control_test.go index fbf29c0..66a118a 100644 --- a/control_test.go +++ b/control_test.go @@ -66,7 +66,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { localIndexId: 201, vpnIp: vpnIp, relayState: RelayState{ - relays: map[netip.Addr]struct{}{}, + relays: nil, relayForByIp: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, @@ -85,7 +85,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { localIndexId: 201, vpnIp: vpnIp2, relayState: RelayState{ - relays: map[netip.Addr]struct{}{}, + relays: nil, relayForByIp: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, diff --git a/handshake_ix.go b/handshake_ix.go index 150e129..0e9c62a 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -151,7 +151,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet HandshakePacket: make(map[uint8][]byte, 0), lastHandshakeTime: hs.Details.Time, relayState: RelayState{ - relays: map[netip.Addr]struct{}{}, + relays: nil, relayForByIp: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, diff --git a/handshake_manager.go b/handshake_manager.go index d87ff02..56472dd 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -403,7 +403,7 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands vpnIp: vpnIp, HandshakePacket: make(map[uint8][]byte, 0), relayState: RelayState{ - relays: map[netip.Addr]struct{}{}, + relays: nil, relayForByIp: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, diff --git a/hostmap.go b/hostmap.go index f7da0ad..0b34de4 100644 --- a/hostmap.go +++ b/hostmap.go @@ -4,6 +4,7 @@ import ( "errors" "net" "net/netip" + "slices" "sync" "sync/atomic" "time" @@ -69,15 +70,20 @@ type HostMap struct { type RelayState struct { sync.RWMutex - relays map[netip.Addr]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer - relayForByIp map[netip.Addr]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info - relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info + relays []netip.Addr // Ordered set of VpnIp's of Hosts to use as relays to access this peer + relayForByIp map[netip.Addr]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info + relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info } func (rs *RelayState) DeleteRelay(ip netip.Addr) { rs.Lock() defer rs.Unlock() - delete(rs.relays, ip) + for idx, val := range rs.relays { + if val == ip { + rs.relays = append(rs.relays[:idx], rs.relays[idx+1:]...) + return + } + } } func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) { @@ -122,16 +128,16 @@ func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) { func (rs *RelayState) InsertRelayTo(ip netip.Addr) { rs.Lock() defer rs.Unlock() - rs.relays[ip] = struct{}{} + if !slices.Contains(rs.relays, ip) { + rs.relays = append(rs.relays, ip) + } } func (rs *RelayState) CopyRelayIps() []netip.Addr { + ret := make([]netip.Addr, len(rs.relays)) rs.RLock() defer rs.RUnlock() - ret := make([]netip.Addr, 0, len(rs.relays)) - for ip := range rs.relays { - ret = append(ret, ip) - } + copy(ret, rs.relays) return ret } diff --git a/hostmap_test.go b/hostmap_test.go index 7e2feb8..6eb8751 100644 --- a/hostmap_test.go +++ b/hostmap_test.go @@ -7,6 +7,7 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestHostMap_MakePrimary(t *testing.T) { @@ -225,3 +226,31 @@ func TestHostMap_reload(t *testing.T) { c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]") assert.EqualValues(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges())) } + +func TestHostMap_RelayState(t *testing.T) { + h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1} + a1 := netip.MustParseAddr("::1") + a2 := netip.MustParseAddr("2001::1") + + h1.relayState.InsertRelayTo(a1) + assert.Equal(t, h1.relayState.relays, []netip.Addr{a1}) + h1.relayState.InsertRelayTo(a2) + assert.Equal(t, h1.relayState.relays, []netip.Addr{a1, a2}) + // Ensure that the first relay added is the first one returned in the copy + currentRelays := h1.relayState.CopyRelayIps() + require.Len(t, currentRelays, 2) + assert.Equal(t, currentRelays[0], a1) + + // Deleting the last one in the list works ok + h1.relayState.DeleteRelay(a2) + assert.Equal(t, h1.relayState.relays, []netip.Addr{a1}) + + // Deleting an element not in the list works ok + h1.relayState.DeleteRelay(a2) + assert.Equal(t, h1.relayState.relays, []netip.Addr{a1}) + + // Deleting the only element in the list works ok + h1.relayState.DeleteRelay(a1) + assert.Equal(t, h1.relayState.relays, []netip.Addr{}) + +}