Remove handshake race avoidance (#820)

Co-authored-by: Wade Simmons <wadey@slack-corp.com>
This commit is contained in:
Nate Brown 2023-03-13 12:35:14 -05:00 committed by GitHub
parent 2ea360e5e2
commit 92cc32f844
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 742 additions and 158 deletions

View File

@ -181,6 +181,14 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
continue continue
} }
// Does the vpnIp point to this hostinfo or is it ancillary? If we have ancillary hostinfos then we need to
// decide if this should be the main hostinfo if we are seeing traffic on it
primary, _ := n.hostMap.QueryVpnIp(hostinfo.vpnIp)
mainHostInfo := true
if primary != nil && primary != hostinfo {
mainHostInfo = false
}
// If we saw an incoming packets from this ip and peer's certificate is not // If we saw an incoming packets from this ip and peer's certificate is not
// expired, just ignore. // expired, just ignore.
if traf { if traf {
@ -191,6 +199,20 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
} }
n.ClearLocalIndex(localIndex) n.ClearLocalIndex(localIndex)
n.ClearPendingDeletion(localIndex) n.ClearPendingDeletion(localIndex)
if !mainHostInfo {
if hostinfo.vpnIp > n.intf.myVpnIp {
// We are receiving traffic on the non primary hostinfo and we really just want 1 tunnel. Make
// This the primary and prime the old primary hostinfo for testing
n.hostMap.MakePrimary(hostinfo)
n.Out(primary.localIndexId)
} else {
// This hostinfo is still being used despite not being the primary hostinfo for this vpn ip
// Keep tracking so that we can tear it down when it goes away
n.Out(hostinfo.localIndexId)
}
}
continue continue
} }
@ -198,7 +220,7 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
WithField("tunnelCheck", m{"state": "testing", "method": "active"}). WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
Debug("Tunnel status") Debug("Tunnel status")
if hostinfo != nil && hostinfo.ConnectionState != nil { if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
n.intf.sendMessageToVpnIp(header.Test, header.TestRequest, hostinfo, p, nb, out) n.intf.sendMessageToVpnIp(header.Test, header.TestRequest, hostinfo, p, nb, out)

View File

@ -80,7 +80,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
certState: cs, certState: cs,
H: &noise.HandshakeState{}, H: &noise.HandshakeState{},
} }
nc.hostMap.addHostInfo(hostinfo, ifce) nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
// We saw traffic out to vpnIp // We saw traffic out to vpnIp
nc.Out(hostinfo.localIndexId) nc.Out(hostinfo.localIndexId)
@ -156,7 +156,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
certState: cs, certState: cs,
H: &noise.HandshakeState{}, H: &noise.HandshakeState{},
} }
nc.hostMap.addHostInfo(hostinfo, ifce) nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
// We saw traffic out to vpnIp // We saw traffic out to vpnIp
nc.Out(hostinfo.localIndexId) nc.Out(hostinfo.localIndexId)

View File

