diff --git a/connection_manager.go b/connection_manager.go index 8135421..d1e78ca 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -181,6 +181,14 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) continue } + // Does the vpnIp point to this hostinfo or is it ancillary? If we have ancillary hostinfos then we need to + // decide if this should be the main hostinfo if we are seeing traffic on it + primary, _ := n.hostMap.QueryVpnIp(hostinfo.vpnIp) + mainHostInfo := true + if primary != nil && primary != hostinfo { + mainHostInfo = false + } + // If we saw an incoming packets from this ip and peer's certificate is not // expired, just ignore. if traf { @@ -191,6 +199,20 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) } n.ClearLocalIndex(localIndex) n.ClearPendingDeletion(localIndex) + + if !mainHostInfo { + if hostinfo.vpnIp > n.intf.myVpnIp { + // We are receiving traffic on the non primary hostinfo and we really just want 1 tunnel. Make + // This the primary and prime the old primary hostinfo for testing + n.hostMap.MakePrimary(hostinfo) + n.Out(primary.localIndexId) + } else { + // This hostinfo is still being used despite not being the primary hostinfo for this vpn ip + // Keep tracking so that we can tear it down when it goes away + n.Out(hostinfo.localIndexId) + } + } + continue } @@ -198,7 +220,7 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) WithField("tunnelCheck", m{"state": "testing", "method": "active"}). Debug("Tunnel status") - if hostinfo != nil && hostinfo.ConnectionState != nil { + if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo { // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues n.intf.sendMessageToVpnIp(header.Test, header.TestRequest, hostinfo, p, nb, out) diff --git a/connection_manager_test.go b/connection_manager_test.go index 58fdbcd..51e331b 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -80,7 +80,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { certState: cs, H: &noise.HandshakeState{}, } - nc.hostMap.addHostInfo(hostinfo, ifce) + nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) // We saw traffic out to vpnIp nc.Out(hostinfo.localIndexId) @@ -156,7 +156,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { certState: cs, H: &noise.HandshakeState{}, } - nc.hostMap.addHostInfo(hostinfo, ifce) + nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) // We saw traffic out to vpnIp nc.Out(hostinfo.localIndexId) diff --git a/control.go b/control.go index adc2a48..ab3a5cb 100644 --- a/control.go +++ b/control.go @@ -95,12 +95,21 @@ func (c *Control) RebindUDPServer() { c.f.rebindCount++ } -// ListHostmap returns details about the actual or pending (handshaking) hostmap -func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo { +// ListHostmapHosts returns details about the actual or pending (handshaking) hostmap by vpn ip +func (c *Control) ListHostmapHosts(pendingMap bool) []ControlHostInfo { if pendingMap { - return listHostMap(c.f.handshakeManager.pendingHostMap) + return listHostMapHosts(c.f.handshakeManager.pendingHostMap) } else { - return listHostMap(c.f.hostMap) + return listHostMapHosts(c.f.hostMap) + } +} + +// ListHostmapIndexes returns details about the actual or pending (handshaking) hostmap by local index id +func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo { + if pendingMap { + return listHostMapIndexes(c.f.handshakeManager.pendingHostMap) + } else { + return listHostMapIndexes(c.f.hostMap) } } @@ -232,7 +241,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { return chi } -func listHostMap(hm *HostMap) []ControlHostInfo { +func listHostMapHosts(hm *HostMap) []ControlHostInfo { hm.RLock() hosts := make([]ControlHostInfo, len(hm.Hosts)) i := 0 @@ -244,3 +253,16 @@ func listHostMap(hm *HostMap) []ControlHostInfo { return hosts } + +func listHostMapIndexes(hm *HostMap) []ControlHostInfo { + hm.RLock() + hosts := make([]ControlHostInfo, len(hm.Indexes)) + i := 0 + for _, v := range hm.Indexes { + hosts[i] = copyHostInfo(v, hm.preferredRanges) + i++ + } + hm.RUnlock() + + return hosts +} diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index bfde43e..d12412e 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -19,10 +19,10 @@ import ( func BenchmarkHotPath(b *testing.B) { ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) // Start the servers myControl.Start() @@ -32,7 +32,7 @@ func BenchmarkHotPath(b *testing.B) { r.CancelFlowLogs() for n := 0; n < b.N; n++ { - myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) _ = r.RouteForAllUntilTxTun(theirControl) } @@ -42,18 +42,18 @@ func BenchmarkHotPath(b *testing.B) { func TestGoodHandshake(t *testing.T) { ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) + theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) // Start the servers myControl.Start() theirControl.Start() t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") - myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) t.Log("Have them consume my stage 0 packet. They have a tunnel now") theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) @@ -74,16 +74,16 @@ func TestGoodHandshake(t *testing.T) { myControl.WaitForType(1, 0, theirControl) t.Log("Make sure our host infos are correct") - assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl) + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl) t.Log("Get that cached packet and make sure it looks right") myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) t.Log("Do a bidirectional tunnel test") r := router.NewR(t, myControl, theirControl) defer r.RenderFlow() - assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) r.RenderHostmaps("Final hostmaps", myControl, theirControl) myControl.Stop() @@ -97,15 +97,15 @@ func TestWrongResponderHandshake(t *testing.T) { // The IPs here are chosen on purpose: // The current remote handling will sort by preference, public, and then lexically. // So we need them to have a higher address than evil (we could apply a preference though) - myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil) - theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil) + myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil) + theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil) evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil) // Add their real udp addr, which should be tried after evil. - myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) // Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse. - myControl.InjectLightHouseAddr(theirVpnIp, evilUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.IP, evilUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl, evilControl) @@ -117,7 +117,7 @@ func TestWrongResponderHandshake(t *testing.T) { evilControl.Start() t.Log("Start the handshake process, we will route until we see our cached packet get sent to them") - myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { h := &header.H{} err := h.Parse(p.Data) @@ -136,18 +136,18 @@ func TestWrongResponderHandshake(t *testing.T) { t.Log("My cached packet should be received by them") myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) t.Log("Test the tunnel with them") - assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl) - assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r) + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl) + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) t.Log("Flush all packets from all controllers") r.FlushAll() t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp), true), "My pending hostmap should not contain evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp), false), "My main hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), true), "My pending hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), false), "My main hostmap should not contain evil") //NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete //TODO: assert hostmaps for everyone @@ -157,14 +157,17 @@ func TestWrongResponderHandshake(t *testing.T) { theirControl.Stop() } -func Test_Case1_Stage1Race(t *testing.T) { +func TestStage1Race(t *testing.T) { + // This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow + // But will eventually collapse down to a single tunnel + ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) + theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIp, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -175,8 +178,8 @@ func Test_Case1_Stage1Race(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake to start on both me and them") - myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIp, 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) t.Log("Get both stage 1 handshake packets") myHsForThem := myControl.GetFromUDP(true) @@ -185,44 +188,165 @@ func Test_Case1_Stage1Race(t *testing.T) { r.Log("Now inject both stage 1 handshake packets") r.InjectUDPPacket(theirControl, myControl, theirHsForMe) r.InjectUDPPacket(myControl, theirControl, myHsForThem) - //TODO: they should win, grab their index for me and make sure I use it in the end. - r.Log("They should not have a stage 2 (won the race) but I should send one") - r.InjectUDPPacket(myControl, theirControl, myControl.GetFromUDP(true)) + r.Log("Route until they receive a message packet") + myCachedPacket := r.RouteForAllUntilTxTun(theirControl) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) - r.Log("Route for me until I send a message packet to them") - r.RouteForAllUntilAfterMsgTypeTo(theirControl, header.Message, header.MessageNone) + r.Log("Their cached packet should be received by me") + theirCachedPacket := r.RouteForAllUntilTxTun(myControl) + assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80) - t.Log("My cached packet should be received by them") - myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80) + r.Log("Do a bidirectional tunnel test") + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) - t.Log("Route for them until I send a message packet to me") - theirControl.WaitForType(1, 0, myControl) + myHostmapHosts := myControl.ListHostmapHosts(false) + myHostmapIndexes := myControl.ListHostmapIndexes(false) + theirHostmapHosts := theirControl.ListHostmapHosts(false) + theirHostmapIndexes := theirControl.ListHostmapIndexes(false) - t.Log("Their cached packet should be received by me") - theirCachedPacket := myControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIp, myVpnIp, 80, 80) + // We should have two tunnels on both sides + assert.Len(t, myHostmapHosts, 1) + assert.Len(t, theirHostmapHosts, 1) + assert.Len(t, myHostmapIndexes, 2) + assert.Len(t, theirHostmapIndexes, 2) - t.Log("Do a bidirectional tunnel test") - assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r) + r.RenderHostmaps("Starting hostmaps", myControl, theirControl) + + r.Log("Spin until connection manager tears down a tunnel") + + for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + t.Log("Connection manager hasn't ticked yet") + time.Sleep(time.Second) + } + + myFinalHostmapHosts := myControl.ListHostmapHosts(false) + myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) + theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) + theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) + + // We should only have a single tunnel now on both sides + assert.Len(t, myFinalHostmapHosts, 1) + assert.Len(t, theirFinalHostmapHosts, 1) + assert.Len(t, myFinalHostmapIndexes, 1) + assert.Len(t, theirFinalHostmapIndexes, 1) r.RenderHostmaps("Final hostmaps", myControl, theirControl) myControl.Stop() theirControl.Stop() - //TODO: assert hostmaps +} + +func TestUncleanShutdownRaceLoser(t *testing.T) { + ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) + theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + + // Teach my how to get to the relay and that their can be reached via the relay + myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + theirControl.Start() + + r.Log("Trigger a handshake from me to them") + myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + + p := r.RouteForAllUntilTxTun(theirControl) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + + r.Log("Nuke my hostmap") + myHostmap := myControl.GetHostmap() + myHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{} + myHostmap.Indexes = map[uint32]*nebula.HostInfo{} + myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} + + myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me again")) + p = r.RouteForAllUntilTxTun(theirControl) + assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + + r.Log("Assert the tunnel works") + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + + r.Log("Wait for the dead index to go away") + start := len(theirControl.GetHostmap().Indexes) + for { + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + if len(theirControl.GetHostmap().Indexes) < start { + break + } + time.Sleep(time.Second) + } + + r.RenderHostmaps("Final hostmaps", myControl, theirControl) +} + +func TestUncleanShutdownRaceWinner(t *testing.T) { + ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) + theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + + // Teach my how to get to the relay and that their can be reached via the relay + myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + theirControl.Start() + + r.Log("Trigger a handshake from me to them") + myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + + p := r.RouteForAllUntilTxTun(theirControl) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + r.RenderHostmaps("Final hostmaps", myControl, theirControl) + + r.Log("Nuke my hostmap") + theirHostmap := theirControl.GetHostmap() + theirHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{} + theirHostmap.Indexes = map[uint32]*nebula.HostInfo{} + theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} + + theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them again")) + p = r.RouteForAllUntilTxTun(myControl) + assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80) + r.RenderHostmaps("Derp hostmaps", myControl, theirControl) + + r.Log("Assert the tunnel works") + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + + r.Log("Wait for the dead index to go away") + start := len(myControl.GetHostmap().Indexes) + for { + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + if len(myControl.GetHostmap().Indexes) < start { + break + } + time.Sleep(time.Second) + } + + r.RenderHostmaps("Final hostmaps", myControl, theirControl) } func TestRelays(t *testing.T) { ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIp, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIp, relayUdpAddr := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + myControl, myVpnIpNet, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIp, relayUdpAddr) - myControl.InjectRelays(theirVpnIp, []net.IP{relayVpnIp}) - relayControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) + relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -234,12 +358,84 @@ func TestRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIp, theirVpnIp, 80, 80) + r.Log("Assert the tunnel works") + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) //TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it } +func TestStage1RaceRelays(t *testing.T) { + //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay + ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + + // Teach my how to get to the relay and that their can be reached via the relay + myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) + theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) + + myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) + theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) + + relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, relayControl, theirControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + relayControl.Start() + theirControl.Start() + + r.Log("Trigger a handshake to start on both me and relay") + myControl.InjectTunUDPPacket(relayVpnIpNet.IP, 80, 80, []byte("Hi from me")) + relayControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from relay")) + + r.Log("Get both stage 1 handshake packets") + //TODO: this is where it breaks, we need to get the hs packets for the relay not for the destination + myHsForThem := myControl.GetFromUDP(true) + relayHsForMe := relayControl.GetFromUDP(true) + + r.Log("Now inject both stage 1 handshake packets") + r.InjectUDPPacket(relayControl, myControl, relayHsForMe) + r.InjectUDPPacket(myControl, relayControl, myHsForThem) + + r.Log("Route for me until I send a message packet to relay") + r.RouteForAllUntilAfterMsgTypeTo(relayControl, header.Message, header.MessageNone) + + r.Log("My cached packet should be received by relay") + myCachedPacket := relayControl.GetFromTun(true) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, relayVpnIpNet.IP, 80, 80) + + r.Log("Relays cached packet should be received by me") + relayCachedPacket := r.RouteForAllUntilTxTun(myControl) + assertUdpPacket(t, []byte("Hi from relay"), relayCachedPacket, relayVpnIpNet.IP, myVpnIpNet.IP, 80, 80) + + r.Log("Do a bidirectional tunnel test; me and relay") + assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) + + r.Log("Create a tunnel between relay and them") + assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) + + r.RenderHostmaps("Starting hostmaps", myControl, relayControl, theirControl) + + r.Log("Trigger a handshake to start from me to them via the relay") + //TODO: if we initiate a handshake from me and then assert the tunnel it will cause a relay control race that can blow up + // this is a problem that exists on master today + //myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + + myControl.Stop() + theirControl.Stop() + relayControl.Stop() + // + ////TODO: assert hostmaps +} + //TODO: add a test with many lies diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index a378bea..3a2d7b5 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -30,7 +30,7 @@ import ( type m map[string]interface{} // newSimpleServer creates a nebula instance with many assumptions -func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, net.IP, *net.UDPAddr) { +func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, *net.IPNet, *net.UDPAddr) { l := NewTestLogger() vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}} @@ -101,7 +101,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u panic(err) } - return control, vpnIpNet.IP, &udpAddr + return control, vpnIpNet, &udpAddr } // newTestCaCert will generate a CA cert @@ -231,12 +231,12 @@ func deadline(t *testing.T, seconds time.Duration) doneCb { func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control, r *router.R) { // Send a packet from them to me controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B")) - bPacket := r.RouteUntilTxTun(controlB, controlA) + bPacket := r.RouteForAllUntilTxTun(controlA) assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80) // And once more from me to them controlA.InjectTunUDPPacket(vpnIpB, 80, 90, []byte("Hello from A")) - aPacket := r.RouteUntilTxTun(controlA, controlB) + aPacket := r.RouteForAllUntilTxTun(controlB) assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80) } diff --git a/e2e/router/hostmap.go b/e2e/router/hostmap.go index 948281a..10627fc 100644 --- a/e2e/router/hostmap.go +++ b/e2e/router/hostmap.go @@ -5,9 +5,11 @@ package router import ( "fmt" + "sort" "strings" "github.com/slackhq/nebula" + "github.com/slackhq/nebula/iputil" ) type edge struct { @@ -64,7 +66,8 @@ func renderHostmap(c *nebula.Control) (string, []*edge) { // Draw the vpn to index nodes r += fmt.Sprintf("\t\tsubgraph %s.hosts[\"Hosts (vpn ip to index)\"]\n", clusterName) - for vpnIp, hi := range hm.Hosts { + for _, vpnIp := range sortedHosts(hm.Hosts) { + hi := hm.Hosts[vpnIp] r += fmt.Sprintf("\t\t\t%v.%v[\"%v\"]\n", clusterName, vpnIp, vpnIp) lines = append(lines, fmt.Sprintf("%v.%v --> %v.%v", clusterName, vpnIp, clusterName, hi.GetLocalIndex())) @@ -91,7 +94,8 @@ func renderHostmap(c *nebula.Control) (string, []*edge) { // Draw the local index to relay or remote index nodes r += fmt.Sprintf("\t\tsubgraph indexes.%s[\"Indexes (index to hostinfo)\"]\n", clusterName) - for idx, hi := range hm.Indexes { + for _, idx := range sortedIndexes(hm.Indexes) { + hi := hm.Indexes[idx] r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp()) remoteClusterName := strings.Trim(hi.GetCert().Details.Name, " ") globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())}) @@ -107,3 +111,29 @@ func renderHostmap(c *nebula.Control) (string, []*edge) { r += "\tend\n" return r, globalLines } + +func sortedHosts(hosts map[iputil.VpnIp]*nebula.HostInfo) []iputil.VpnIp { + keys := make([]iputil.VpnIp, 0, len(hosts)) + for key := range hosts { + keys = append(keys, key) + } + + sort.SliceStable(keys, func(i, j int) bool { + return keys[i] > keys[j] + }) + + return keys +} + +func sortedIndexes(indexes map[uint32]*nebula.HostInfo) []uint32 { + keys := make([]uint32, 0, len(indexes)) + for key := range indexes { + keys = append(keys, key) + } + + sort.SliceStable(keys, func(i, j int) bool { + return keys[i] > keys[j] + }) + + return keys +} diff --git a/e2e/router/router.go b/e2e/router/router.go index aa56db8..98bb31d 100644 --- a/e2e/router/router.go +++ b/e2e/router/router.go @@ -10,6 +10,7 @@ import ( "os" "path/filepath" "reflect" + "sort" "strconv" "strings" "sync" @@ -22,6 +23,7 @@ import ( "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" + "golang.org/x/exp/maps" ) type R struct { @@ -150,6 +152,7 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { case <-ctx.Done(): return case <-clockSource.C: + r.renderHostmaps("clock tick") r.renderFlow() } } @@ -220,11 +223,16 @@ func (r *R) renderFlow() { ) } + if len(participantsVals) > 2 { + // Get the first and last participantVals for notes + participantsVals = []string{participantsVals[0], participantsVals[len(participantsVals)-1]} + } + // Print packets h := &header.H{} for _, e := range r.flow { if e.packet == nil { - fmt.Fprintf(f, " note over %s: %s\n", strings.Join(participantsVals, ", "), e.note) + //fmt.Fprintf(f, " note over %s: %s\n", strings.Join(participantsVals, ", "), e.note) continue } @@ -294,6 +302,28 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) { }) } +func (r *R) renderHostmaps(title string) { + c := maps.Values(r.controls) + sort.SliceStable(c, func(i, j int) bool { + return c[i].GetVpnIp() > c[j].GetVpnIp() + }) + + s := renderHostmaps(c...) + if len(r.additionalGraphs) > 0 { + lastGraph := r.additionalGraphs[len(r.additionalGraphs)-1] + if lastGraph.content == s { + // Ignore this rendering if it matches the last rendering added + // This is useful if you want to track rendering changes + return + } + } + + r.additionalGraphs = append(r.additionalGraphs, mermaidGraph{ + title: title, + content: s, + }) +} + // InjectFlow can be used to record packet flow if the test is handling the routing on its own. // The packet is assumed to have been received func (r *R) InjectFlow(from, to *nebula.Control, p *udp.Packet) { @@ -332,6 +362,8 @@ func (r *R) unlockedInjectFlow(from, to *nebula.Control, p *udp.Packet, tun bool return nil } + r.renderHostmaps(fmt.Sprintf("Packet %v", len(r.flow))) + if len(r.ignoreFlows) > 0 { var h header.H err := h.Parse(p.Data) diff --git a/go.mod b/go.mod index 8e8a354..d05ab70 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/stretchr/testify v1.8.1 github.com/vishvananda/netlink v1.1.0 golang.org/x/crypto v0.3.0 + golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2 golang.org/x/net v0.2.0 golang.org/x/sys v0.2.0 golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 diff --git a/go.sum b/go.sum index 3c5eaa7..cb2db8e 100644 --- a/go.sum +++ b/go.sum @@ -266,6 +266,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= +golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2 h1:Jvc7gsqn21cJHCmAWx0LiimpP18LZmUxkT5Mp7EZ1mI= +golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= diff --git a/handshake_ix.go b/handshake_ix.go index 11a16a6..bb511cc 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -207,9 +207,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b hostinfo.SetRemote(addr) hostinfo.CreateRemoteCIDR(remoteCert) - // Only overwrite existing record if we should win the handshake race - overwrite := vpnIp > f.myVpnIp - existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f) + existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) if err != nil { switch err { case ErrAlreadySeen: @@ -280,16 +278,6 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp). Error("Failed to add HostInfo due to localIndex collision") return - case ErrExistingHandshake: - // We have a race where both parties think they are an initiator and this tunnel lost, let the other one finish - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Error("Prevented a pending handshake race") - return default: // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // And we forget to update it here @@ -344,6 +332,12 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b Info("Handshake message sent") } + if existing != nil { + // Make sure we are tracking the old primary if there was one, it needs to go away eventually + f.connectionManager.Out(existing.localIndexId) + } + + f.connectionManager.Out(hostinfo.localIndexId) hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics) return @@ -501,8 +495,12 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo * hostinfo.CreateRemoteCIDR(remoteCert) // Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp - //TODO: Complete here does not do a race avoidance, it will just take the new tunnel. Is this ok? - f.handshakeManager.Complete(hostinfo, f) + existing := f.handshakeManager.Complete(hostinfo, f) + if existing != nil { + // Make sure we are tracking the old primary if there was one, it needs to go away eventually + f.connectionManager.Out(existing.localIndexId) + } + hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics) f.metricHandshakes.Update(duration) diff --git a/handshake_manager.go b/handshake_manager.go index 4325841..06805b6 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -53,6 +53,10 @@ type HandshakeManager struct { metricTimedOut metrics.Counter l *logrus.Logger + // vpnIps is another map similar to the pending hostmap but tracks entries in the wheel instead + // this is to avoid situations where the same vpn ip enters the wheel and causes rapid fire handshaking + vpnIps map[iputil.VpnIp]struct{} + // can be used to trigger outbound handshake for the given vpnIp trigger chan iputil.VpnIp } @@ -66,6 +70,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [ config: config, trigger: make(chan iputil.VpnIp, config.triggerBuffer), OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp](config.tryInterval, hsTimeout(config.retries, config.tryInterval)), + vpnIps: map[iputil.VpnIp]struct{}{}, messageMetrics: config.messageMetrics, metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil), metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil), @@ -103,6 +108,7 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.E func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, lighthouseTriggered bool) { hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp) if err != nil { + delete(c.vpnIps, vpnIp) return } hostinfo.Lock() @@ -160,7 +166,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, l c.lightHouse.QueryServer(vpnIp, f) } - // Send a the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply + // Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply var sentTo []*udp.Addr hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udp.Addr, _ bool) { c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) @@ -260,7 +266,6 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, l // If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add if !lighthouseTriggered { - //TODO: feel like we dupe handshake real fast in a tight loop, why? c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) } } @@ -269,7 +274,10 @@ func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *H hostinfo, created := c.pendingHostMap.AddVpnIp(vpnIp, init) if created { - c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval) + if _, ok := c.vpnIps[vpnIp]; !ok { + c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval) + } + c.vpnIps[vpnIp] = struct{}{} c.metricInitiated.Inc(1) } @@ -280,7 +288,6 @@ var ( ErrExistingHostInfo = errors.New("existing hostinfo") ErrAlreadySeen = errors.New("already seen") ErrLocalIndexCollision = errors.New("local index collision") - ErrExistingHandshake = errors.New("existing handshake") ) // CheckAndComplete checks for any conflicts in the main and pending hostmap @@ -294,7 +301,7 @@ var ( // // ErrLocalIndexCollision if we already have an entry in the main or pending // hostmap for the hostinfo.localIndexId. -func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, overwrite bool, f *Interface) (*HostInfo, error) { +func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) { c.pendingHostMap.Lock() defer c.pendingHostMap.Unlock() c.mainHostMap.Lock() @@ -303,9 +310,14 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket // Check if we already have a tunnel with this vpn ip existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp] if found && existingHostInfo != nil { - // Is it just a delayed handshake packet? - if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) { - return existingHostInfo, ErrAlreadySeen + testHostInfo := existingHostInfo + for testHostInfo != nil { + // Is it just a delayed handshake packet? + if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) { + return existingHostInfo, ErrAlreadySeen + } + + testHostInfo = testHostInfo.next } // Is this a newer handshake? @@ -337,56 +349,19 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket Info("New host shadows existing host remoteIndex") } - // Check if we are also handshaking with this vpn ip - pendingHostInfo, found := c.pendingHostMap.Hosts[hostinfo.vpnIp] - if found && pendingHostInfo != nil { - if !overwrite { - // We won, let our pending handshake win - return pendingHostInfo, ErrExistingHandshake - } - - // We lost, take this handshake and move any cached packets over so they get sent - pendingHostInfo.ConnectionState.queueLock.Lock() - hostinfo.packetStore = append(hostinfo.packetStore, pendingHostInfo.packetStore...) - c.pendingHostMap.unlockedDeleteHostInfo(pendingHostInfo) - pendingHostInfo.ConnectionState.queueLock.Unlock() - pendingHostInfo.logger(c.l).Info("Handshake race lost, replacing pending handshake with completed tunnel") - } - - if existingHostInfo != nil { - // We are going to overwrite this entry, so remove the old references - delete(c.mainHostMap.Hosts, existingHostInfo.vpnIp) - delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId) - delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId) - for _, relayIdx := range existingHostInfo.relayState.CopyRelayForIdxs() { - delete(c.mainHostMap.Relays, relayIdx) - } - } - - c.mainHostMap.addHostInfo(hostinfo, f) + c.mainHostMap.unlockedAddHostInfo(hostinfo, f) return existingHostInfo, nil } // Complete is a simpler version of CheckAndComplete when we already know we // won't have a localIndexId collision because we already have an entry in the -// pendingHostMap -func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { +// pendingHostMap. An existing hostinfo is returned if there was one. +func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) *HostInfo { c.pendingHostMap.Lock() defer c.pendingHostMap.Unlock() c.mainHostMap.Lock() defer c.mainHostMap.Unlock() - existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp] - if found && existingHostInfo != nil { - // We are going to overwrite this entry, so remove the old references - delete(c.mainHostMap.Hosts, existingHostInfo.vpnIp) - delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId) - delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId) - for _, relayIdx := range existingHostInfo.relayState.CopyRelayForIdxs() { - delete(c.mainHostMap.Relays, relayIdx) - } - } - existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] if found && existingRemoteIndex != nil { // We have a collision, but this can happen since we can't control @@ -396,8 +371,10 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { Info("New host shadows existing host remoteIndex") } - c.mainHostMap.addHostInfo(hostinfo, f) + existingHostInfo := c.mainHostMap.Hosts[hostinfo.vpnIp] + c.mainHostMap.unlockedAddHostInfo(hostinfo, f) c.pendingHostMap.unlockedDeleteHostInfo(hostinfo) + return existingHostInfo } // AddIndexHostInfo generates a unique localIndexId for this HostInfo diff --git a/hostmap.go b/hostmap.go index 372333e..231beb1 100644 --- a/hostmap.go +++ b/hostmap.go @@ -23,6 +23,10 @@ const PromoteEvery = 1000 const ReQueryEvery = 5000 const MaxRemotes = 10 +// MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip +// 5 allows for an initial handshake and each host pair re-handshaking twice +const MaxHostInfosPerVpnIp = 5 + // How long we should prevent roaming back to the previous IP. // This helps prevent flapping due to packets already in flight const RoamingSuppressSeconds = 2 @@ -180,6 +184,10 @@ type HostInfo struct { lastRoam time.Time lastRoamRemote *udp.Addr + + // Used to track other hostinfos for this vpn ip since only 1 can be primary + // Synchronised via hostmap lock and not the hostinfo lock. + next, prev *HostInfo } type ViaSender struct { @@ -395,9 +403,12 @@ func (hm *HostMap) DeleteReverseIndex(index uint32) { } } -func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) { +// DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip +func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool { // Delete the host itself, ensuring it's not modified anymore hm.Lock() + // If we have a previous or next hostinfo then we are not the last one for this vpn ip + final := (hostinfo.next == nil && hostinfo.prev == nil) hm.unlockedDeleteHostInfo(hostinfo) hm.Unlock() @@ -421,6 +432,8 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) { for _, localIdx := range teardownRelayIdx { hm.RemoveRelay(localIdx) } + + return final } func (hm *HostMap) DeleteRelayIdx(localIdx uint32) { @@ -429,29 +442,81 @@ func (hm *HostMap) DeleteRelayIdx(localIdx uint32) { delete(hm.RemoteIndexes, localIdx) } -func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { - // Check if this same hostId is in the hostmap with a different instance. - // This could happen if we have an entry in the pending hostmap with different - // index values than the one in the main hostmap. - hostinfo2, ok := hm.Hosts[hostinfo.vpnIp] - if ok && hostinfo2 != hostinfo { - delete(hm.Hosts, hostinfo2.vpnIp) - delete(hm.Indexes, hostinfo2.localIndexId) - delete(hm.RemoteIndexes, hostinfo2.remoteIndexId) +func (hm *HostMap) MakePrimary(hostinfo *HostInfo) { + hm.Lock() + defer hm.Unlock() + hm.unlockedMakePrimary(hostinfo) +} + +func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) { + oldHostinfo := hm.Hosts[hostinfo.vpnIp] + if oldHostinfo == hostinfo { + return } - delete(hm.Hosts, hostinfo.vpnIp) - if len(hm.Hosts) == 0 { - hm.Hosts = map[iputil.VpnIp]*HostInfo{} + if hostinfo.prev != nil { + hostinfo.prev.next = hostinfo.next } + + if hostinfo.next != nil { + hostinfo.next.prev = hostinfo.prev + } + + hm.Hosts[hostinfo.vpnIp] = hostinfo + + if oldHostinfo == nil { + return + } + + hostinfo.next = oldHostinfo + oldHostinfo.prev = hostinfo + hostinfo.prev = nil +} + +func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { + primary, ok := hm.Hosts[hostinfo.vpnIp] + if ok && primary == hostinfo { + // The vpnIp pointer points to the same hostinfo as the local index id, we can remove it + delete(hm.Hosts, hostinfo.vpnIp) + if len(hm.Hosts) == 0 { + hm.Hosts = map[iputil.VpnIp]*HostInfo{} + } + + if hostinfo.next != nil { + // We had more than 1 hostinfo at this vpnip, promote the next in the list to primary + hm.Hosts[hostinfo.vpnIp] = hostinfo.next + // It is primary, there is no previous hostinfo now + hostinfo.next.prev = nil + } + + } else { + // Relink if we were in the middle of multiple hostinfos for this vpn ip + if hostinfo.prev != nil { + hostinfo.prev.next = hostinfo.next + } + + if hostinfo.next != nil { + hostinfo.next.prev = hostinfo.prev + } + } + + hostinfo.next = nil + hostinfo.prev = nil + + // The remote index uses index ids outside our control so lets make sure we are only removing + // the remote index pointer here if it points to the hostinfo we are deleting + hostinfo2, ok := hm.RemoteIndexes[hostinfo.remoteIndexId] + if ok && hostinfo2 == hostinfo { + delete(hm.RemoteIndexes, hostinfo.remoteIndexId) + if len(hm.RemoteIndexes) == 0 { + hm.RemoteIndexes = map[uint32]*HostInfo{} + } + } + delete(hm.Indexes, hostinfo.localIndexId) if len(hm.Indexes) == 0 { hm.Indexes = map[uint32]*HostInfo{} } - delete(hm.RemoteIndexes, hostinfo.remoteIndexId) - if len(hm.RemoteIndexes) == 0 { - hm.RemoteIndexes = map[uint32]*HostInfo{} - } if hm.l.Level >= logrus.DebugLevel { hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts), @@ -520,15 +585,22 @@ func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*Host return nil, errors.New("unable to find host") } -// We already have the hm Lock when this is called, so make sure to not call -// any other methods that might try to grab it again -func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) { +// unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps. +// If an entry exists for the Hosts table (vpnIp -> hostinfo) then the provided hostinfo will be made primary +func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { if f.serveDns { remoteCert := hostinfo.ConnectionState.peerCert dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String()) } + existing := hm.Hosts[hostinfo.vpnIp] hm.Hosts[hostinfo.vpnIp] = hostinfo + + if existing != nil { + hostinfo.next = existing + existing.prev = hostinfo + } + hm.Indexes[hostinfo.localIndexId] = hostinfo hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo @@ -537,6 +609,16 @@ func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) { "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}). Debug("Hostmap vpnIp added") } + + i := 1 + check := hostinfo + for check != nil { + if i > MaxHostInfosPerVpnIp { + hm.unlockedDeleteHostInfo(check) + } + check = check.next + i++ + } } // punchList assembles a list of all non nil RemoteList pointer entries in this hostmap diff --git a/hostmap_test.go b/hostmap_test.go index 2808317..e523a21 100644 --- a/hostmap_test.go +++ b/hostmap_test.go @@ -1 +1,207 @@ package nebula + +import ( + "net" + "testing" + + "github.com/slackhq/nebula/test" + "github.com/stretchr/testify/assert" +) + +func TestHostMap_MakePrimary(t *testing.T) { + l := test.NewLogger() + hm := NewHostMap( + l, "test", + &net.IPNet{ + IP: net.IP{10, 0, 0, 1}, + Mask: net.IPMask{255, 255, 255, 0}, + }, + []*net.IPNet{}, + ) + + f := &Interface{} + + h1 := &HostInfo{vpnIp: 1, localIndexId: 1} + h2 := &HostInfo{vpnIp: 1, localIndexId: 2} + h3 := &HostInfo{vpnIp: 1, localIndexId: 3} + h4 := &HostInfo{vpnIp: 1, localIndexId: 4} + + hm.unlockedAddHostInfo(h4, f) + hm.unlockedAddHostInfo(h3, f) + hm.unlockedAddHostInfo(h2, f) + hm.unlockedAddHostInfo(h1, f) + + // Make sure we go h1 -> h2 -> h3 -> h4 + prim, _ := hm.QueryVpnIp(1) + assert.Equal(t, h1.localIndexId, prim.localIndexId) + assert.Equal(t, h2.localIndexId, prim.next.localIndexId) + assert.Nil(t, prim.prev) + assert.Equal(t, h1.localIndexId, h2.prev.localIndexId) + assert.Equal(t, h3.localIndexId, h2.next.localIndexId) + assert.Equal(t, h2.localIndexId, h3.prev.localIndexId) + assert.Equal(t, h4.localIndexId, h3.next.localIndexId) + assert.Equal(t, h3.localIndexId, h4.prev.localIndexId) + assert.Nil(t, h4.next) + + // Swap h3/middle to primary + hm.MakePrimary(h3) + + // Make sure we go h3 -> h1 -> h2 -> h4 + prim, _ = hm.QueryVpnIp(1) + assert.Equal(t, h3.localIndexId, prim.localIndexId) + assert.Equal(t, h1.localIndexId, prim.next.localIndexId) + assert.Nil(t, prim.prev) + assert.Equal(t, h2.localIndexId, h1.next.localIndexId) + assert.Equal(t, h3.localIndexId, h1.prev.localIndexId) + assert.Equal(t, h4.localIndexId, h2.next.localIndexId) + assert.Equal(t, h1.localIndexId, h2.prev.localIndexId) + assert.Equal(t, h2.localIndexId, h4.prev.localIndexId) + assert.Nil(t, h4.next) + + // Swap h4/tail to primary + hm.MakePrimary(h4) + + // Make sure we go h4 -> h3 -> h1 -> h2 + prim, _ = hm.QueryVpnIp(1) + assert.Equal(t, h4.localIndexId, prim.localIndexId) + assert.Equal(t, h3.localIndexId, prim.next.localIndexId) + assert.Nil(t, prim.prev) + assert.Equal(t, h1.localIndexId, h3.next.localIndexId) + assert.Equal(t, h4.localIndexId, h3.prev.localIndexId) + assert.Equal(t, h2.localIndexId, h1.next.localIndexId) + assert.Equal(t, h3.localIndexId, h1.prev.localIndexId) + assert.Equal(t, h1.localIndexId, h2.prev.localIndexId) + assert.Nil(t, h2.next) + + // Swap h4 again should be no-op + hm.MakePrimary(h4) + + // Make sure we go h4 -> h3 -> h1 -> h2 + prim, _ = hm.QueryVpnIp(1) + assert.Equal(t, h4.localIndexId, prim.localIndexId) + assert.Equal(t, h3.localIndexId, prim.next.localIndexId) + assert.Nil(t, prim.prev) + assert.Equal(t, h1.localIndexId, h3.next.localIndexId) + assert.Equal(t, h4.localIndexId, h3.prev.localIndexId) + assert.Equal(t, h2.localIndexId, h1.next.localIndexId) + assert.Equal(t, h3.localIndexId, h1.prev.localIndexId) + assert.Equal(t, h1.localIndexId, h2.prev.localIndexId) + assert.Nil(t, h2.next) +} + +func TestHostMap_DeleteHostInfo(t *testing.T) { + l := test.NewLogger() + hm := NewHostMap( + l, "test", + &net.IPNet{ + IP: net.IP{10, 0, 0, 1}, + Mask: net.IPMask{255, 255, 255, 0}, + }, + []*net.IPNet{}, + ) + + f := &Interface{} + + h1 := &HostInfo{vpnIp: 1, localIndexId: 1} + h2 := &HostInfo{vpnIp: 1, localIndexId: 2} + h3 := &HostInfo{vpnIp: 1, localIndexId: 3} + h4 := &HostInfo{vpnIp: 1, localIndexId: 4} + h5 := &HostInfo{vpnIp: 1, localIndexId: 5} + h6 := &HostInfo{vpnIp: 1, localIndexId: 6} + + hm.unlockedAddHostInfo(h6, f) + hm.unlockedAddHostInfo(h5, f) + hm.unlockedAddHostInfo(h4, f) + hm.unlockedAddHostInfo(h3, f) + hm.unlockedAddHostInfo(h2, f) + hm.unlockedAddHostInfo(h1, f) + + // h6 should be deleted + assert.Nil(t, h6.next) + assert.Nil(t, h6.prev) + _, err := hm.QueryIndex(h6.localIndexId) + assert.Error(t, err) + + // Make sure we go h1 -> h2 -> h3 -> h4 -> h5 + prim, _ := hm.QueryVpnIp(1) + assert.Equal(t, h1.localIndexId, prim.localIndexId) + assert.Equal(t, h2.localIndexId, prim.next.localIndexId) + assert.Nil(t, prim.prev) + assert.Equal(t, h1.localIndexId, h2.prev.localIndexId) + assert.Equal(t, h3.localIndexId, h2.next.localIndexId) + assert.Equal(t, h2.localIndexId, h3.prev.localIndexId) + assert.Equal(t, h4.localIndexId, h3.next.localIndexId) + assert.Equal(t, h3.localIndexId, h4.prev.localIndexId) + assert.Equal(t, h5.localIndexId, h4.next.localIndexId) + assert.Equal(t, h4.localIndexId, h5.prev.localIndexId) + assert.Nil(t, h5.next) + + // Delete primary + hm.DeleteHostInfo(h1) + assert.Nil(t, h1.prev) + assert.Nil(t, h1.next) + + // Make sure we go h2 -> h3 -> h4 -> h5 + prim, _ = hm.QueryVpnIp(1) + assert.Equal(t, h2.localIndexId, prim.localIndexId) + assert.Equal(t, h3.localIndexId, prim.next.localIndexId) + assert.Nil(t, prim.prev) + assert.Equal(t, h3.localIndexId, h2.next.localIndexId) + assert.Equal(t, h2.localIndexId, h3.prev.localIndexId) + assert.Equal(t, h4.localIndexId, h3.next.localIndexId) + assert.Equal(t, h3.localIndexId, h4.prev.localIndexId) + assert.Equal(t, h5.localIndexId, h4.next.localIndexId) + assert.Equal(t, h4.localIndexId, h5.prev.localIndexId) + assert.Nil(t, h5.next) + + // Delete in the middle + hm.DeleteHostInfo(h3) + assert.Nil(t, h3.prev) + assert.Nil(t, h3.next) + + // Make sure we go h2 -> h4 -> h5 + prim, _ = hm.QueryVpnIp(1) + assert.Equal(t, h2.localIndexId, prim.localIndexId) + assert.Equal(t, h4.localIndexId, prim.next.localIndexId) + assert.Nil(t, prim.prev) + assert.Equal(t, h4.localIndexId, h2.next.localIndexId) + assert.Equal(t, h2.localIndexId, h4.prev.localIndexId) + assert.Equal(t, h5.localIndexId, h4.next.localIndexId) + assert.Equal(t, h4.localIndexId, h5.prev.localIndexId) + assert.Nil(t, h5.next) + + // Delete the tail + hm.DeleteHostInfo(h5) + assert.Nil(t, h5.prev) + assert.Nil(t, h5.next) + + // Make sure we go h2 -> h4 + prim, _ = hm.QueryVpnIp(1) + assert.Equal(t, h2.localIndexId, prim.localIndexId) + assert.Equal(t, h4.localIndexId, prim.next.localIndexId) + assert.Nil(t, prim.prev) + assert.Equal(t, h4.localIndexId, h2.next.localIndexId) + assert.Equal(t, h2.localIndexId, h4.prev.localIndexId) + assert.Nil(t, h4.next) + + // Delete the head + hm.DeleteHostInfo(h2) + assert.Nil(t, h2.prev) + assert.Nil(t, h2.next) + + // Make sure we only have h4 + prim, _ = hm.QueryVpnIp(1) + assert.Equal(t, h4.localIndexId, prim.localIndexId) + assert.Nil(t, prim.prev) + assert.Nil(t, prim.next) + assert.Nil(t, h4.next) + + // Delete the only item + hm.DeleteHostInfo(h4) + assert.Nil(t, h4.prev) + assert.Nil(t, h4.next) + + // Make sure we have nil + prim, _ = hm.QueryVpnIp(1) + assert.Nil(t, prim) +} diff --git a/outside.go b/outside.go index c43a385..605325d 100644 --- a/outside.go +++ b/outside.go @@ -245,9 +245,11 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) { //TODO: this would be better as a single function in ConnectionManager that handled locks appropriately f.connectionManager.ClearLocalIndex(hostInfo.localIndexId) f.connectionManager.ClearPendingDeletion(hostInfo.localIndexId) - f.lightHouse.DeleteVpnIp(hostInfo.vpnIp) - - f.hostMap.DeleteHostInfo(hostInfo) + final := f.hostMap.DeleteHostInfo(hostInfo) + if final { + // We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage + f.lightHouse.DeleteVpnIp(hostInfo.vpnIp) + } } // sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index a4ee20b..442a9b5 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -51,7 +51,7 @@ func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int // packets should exit the udp side, capture them with udpConn.Get func (t *TestTun) Send(packet []byte) { if t.l.Level >= logrus.InfoLevel { - t.l.WithField("dataLen", len(packet)).Info("Tun receiving injected packet") + t.l.WithField("dataLen", len(packet)).Debug("Tun receiving injected packet") } t.rxPackets <- packet } diff --git a/relay_manager.go b/relay_manager.go index 95807bd..080d144 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -61,6 +61,11 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp iput _, inRelays := hm.Relays[index] if !inRelays { + // Avoid standing up a relay that can't be used since only the primary hostinfo + // will be pointed to by the relay logic + //TODO: if there was an existing primary and it had relay state, should we merge? + hm.unlockedMakePrimary(relayHostInfo) + hm.Relays[index] = relayHostInfo newRelay := Relay{ Type: relayType, diff --git a/ssh.go b/ssh.go index f8050ff..7b9e28a 100644 --- a/ssh.go +++ b/ssh.go @@ -22,8 +22,9 @@ import ( ) type sshListHostMapFlags struct { - Json bool - Pretty bool + Json bool + Pretty bool + ByIndex bool } type sshPrintCertFlags struct { @@ -174,6 +175,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap s := sshListHostMapFlags{} fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json") + fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table") return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { @@ -189,6 +191,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap s := sshListHostMapFlags{} fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json") + fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table") return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { @@ -368,7 +371,13 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error return nil } - hm := listHostMap(hostMap) + var hm []ControlHostInfo + if fs.ByIndex { + hm = listHostMapIndexes(hostMap) + } else { + hm = listHostMapHosts(hostMap) + } + sort.Slice(hm, func(i, j int) bool { return bytes.Compare(hm[i].VpnIp, hm[j].VpnIp) < 0 }) diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 55213b8..b3e2498 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -66,7 +66,7 @@ func (u *Conn) Send(packet *Packet) { u.l.WithField("header", h). WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)). WithField("dataLen", len(packet.Data)). - Info("UDP receiving injected packet") + Debug("UDP receiving injected packet") } u.RxPackets <- packet }