@ -95,12 +95,21 @@ func (c *Control) RebindUDPServer() {
c.f.rebindCount++ c.f.rebindCount++
} }
// ListHostmap returns details about the actual or pending (handshaking) hostmap // ListHostmapHosts returns details about the actual or pending (handshaking) hostmap by vpn ip
func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo { func (c *Control) ListHostmapHosts(pendingMap bool) []ControlHostInfo {
if pendingMap { if pendingMap {
return listHostMap(c.f.handshakeManager.pendingHostMap) return listHostMapHosts(c.f.handshakeManager.pendingHostMap)
} else { } else {
return listHostMap(c.f.hostMap) return listHostMapHosts(c.f.hostMap)
}
}
// ListHostmapIndexes returns details about the actual or pending (handshaking) hostmap by local index id
func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
if pendingMap {
return listHostMapIndexes(c.f.handshakeManager.pendingHostMap)
} else {
return listHostMapIndexes(c.f.hostMap)
} }
} }
@ -232,7 +241,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
return chi return chi
} }
func listHostMap(hm *HostMap) []ControlHostInfo { func listHostMapHosts(hm *HostMap) []ControlHostInfo {
hm.RLock() hm.RLock()
hosts := make([]ControlHostInfo, len(hm.Hosts)) hosts := make([]ControlHostInfo, len(hm.Hosts))
i := 0 i := 0
@ -244,3 +253,16 @@ func listHostMap(hm *HostMap) []ControlHostInfo {
return hosts return hosts
} }
func listHostMapIndexes(hm *HostMap) []ControlHostInfo {
hm.RLock()
hosts := make([]ControlHostInfo, len(hm.Indexes))
i := 0
for _, v := range hm.Indexes {
hosts[i] = copyHostInfo(v, hm.preferredRanges)
i++
}
hm.RUnlock()
return hosts
}

View File

@ -19,10 +19,10 @@ import (
func BenchmarkHotPath(b *testing.B) { func BenchmarkHotPath(b *testing.B) {
ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
myControl, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) myControl, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
// Put their info in our lighthouse // Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
// Start the servers // Start the servers
myControl.Start() myControl.Start()
@ -32,7 +32,7 @@ func BenchmarkHotPath(b *testing.B) {
r.CancelFlowLogs() r.CancelFlowLogs()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
_ = r.RouteForAllUntilTxTun(theirControl) _ = r.RouteForAllUntilTxTun(theirControl)
} }
@ -42,18 +42,18 @@ func BenchmarkHotPath(b *testing.B) {
func TestGoodHandshake(t *testing.T) { func TestGoodHandshake(t *testing.T) {
ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
// Put their info in our lighthouse // Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
// Start the servers // Start the servers
myControl.Start() myControl.Start()
theirControl.Start() theirControl.Start()
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
t.Log("Have them consume my stage 0 packet. They have a tunnel now") t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
@ -74,16 +74,16 @@ func TestGoodHandshake(t *testing.T) {
myControl.WaitForType(1, 0, theirControl) myControl.WaitForType(1, 0, theirControl)
t.Log("Make sure our host infos are correct") t.Log("Make sure our host infos are correct")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl) assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl)
t.Log("Get that cached packet and make sure it looks right") t.Log("Get that cached packet and make sure it looks right")
myCachedPacket := theirControl.GetFromTun(true) myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80) assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
t.Log("Do a bidirectional tunnel test") t.Log("Do a bidirectional tunnel test")
r := router.NewR(t, myControl, theirControl) r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow() defer r.RenderFlow()
assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r) assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
r.RenderHostmaps("Final hostmaps", myControl, theirControl) r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop() myControl.Stop()
@ -97,15 +97,15 @@ func TestWrongResponderHandshake(t *testing.T) {
// The IPs here are chosen on purpose: // The IPs here are chosen on purpose:
// The current remote handling will sort by preference, public, and then lexically. // The current remote handling will sort by preference, public, and then lexically.
// So we need them to have a higher address than evil (we could apply a preference though) // So we need them to have a higher address than evil (we could apply a preference though)
myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil) myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil)
theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil) theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil)
evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil) evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil)
// Add their real udp addr, which should be tried after evil. // Add their real udp addr, which should be tried after evil.
myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
// Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse. // Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse.
myControl.InjectLightHouseAddr(theirVpnIp, evilUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet.IP, evilUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl, evilControl) r := router.NewR(t, myControl, theirControl, evilControl)
@ -117,7 +117,7 @@ func TestWrongResponderHandshake(t *testing.T) {
evilControl.Start() evilControl.Start()
t.Log("Start the handshake process, we will route until we see our cached packet get sent to them") t.Log("Start the handshake process, we will route until we see our cached packet get sent to them")
myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
h := &header.H{} h := &header.H{}
err := h.Parse(p.Data) err := h.Parse(p.Data)
@ -136,18 +136,18 @@ func TestWrongResponderHandshake(t *testing.T) {
t.Log("My cached packet should be received by them") t.Log("My cached packet should be received by them")
myCachedPacket := theirControl.GetFromTun(true) myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80) assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
t.Log("Test the tunnel with them") t.Log("Test the tunnel with them")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl) assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl)
assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r) assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
t.Log("Flush all packets from all controllers") t.Log("Flush all packets from all controllers")
r.FlushAll() r.FlushAll()
t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp), true), "My pending hostmap should not contain evil") assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), true), "My pending hostmap should not contain evil")
assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp), false), "My main hostmap should not contain evil") assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), false), "My main hostmap should not contain evil")
//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete //NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
//TODO: assert hostmaps for everyone //TODO: assert hostmaps for everyone
@ -157,14 +157,17 @@ func TestWrongResponderHandshake(t *testing.T) {
theirControl.Stop() theirControl.Stop()
} }
func Test_Case1_Stage1Race(t *testing.T) { func TestStage1Race(t *testing.T) {
// This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow
// But will eventually collapse down to a single tunnel
ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil)
theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
// Put their info in our lighthouse and vice versa // Put their info in our lighthouse and vice versa
myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIp, myUdpAddr) theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl) r := router.NewR(t, myControl, theirControl)
@ -175,8 +178,8 @@ func Test_Case1_Stage1Race(t *testing.T) {
theirControl.Start() theirControl.Start()
t.Log("Trigger a handshake to start on both me and them") t.Log("Trigger a handshake to start on both me and them")
myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIp, 80, 80, []byte("Hi from them")) theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))
t.Log("Get both stage 1 handshake packets") t.Log("Get both stage 1 handshake packets")
myHsForThem := myControl.GetFromUDP(true) myHsForThem := myControl.GetFromUDP(true)
@ -185,44 +188,165 @@ func Test_Case1_Stage1Race(t *testing.T) {
r.Log("Now inject both stage 1 handshake packets") r.Log("Now inject both stage 1 handshake packets")
r.InjectUDPPacket(theirControl, myControl, theirHsForMe) r.InjectUDPPacket(theirControl, myControl, theirHsForMe)
r.InjectUDPPacket(myControl, theirControl, myHsForThem) r.InjectUDPPacket(myControl, theirControl, myHsForThem)
//TODO: they should win, grab their index for me and make sure I use it in the end.
r.Log("They should not have a stage 2 (won the race) but I should send one") r.Log("Route until they receive a message packet")
r.InjectUDPPacket(myControl, theirControl, myControl.GetFromUDP(true)) myCachedPacket := r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
r.Log("Route for me until I send a message packet to them") r.Log("Their cached packet should be received by me")
r.RouteForAllUntilAfterMsgTypeTo(theirControl, header.Message, header.MessageNone) theirCachedPacket := r.RouteForAllUntilTxTun(myControl)
assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80)
t.Log("My cached packet should be received by them") r.Log("Do a bidirectional tunnel test")
myCachedPacket := theirControl.GetFromTun(true) assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
t.Log("Route for them until I send a message packet to me") myHostmapHosts := myControl.ListHostmapHosts(false)
theirControl.WaitForType(1, 0, myControl) myHostmapIndexes := myControl.ListHostmapIndexes(false)
theirHostmapHosts := theirControl.ListHostmapHosts(false)
theirHostmapIndexes := theirControl.ListHostmapIndexes(false)
t.Log("Their cached packet should be received by me") // We should have two tunnels on both sides
theirCachedPacket := myControl.GetFromTun(true) assert.Len(t, myHostmapHosts, 1)
assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIp, myVpnIp, 80, 80) assert.Len(t, theirHostmapHosts, 1)
assert.Len(t, myHostmapIndexes, 2)
assert.Len(t, theirHostmapIndexes, 2)
t.Log("Do a bidirectional tunnel test") r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r)
r.Log("Spin until connection manager tears down a tunnel")
for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second)
}
myFinalHostmapHosts := myControl.ListHostmapHosts(false)
myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
// We should only have a single tunnel now on both sides
assert.Len(t, myFinalHostmapHosts, 1)
assert.Len(t, theirFinalHostmapHosts, 1)
assert.Len(t, myFinalHostmapIndexes, 1)
assert.Len(t, theirFinalHostmapIndexes, 1)
r.RenderHostmaps("Final hostmaps", myControl, theirControl) r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop() myControl.Stop()
theirControl.Stop() theirControl.Stop()
//TODO: assert hostmaps }
func TestUncleanShutdownRaceLoser(t *testing.T) {
ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil)
theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
// Start the servers
myControl.Start()
theirControl.Start()
r.Log("Trigger a handshake from me to them")
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
r.Log("Nuke my hostmap")
myHostmap := myControl.GetHostmap()
myHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{}
myHostmap.Indexes = map[uint32]*nebula.HostInfo{}
myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me again"))
p = r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
r.Log("Assert the tunnel works")
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
r.Log("Wait for the dead index to go away")
start := len(theirControl.GetHostmap().Indexes)
for {
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
if len(theirControl.GetHostmap().Indexes) < start {
break
}
time.Sleep(time.Second)
}
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
}
func TestUncleanShutdownRaceWinner(t *testing.T) {
ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil)
theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
// Start the servers
myControl.Start()
theirControl.Start()
r.Log("Trigger a handshake from me to them")
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
r.Log("Nuke my hostmap")
theirHostmap := theirControl.GetHostmap()
theirHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{}
theirHostmap.Indexes = map[uint32]*nebula.HostInfo{}
theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them again"))
p = r.RouteForAllUntilTxTun(myControl)
assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80)
r.RenderHostmaps("Derp hostmaps", myControl, theirControl)
r.Log("Assert the tunnel works")
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
r.Log("Wait for the dead index to go away")
start := len(myControl.GetHostmap().Indexes)
for {
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
if len(myControl.GetHostmap().Indexes) < start {
break
}
time.Sleep(time.Second)
}
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
} }
func TestRelays(t *testing.T) { func TestRelays(t *testing.T) {
ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
myControl, myVpnIp, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) myControl, myVpnIpNet, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
relayControl, relayVpnIp, relayUdpAddr := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) relayControl, relayVpnIpNet, relayUdpAddr := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay // Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIp, relayUdpAddr) myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
myControl.InjectRelays(theirVpnIp, []net.IP{relayVpnIp}) myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
relayControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl) r := router.NewR(t, myControl, relayControl, theirControl)
@ -234,12 +358,84 @@ func TestRelays(t *testing.T) {
theirControl.Start() theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay") t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl) p := r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIp, theirVpnIp, 80, 80) r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
//TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it //TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it
} }
func TestStage1RaceRelays(t *testing.T) {
//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
defer r.RenderFlow()
// Start the servers
myControl.Start()
relayControl.Start()
theirControl.Start()
r.Log("Trigger a handshake to start on both me and relay")
myControl.InjectTunUDPPacket(relayVpnIpNet.IP, 80, 80, []byte("Hi from me"))
relayControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from relay"))
r.Log("Get both stage 1 handshake packets")
//TODO: this is where it breaks, we need to get the hs packets for the relay not for the destination
myHsForThem := myControl.GetFromUDP(true)
relayHsForMe := relayControl.GetFromUDP(true)
r.Log("Now inject both stage 1 handshake packets")
r.InjectUDPPacket(relayControl, myControl, relayHsForMe)
r.InjectUDPPacket(myControl, relayControl, myHsForThem)
r.Log("Route for me until I send a message packet to relay")
r.RouteForAllUntilAfterMsgTypeTo(relayControl, header.Message, header.MessageNone)
r.Log("My cached packet should be received by relay")
myCachedPacket := relayControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, relayVpnIpNet.IP, 80, 80)
r.Log("Relays cached packet should be received by me")
relayCachedPacket := r.RouteForAllUntilTxTun(myControl)
assertUdpPacket(t, []byte("Hi from relay"), relayCachedPacket, relayVpnIpNet.IP, myVpnIpNet.IP, 80, 80)
r.Log("Do a bidirectional tunnel test; me and relay")
assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
r.Log("Create a tunnel between relay and them")
assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
r.RenderHostmaps("Starting hostmaps", myControl, relayControl, theirControl)
r.Log("Trigger a handshake to start from me to them via the relay")
//TODO: if we initiate a handshake from me and then assert the tunnel it will cause a relay control race that can blow up
// this is a problem that exists on master today
//myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
myControl.Stop()
theirControl.Stop()
relayControl.Stop()
//
////TODO: assert hostmaps
}
//TODO: add a test with many lies //TODO: add a test with many lies

View File

@ -30,7 +30,7 @@ import (
type m map[string]interface{} type m map[string]interface{}
// newSimpleServer creates a nebula instance with many assumptions // newSimpleServer creates a nebula instance with many assumptions
func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, net.IP, *net.UDPAddr) { func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, *net.IPNet, *net.UDPAddr) {
l := NewTestLogger() l := NewTestLogger()
vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}} vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
@ -101,7 +101,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
panic(err) panic(err)
} }
return control, vpnIpNet.IP, &udpAddr return control, vpnIpNet, &udpAddr
} }
// newTestCaCert will generate a CA cert // newTestCaCert will generate a CA cert
@ -231,12 +231,12 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control, r *router.R) { func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control, r *router.R) {
// Send a packet from them to me // Send a packet from them to me
controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B")) controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B"))
bPacket := r.RouteUntilTxTun(controlB, controlA) bPacket := r.RouteForAllUntilTxTun(controlA)
assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80) assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80)
// And once more from me to them // And once more from me to them
controlA.InjectTunUDPPacket(vpnIpB, 80, 90, []byte("Hello from A")) controlA.InjectTunUDPPacket(vpnIpB, 80, 90, []byte("Hello from A"))
aPacket := r.RouteUntilTxTun(controlA, controlB) aPacket := r.RouteForAllUntilTxTun(controlB)
assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80) assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
} }

View File

@ -5,9 +5,11 @@ package router
import ( import (
"fmt" "fmt"
"sort"
"strings" "strings"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"github.com/slackhq/nebula/iputil"
) )
type edge struct { type edge struct {
@ -64,7 +66,8 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
// Draw the vpn to index nodes // Draw the vpn to index nodes
r += fmt.Sprintf("\t\tsubgraph %s.hosts[\"Hosts (vpn ip to index)\"]\n", clusterName) r += fmt.Sprintf("\t\tsubgraph %s.hosts[\"Hosts (vpn ip to index)\"]\n", clusterName)
for vpnIp, hi := range hm.Hosts { for _, vpnIp := range sortedHosts(hm.Hosts) {
hi := hm.Hosts[vpnIp]
r += fmt.Sprintf("\t\t\t%v.%v[\"%v\"]\n", clusterName, vpnIp, vpnIp) r += fmt.Sprintf("\t\t\t%v.%v[\"%v\"]\n", clusterName, vpnIp, vpnIp)
lines = append(lines, fmt.Sprintf("%v.%v --> %v.%v", clusterName, vpnIp, clusterName, hi.GetLocalIndex())) lines = append(lines, fmt.Sprintf("%v.%v --> %v.%v", clusterName, vpnIp, clusterName, hi.GetLocalIndex()))
@ -91,7 +94,8 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
// Draw the local index to relay or remote index nodes // Draw the local index to relay or remote index nodes
r += fmt.Sprintf("\t\tsubgraph indexes.%s[\"Indexes (index to hostinfo)\"]\n", clusterName) r += fmt.Sprintf("\t\tsubgraph indexes.%s[\"Indexes (index to hostinfo)\"]\n", clusterName)
for idx, hi := range hm.Indexes { for _, idx := range sortedIndexes(hm.Indexes) {
hi := hm.Indexes[idx]
r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp()) r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp())
remoteClusterName := strings.Trim(hi.GetCert().Details.Name, " ") remoteClusterName := strings.Trim(hi.GetCert().Details.Name, " ")
globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())}) globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())})
@ -107,3 +111,29 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
r += "\tend\n" r += "\tend\n"
return r, globalLines return r, globalLines
} }
func sortedHosts(hosts map[iputil.VpnIp]*nebula.HostInfo) []iputil.VpnIp {
keys := make([]iputil.VpnIp, 0, len(hosts))
for key := range hosts {
keys = append(keys, key)
}
sort.SliceStable(keys, func(i, j int) bool {
return keys[i] > keys[j]
})
return keys
}
func sortedIndexes(indexes map[uint32]*nebula.HostInfo) []uint32 {
keys := make([]uint32, 0, len(indexes))
for key := range indexes {
keys = append(keys, key)
}
sort.SliceStable(keys, func(i, j int) bool {
return keys[i] > keys[j]
})
return keys
}

View File

@ -10,6 +10,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"reflect" "reflect"
"sort"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -22,6 +23,7 @@ import (
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"golang.org/x/exp/maps"
) )
type R struct { type R struct {
@ -150,6 +152,7 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
case <-ctx.Done(): case <-ctx.Done():
return return
case <-clockSource.C: case <-clockSource.C:
r.renderHostmaps("clock tick")
r.renderFlow() r.renderFlow()
} }
} }
@ -220,11 +223,16 @@ func (r *R) renderFlow() {
) )
} }
if len(participantsVals) > 2 {
// Get the first and last participantVals for notes
participantsVals = []string{participantsVals[0], participantsVals[len(participantsVals)-1]}
}
// Print packets // Print packets
h := &header.H{} h := &header.H{}
for _, e := range r.flow { for _, e := range r.flow {
if e.packet == nil { if e.packet == nil {
fmt.Fprintf(f, " note over %s: %s\n", strings.Join(participantsVals, ", "), e.note) //fmt.Fprintf(f, " note over %s: %s\n", strings.Join(participantsVals, ", "), e.note)
continue continue
} }
@ -294,6 +302,28 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) {
}) })
} }
func (r *R) renderHostmaps(title string) {
c := maps.Values(r.controls)
sort.SliceStable(c, func(i, j int) bool {
return c[i].GetVpnIp() > c[j].GetVpnIp()
})
s := renderHostmaps(c...)
if len(r.additionalGraphs) > 0 {
lastGraph := r.additionalGraphs[len(r.additionalGraphs)-1]
if lastGraph.content == s {
// Ignore this rendering if it matches the last rendering added
// This is useful if you want to track rendering changes
return
}
}
r.additionalGraphs = append(r.additionalGraphs, mermaidGraph{
title: title,
content: s,
})
}
// InjectFlow can be used to record packet flow if the test is handling the routing on its own. // InjectFlow can be used to record packet flow if the test is handling the routing on its own.
// The packet is assumed to have been received // The packet is assumed to have been received
func (r *R) InjectFlow(from, to *nebula.Control, p *udp.Packet) { func (r *R) InjectFlow(from, to *nebula.Control, p *udp.Packet) {
@ -332,6 +362,8 @@ func (r *R) unlockedInjectFlow(from, to *nebula.Control, p *udp.Packet, tun bool
return nil return nil
} }
r.renderHostmaps(fmt.Sprintf("Packet %v", len(r.flow)))
if len(r.ignoreFlows) > 0 { if len(r.ignoreFlows) > 0 {
var h header.H var h header.H
err := h.Parse(p.Data) err := h.Parse(p.Data)

1
go.mod
View File

@ -21,6 +21,7 @@ require (
github.com/stretchr/testify v1.8.1 github.com/stretchr/testify v1.8.1
github.com/vishvananda/netlink v1.1.0 github.com/vishvananda/netlink v1.1.0
golang.org/x/crypto v0.3.0 golang.org/x/crypto v0.3.0
golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2
golang.org/x/net v0.2.0 golang.org/x/net v0.2.0
golang.org/x/sys v0.2.0 golang.org/x/sys v0.2.0
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224

2
go.sum
View File

@ -266,6 +266,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2 h1:Jvc7gsqn21cJHCmAWx0LiimpP18LZmUxkT5Mp7EZ1mI=
golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=

View File

@ -207,9 +207,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
hostinfo.SetRemote(addr) hostinfo.SetRemote(addr)
hostinfo.CreateRemoteCIDR(remoteCert) hostinfo.CreateRemoteCIDR(remoteCert)
// Only overwrite existing record if we should win the handshake race existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
overwrite := vpnIp > f.myVpnIp
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f)
if err != nil { if err != nil {
switch err { switch err {
case ErrAlreadySeen: case ErrAlreadySeen:
@ -280,16 +278,6 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp). WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp).
Error("Failed to add HostInfo due to localIndex collision") Error("Failed to add HostInfo due to localIndex collision")
return return
case ErrExistingHandshake:
// We have a race where both parties think they are an initiator and this tunnel lost, let the other one finish
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Prevented a pending handshake race")
return
default: default:
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
// And we forget to update it here // And we forget to update it here
@ -344,6 +332,12 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
Info("Handshake message sent") Info("Handshake message sent")
} }
if existing != nil {
// Make sure we are tracking the old primary if there was one, it needs to go away eventually
f.connectionManager.Out(existing.localIndexId)
}
f.connectionManager.Out(hostinfo.localIndexId)
hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics) hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics)
return return
@ -501,8 +495,12 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo *
hostinfo.CreateRemoteCIDR(remoteCert) hostinfo.CreateRemoteCIDR(remoteCert)
// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp // Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
//TODO: Complete here does not do a race avoidance, it will just take the new tunnel. Is this ok? existing := f.handshakeManager.Complete(hostinfo, f)
f.handshakeManager.Complete(hostinfo, f) if existing != nil {
// Make sure we are tracking the old primary if there was one, it needs to go away eventually
f.connectionManager.Out(existing.localIndexId)
}
hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics) hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics)
f.metricHandshakes.Update(duration) f.metricHandshakes.Update(duration)

View File

@ -53,6 +53,10 @@ type HandshakeManager struct {
metricTimedOut metrics.Counter metricTimedOut metrics.Counter
l *logrus.Logger l *logrus.Logger
// vpnIps is another map similar to the pending hostmap but tracks entries in the wheel instead
// this is to avoid situations where the same vpn ip enters the wheel and causes rapid fire handshaking
vpnIps map[iputil.VpnIp]struct{}
// can be used to trigger outbound handshake for the given vpnIp // can be used to trigger outbound handshake for the given vpnIp
trigger chan iputil.VpnIp trigger chan iputil.VpnIp
} }
@ -66,6 +70,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [
config: config, config: config,
trigger: make(chan iputil.VpnIp, config.triggerBuffer), trigger: make(chan iputil.VpnIp, config.triggerBuffer),
OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp](config.tryInterval, hsTimeout(config.retries, config.tryInterval)), OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp](config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
vpnIps: map[iputil.VpnIp]struct{}{},
messageMetrics: config.messageMetrics, messageMetrics: config.messageMetrics,
metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil), metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil), metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil),
@ -103,6 +108,7 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.E
func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, lighthouseTriggered bool) { func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, lighthouseTriggered bool) {
hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp) hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp)
if err != nil { if err != nil {
delete(c.vpnIps, vpnIp)
return return
} }
hostinfo.Lock() hostinfo.Lock()
@ -160,7 +166,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, l
c.lightHouse.QueryServer(vpnIp, f) c.lightHouse.QueryServer(vpnIp, f)
} }
// Send a the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply // Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
var sentTo []*udp.Addr var sentTo []*udp.Addr
hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udp.Addr, _ bool) { hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
@ -260,7 +266,6 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, l
// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add // If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
if !lighthouseTriggered { if !lighthouseTriggered {
//TODO: feel like we dupe handshake real fast in a tight loop, why?
c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
} }
} }
@ -269,7 +274,10 @@ func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *H
hostinfo, created := c.pendingHostMap.AddVpnIp(vpnIp, init) hostinfo, created := c.pendingHostMap.AddVpnIp(vpnIp, init)
if created { if created {
c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval) if _, ok := c.vpnIps[vpnIp]; !ok {
c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
}
c.vpnIps[vpnIp] = struct{}{}
c.metricInitiated.Inc(1) c.metricInitiated.Inc(1)
} }
@ -280,7 +288,6 @@ var (
ErrExistingHostInfo = errors.New("existing hostinfo") ErrExistingHostInfo = errors.New("existing hostinfo")
ErrAlreadySeen = errors.New("already seen") ErrAlreadySeen = errors.New("already seen")
ErrLocalIndexCollision = errors.New("local index collision") ErrLocalIndexCollision = errors.New("local index collision")
ErrExistingHandshake = errors.New("existing handshake")
) )
// CheckAndComplete checks for any conflicts in the main and pending hostmap // CheckAndComplete checks for any conflicts in the main and pending hostmap
@ -294,7 +301,7 @@ var (
// //
// ErrLocalIndexCollision if we already have an entry in the main or pending // ErrLocalIndexCollision if we already have an entry in the main or pending
// hostmap for the hostinfo.localIndexId. // hostmap for the hostinfo.localIndexId.
func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, overwrite bool, f *Interface) (*HostInfo, error) { func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
c.pendingHostMap.Lock() c.pendingHostMap.Lock()
defer c.pendingHostMap.Unlock() defer c.pendingHostMap.Unlock()
c.mainHostMap.Lock() c.mainHostMap.Lock()
@ -303,9 +310,14 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
// Check if we already have a tunnel with this vpn ip // Check if we already have a tunnel with this vpn ip
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp] existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
if found && existingHostInfo != nil { if found && existingHostInfo != nil {
// Is it just a delayed handshake packet? testHostInfo := existingHostInfo
if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) { for testHostInfo != nil {
return existingHostInfo, ErrAlreadySeen // Is it just a delayed handshake packet?
if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
return existingHostInfo, ErrAlreadySeen
}
testHostInfo = testHostInfo.next
} }
// Is this a newer handshake? // Is this a newer handshake?
@ -337,56 +349,19 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
Info("New host shadows existing host remoteIndex") Info("New host shadows existing host remoteIndex")
} }
// Check if we are also handshaking with this vpn ip c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
pendingHostInfo, found := c.pendingHostMap.Hosts[hostinfo.vpnIp]
if found && pendingHostInfo != nil {
if !overwrite {
// We won, let our pending handshake win
return pendingHostInfo, ErrExistingHandshake
}
// We lost, take this handshake and move any cached packets over so they get sent
pendingHostInfo.ConnectionState.queueLock.Lock()
hostinfo.packetStore = append(hostinfo.packetStore, pendingHostInfo.packetStore...)
c.pendingHostMap.unlockedDeleteHostInfo(pendingHostInfo)
pendingHostInfo.ConnectionState.queueLock.Unlock()
pendingHostInfo.logger(c.l).Info("Handshake race lost, replacing pending handshake with completed tunnel")
}
if existingHostInfo != nil {
// We are going to overwrite this entry, so remove the old references
delete(c.mainHostMap.Hosts, existingHostInfo.vpnIp)
delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
for _, relayIdx := range existingHostInfo.relayState.CopyRelayForIdxs() {
delete(c.mainHostMap.Relays, relayIdx)
}
}
c.mainHostMap.addHostInfo(hostinfo, f)
return existingHostInfo, nil return existingHostInfo, nil
} }
// Complete is a simpler version of CheckAndComplete when we already know we // Complete is a simpler version of CheckAndComplete when we already know we
// won't have a localIndexId collision because we already have an entry in the // won't have a localIndexId collision because we already have an entry in the
// pendingHostMap // pendingHostMap. An existing hostinfo is returned if there was one.
func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) *HostInfo {
c.pendingHostMap.Lock() c.pendingHostMap.Lock()
defer c.pendingHostMap.Unlock() defer c.pendingHostMap.Unlock()
c.mainHostMap.Lock() c.mainHostMap.Lock()
defer c.mainHostMap.Unlock() defer c.mainHostMap.Unlock()
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
if found && existingHostInfo != nil {
// We are going to overwrite this entry, so remove the old references
delete(c.mainHostMap.Hosts, existingHostInfo.vpnIp)
delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
for _, relayIdx := range existingHostInfo.relayState.CopyRelayForIdxs() {
delete(c.mainHostMap.Relays, relayIdx)
}
}
existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
if found && existingRemoteIndex != nil { if found && existingRemoteIndex != nil {
// We have a collision, but this can happen since we can't control // We have a collision, but this can happen since we can't control
@ -396,8 +371,10 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
Info("New host shadows existing host remoteIndex") Info("New host shadows existing host remoteIndex")
} }
c.mainHostMap.addHostInfo(hostinfo, f) existingHostInfo := c.mainHostMap.Hosts[hostinfo.vpnIp]
c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
c.pendingHostMap.unlockedDeleteHostInfo(hostinfo) c.pendingHostMap.unlockedDeleteHostInfo(hostinfo)
return existingHostInfo
} }
// AddIndexHostInfo generates a unique localIndexId for this HostInfo // AddIndexHostInfo generates a unique localIndexId for this HostInfo

View File

@ -23,6 +23,10 @@ const PromoteEvery = 1000
const ReQueryEvery = 5000 const ReQueryEvery = 5000
const MaxRemotes = 10 const MaxRemotes = 10
// MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
// 5 allows for an initial handshake and each host pair re-handshaking twice
const MaxHostInfosPerVpnIp = 5
// How long we should prevent roaming back to the previous IP. // How long we should prevent roaming back to the previous IP.
// This helps prevent flapping due to packets already in flight // This helps prevent flapping due to packets already in flight
const RoamingSuppressSeconds = 2 const RoamingSuppressSeconds = 2
@ -180,6 +184,10 @@ type HostInfo struct {
lastRoam time.Time lastRoam time.Time
lastRoamRemote *udp.Addr lastRoamRemote *udp.Addr
// Used to track other hostinfos for this vpn ip since only 1 can be primary
// Synchronised via hostmap lock and not the hostinfo lock.
next, prev *HostInfo
} }
type ViaSender struct { type ViaSender struct {
@ -395,9 +403,12 @@ func (hm *HostMap) DeleteReverseIndex(index uint32) {
} }
} }
func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) { // DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip
func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
// Delete the host itself, ensuring it's not modified anymore // Delete the host itself, ensuring it's not modified anymore
hm.Lock() hm.Lock()
// If we have a previous or next hostinfo then we are not the last one for this vpn ip
final := (hostinfo.next == nil && hostinfo.prev == nil)
hm.unlockedDeleteHostInfo(hostinfo) hm.unlockedDeleteHostInfo(hostinfo)
hm.Unlock() hm.Unlock()
@ -421,6 +432,8 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) {
for _, localIdx := range teardownRelayIdx { for _, localIdx := range teardownRelayIdx {
hm.RemoveRelay(localIdx) hm.RemoveRelay(localIdx)
} }
return final
} }
func (hm *HostMap) DeleteRelayIdx(localIdx uint32) { func (hm *HostMap) DeleteRelayIdx(localIdx uint32) {
@ -429,29 +442,81 @@ func (hm *HostMap) DeleteRelayIdx(localIdx uint32) {
delete(hm.RemoteIndexes, localIdx) delete(hm.RemoteIndexes, localIdx)
} }
func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { func (hm *HostMap) MakePrimary(hostinfo *HostInfo) {
// Check if this same hostId is in the hostmap with a different instance. hm.Lock()
// This could happen if we have an entry in the pending hostmap with different defer hm.Unlock()
// index values than the one in the main hostmap. hm.unlockedMakePrimary(hostinfo)
hostinfo2, ok := hm.Hosts[hostinfo.vpnIp] }
if ok && hostinfo2 != hostinfo {
delete(hm.Hosts, hostinfo2.vpnIp) func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
delete(hm.Indexes, hostinfo2.localIndexId) oldHostinfo := hm.Hosts[hostinfo.vpnIp]
delete(hm.RemoteIndexes, hostinfo2.remoteIndexId) if oldHostinfo == hostinfo {
return
} }
delete(hm.Hosts, hostinfo.vpnIp) if hostinfo.prev != nil {
if len(hm.Hosts) == 0 { hostinfo.prev.next = hostinfo.next
hm.Hosts = map[iputil.VpnIp]*HostInfo{}
} }
if hostinfo.next != nil {
hostinfo.next.prev = hostinfo.prev
}
hm.Hosts[hostinfo.vpnIp] = hostinfo
if oldHostinfo == nil {
return
}
hostinfo.next = oldHostinfo
oldHostinfo.prev = hostinfo
hostinfo.prev = nil
}
func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
primary, ok := hm.Hosts[hostinfo.vpnIp]
if ok && primary == hostinfo {
// The vpnIp pointer points to the same hostinfo as the local index id, we can remove it
delete(hm.Hosts, hostinfo.vpnIp)
if len(hm.Hosts) == 0 {
hm.Hosts = map[iputil.VpnIp]*HostInfo{}
}
if hostinfo.next != nil {
// We had more than 1 hostinfo at this vpnip, promote the next in the list to primary
hm.Hosts[hostinfo.vpnIp] = hostinfo.next
// It is primary, there is no previous hostinfo now
hostinfo.next.prev = nil
}
} else {
// Relink if we were in the middle of multiple hostinfos for this vpn ip
if hostinfo.prev != nil {
hostinfo.prev.next = hostinfo.next
}
if hostinfo.next != nil {
hostinfo.next.prev = hostinfo.prev
}
}
hostinfo.next = nil
hostinfo.prev = nil
// The remote index uses index ids outside our control so lets make sure we are only removing
// the remote index pointer here if it points to the hostinfo we are deleting
hostinfo2, ok := hm.RemoteIndexes[hostinfo.remoteIndexId]
if ok && hostinfo2 == hostinfo {
delete(hm.RemoteIndexes, hostinfo.remoteIndexId)
if len(hm.RemoteIndexes) == 0 {
hm.RemoteIndexes = map[uint32]*HostInfo{}
}
}
delete(hm.Indexes, hostinfo.localIndexId) delete(hm.Indexes, hostinfo.localIndexId)
if len(hm.Indexes) == 0 { if len(hm.Indexes) == 0 {
hm.Indexes = map[uint32]*HostInfo{} hm.Indexes = map[uint32]*HostInfo{}
} }
delete(hm.RemoteIndexes, hostinfo.remoteIndexId)
if len(hm.RemoteIndexes) == 0 {
hm.RemoteIndexes = map[uint32]*HostInfo{}
}
if hm.l.Level >= logrus.DebugLevel { if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts), hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
@ -520,15 +585,22 @@ func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*Host
return nil, errors.New("unable to find host") return nil, errors.New("unable to find host")
} }
// We already have the hm Lock when this is called, so make sure to not call // unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps.
// any other methods that might try to grab it again // If an entry exists for the Hosts table (vpnIp -> hostinfo) then the provided hostinfo will be made primary
func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) { func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
if f.serveDns { if f.serveDns {
remoteCert := hostinfo.ConnectionState.peerCert remoteCert := hostinfo.ConnectionState.peerCert
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String()) dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
} }
existing := hm.Hosts[hostinfo.vpnIp]
hm.Hosts[hostinfo.vpnIp] = hostinfo hm.Hosts[hostinfo.vpnIp] = hostinfo
if existing != nil {
hostinfo.next = existing
existing.prev = hostinfo
}
hm.Indexes[hostinfo.localIndexId] = hostinfo hm.Indexes[hostinfo.localIndexId] = hostinfo
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
@ -537,6 +609,16 @@ func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}). "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}).
Debug("Hostmap vpnIp added") Debug("Hostmap vpnIp added")
} }
i := 1
check := hostinfo
for check != nil {
if i > MaxHostInfosPerVpnIp {
hm.unlockedDeleteHostInfo(check)
}
check = check.next
i++
}
} }
// punchList assembles a list of all non nil RemoteList pointer entries in this hostmap // punchList assembles a list of all non nil RemoteList pointer entries in this hostmap

View File

@ -1 +1,207 @@
package nebula package nebula
import (
"net"
"testing"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
)
func TestHostMap_MakePrimary(t *testing.T) {
l := test.NewLogger()
hm := NewHostMap(
l, "test",
&net.IPNet{
IP: net.IP{10, 0, 0, 1},
Mask: net.IPMask{255, 255, 255, 0},
},
[]*net.IPNet{},
)
f := &Interface{}
h1 := &HostInfo{vpnIp: 1, localIndexId: 1}
h2 := &HostInfo{vpnIp: 1, localIndexId: 2}
h3 := &HostInfo{vpnIp: 1, localIndexId: 3}
h4 := &HostInfo{vpnIp: 1, localIndexId: 4}
hm.unlockedAddHostInfo(h4, f)
hm.unlockedAddHostInfo(h3, f)
hm.unlockedAddHostInfo(h2, f)
hm.unlockedAddHostInfo(h1, f)
// Make sure we go h1 -> h2 -> h3 -> h4
prim, _ := hm.QueryVpnIp(1)
assert.Equal(t, h1.localIndexId, prim.localIndexId)
assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
assert.Equal(t, h1.localIndexId, h2.prev.localIndexId)
assert.Equal(t, h3.localIndexId, h2.next.localIndexId)
assert.Equal(t, h2.localIndexId, h3.prev.localIndexId)
assert.Equal(t, h4.localIndexId, h3.next.localIndexId)
assert.Equal(t, h3.localIndexId, h4.prev.localIndexId)
assert.Nil(t, h4.next)
// Swap h3/middle to primary
hm.MakePrimary(h3)
// Make sure we go h3 -> h1 -> h2 -> h4
prim, _ = hm.QueryVpnIp(1)
assert.Equal(t, h3.localIndexId, prim.localIndexId)
assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
assert.Equal(t, h2.localIndexId, h1.next.localIndexId)
assert.Equal(t, h3.localIndexId, h1.prev.localIndexId)
assert.Equal(t, h4.localIndexId, h2.next.localIndexId)
assert.Equal(t, h1.localIndexId, h2.prev.localIndexId)
assert.Equal(t, h2.localIndexId, h4.prev.localIndexId)
assert.Nil(t, h4.next)
// Swap h4/tail to primary
hm.MakePrimary(h4)
// Make sure we go h4 -> h3 -> h1 -> h2
prim, _ = hm.QueryVpnIp(1)
assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
assert.Equal(t, h1.localIndexId, h3.next.localIndexId)
assert.Equal(t, h4.localIndexId, h3.prev.localIndexId)
assert.Equal(t, h2.localIndexId, h1.next.localIndexId)
assert.Equal(t, h3.localIndexId, h1.prev.localIndexId)
assert.Equal(t, h1.localIndexId, h2.prev.localIndexId)
assert.Nil(t, h2.next)
// Swap h4 again should be no-op
hm.MakePrimary(h4)
// Make sure we go h4 -> h3 -> h1 -> h2
prim, _ = hm.QueryVpnIp(1)
assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
assert.Equal(t, h1.localIndexId, h3.next.localIndexId)
assert.Equal(t, h4.localIndexId, h3.prev.localIndexId)
assert.Equal(t, h2.localIndexId, h1.next.localIndexId)
assert.Equal(t, h3.localIndexId, h1.prev.localIndexId)
assert.Equal(t, h1.localIndexId, h2.prev.localIndexId)
assert.Nil(t, h2.next)
}
func TestHostMap_DeleteHostInfo(t *testing.T) {
l := test.NewLogger()
hm := NewHostMap(
l, "test",
&net.IPNet{
IP: net.IP{10, 0, 0, 1},
Mask: net.IPMask{255, 255, 255, 0},
},
[]*net.IPNet{},
)
f := &Interface{}
h1 := &HostInfo{vpnIp: 1, localIndexId: 1}
h2 := &HostInfo{vpnIp: 1, localIndexId: 2}
h3 := &HostInfo{vpnIp: 1, localIndexId: 3}
h4 := &HostInfo{vpnIp: 1, localIndexId: 4}
h5 := &HostInfo{vpnIp: 1, localIndexId: 5}
h6 := &HostInfo{vpnIp: 1, localIndexId: 6}
hm.unlockedAddHostInfo(h6, f)
hm.unlockedAddHostInfo(h5, f)
hm.unlockedAddHostInfo(h4, f)
hm.unlockedAddHostInfo(h3, f)
hm.unlockedAddHostInfo(h2, f)
hm.unlockedAddHostInfo(h1, f)
// h6 should be deleted
assert.Nil(t, h6.next)
assert.Nil(t, h6.prev)
_, err := hm.QueryIndex(h6.localIndexId)
assert.Error(t, err)
// Make sure we go h1 -> h2 -> h3 -> h4 -> h5
prim, _ := hm.QueryVpnIp(1)
assert.Equal(t, h1.localIndexId, prim.localIndexId)
assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
assert.Equal(t, h1.localIndexId, h2.prev.localIndexId)
assert.Equal(t, h3.localIndexId, h2.next.localIndexId)
assert.Equal(t, h2.localIndexId, h3.prev.localIndexId)
assert.Equal(t, h4.localIndexId, h3.next.localIndexId)
assert.Equal(t, h3.localIndexId, h4.prev.localIndexId)
assert.Equal(t, h5.localIndexId, h4.next.localIndexId)
assert.Equal(t, h4.localIndexId, h5.prev.localIndexId)
assert.Nil(t, h5.next)
// Delete primary
hm.DeleteHostInfo(h1)
assert.Nil(t, h1.prev)
assert.Nil(t, h1.next)
// Make sure we go h2 -> h3 -> h4 -> h5
prim, _ = hm.QueryVpnIp(1)
assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
assert.Equal(t, h3.localIndexId, h2.next.localIndexId)
assert.Equal(t, h2.localIndexId, h3.prev.localIndexId)
assert.Equal(t, h4.localIndexId, h3.next.localIndexId)
assert.Equal(t, h3.localIndexId, h4.prev.localIndexId)
assert.Equal(t, h5.localIndexId, h4.next.localIndexId)
assert.Equal(t, h4.localIndexId, h5.prev.localIndexId)
assert.Nil(t, h5.next)
// Delete in the middle
hm.DeleteHostInfo(h3)
assert.Nil(t, h3.prev)
assert.Nil(t, h3.next)
// Make sure we go h2 -> h4 -> h5
prim, _ = hm.QueryVpnIp(1)
assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
assert.Equal(t, h4.localIndexId, h2.next.localIndexId)
assert.Equal(t, h2.localIndexId, h4.prev.localIndexId)
assert.Equal(t, h5.localIndexId, h4.next.localIndexId)
assert.Equal(t, h4.localIndexId, h5.prev.localIndexId)
assert.Nil(t, h5.next)
// Delete the tail
hm.DeleteHostInfo(h5)
assert.Nil(t, h5.prev)
assert.Nil(t, h5.next)
// Make sure we go h2 -> h4
prim, _ = hm.QueryVpnIp(1)
assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
assert.Equal(t, h4.localIndexId, h2.next.localIndexId)
assert.Equal(t, h2.localIndexId, h4.prev.localIndexId)
assert.Nil(t, h4.next)
// Delete the head
hm.DeleteHostInfo(h2)
assert.Nil(t, h2.prev)
assert.Nil(t, h2.next)
// Make sure we only have h4
prim, _ = hm.QueryVpnIp(1)
assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Nil(t, prim.prev)
assert.Nil(t, prim.next)
assert.Nil(t, h4.next)
// Delete the only item
hm.DeleteHostInfo(h4)
assert.Nil(t, h4.prev)
assert.Nil(t, h4.next)
// Make sure we have nil
prim, _ = hm.QueryVpnIp(1)
assert.Nil(t, prim)
}

View File

@ -245,9 +245,11 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) {
//TODO: this would be better as a single function in ConnectionManager that handled locks appropriately //TODO: this would be better as a single function in ConnectionManager that handled locks appropriately
f.connectionManager.ClearLocalIndex(hostInfo.localIndexId) f.connectionManager.ClearLocalIndex(hostInfo.localIndexId)
f.connectionManager.ClearPendingDeletion(hostInfo.localIndexId) f.connectionManager.ClearPendingDeletion(hostInfo.localIndexId)
f.lightHouse.DeleteVpnIp(hostInfo.vpnIp) final := f.hostMap.DeleteHostInfo(hostInfo)
if final {
f.hostMap.DeleteHostInfo(hostInfo) // We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage
f.lightHouse.DeleteVpnIp(hostInfo.vpnIp)
}
} }
// sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote // sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote

View File

@ -51,7 +51,7 @@ func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int
// packets should exit the udp side, capture them with udpConn.Get // packets should exit the udp side, capture them with udpConn.Get
func (t *TestTun) Send(packet []byte) { func (t *TestTun) Send(packet []byte) {
if t.l.Level >= logrus.InfoLevel { if t.l.Level >= logrus.InfoLevel {
t.l.WithField("dataLen", len(packet)).Info("Tun receiving injected packet") t.l.WithField("dataLen", len(packet)).Debug("Tun receiving injected packet")
} }
t.rxPackets <- packet t.rxPackets <- packet
} }

View File

@ -61,6 +61,11 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp iput
_, inRelays := hm.Relays[index] _, inRelays := hm.Relays[index]
if !inRelays { if !inRelays {
// Avoid standing up a relay that can't be used since only the primary hostinfo
// will be pointed to by the relay logic
//TODO: if there was an existing primary and it had relay state, should we merge?
hm.unlockedMakePrimary(relayHostInfo)
hm.Relays[index] = relayHostInfo hm.Relays[index] = relayHostInfo
newRelay := Relay{ newRelay := Relay{
Type: relayType, Type: relayType,

15
ssh.go
View File

@ -22,8 +22,9 @@ import (
) )
type sshListHostMapFlags struct { type sshListHostMapFlags struct {
Json bool Json bool
Pretty bool Pretty bool
ByIndex bool
} }
type sshPrintCertFlags struct { type sshPrintCertFlags struct {
@ -174,6 +175,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
s := sshListHostMapFlags{} s := sshListHostMapFlags{}
fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json") fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table")
return fl, &s return fl, &s
}, },
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
@ -189,6 +191,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
s := sshListHostMapFlags{} s := sshListHostMapFlags{}
fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json") fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table")
return fl, &s return fl, &s
}, },
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
@ -368,7 +371,13 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
return nil return nil
} }
hm := listHostMap(hostMap) var hm []ControlHostInfo
if fs.ByIndex {
hm = listHostMapIndexes(hostMap)
} else {
hm = listHostMapHosts(hostMap)
}
sort.Slice(hm, func(i, j int) bool { sort.Slice(hm, func(i, j int) bool {
return bytes.Compare(hm[i].VpnIp, hm[j].VpnIp) < 0 return bytes.Compare(hm[i].VpnIp, hm[j].VpnIp) < 0
}) })

View File

@ -66,7 +66,7 @@ func (u *Conn) Send(packet *Packet) {
u.l.WithField("header", h). u.l.WithField("header", h).
WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)). WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)).
WithField("dataLen", len(packet.Data)). WithField("dataLen", len(packet.Data)).
Info("UDP receiving injected packet") Debug("UDP receiving injected packet")
} }
u.RxPackets <- packet u.RxPackets <- packet
} }