diff --git a/calculated_remote.go b/calculated_remote.go new file mode 100644 index 0000000..910f757 --- /dev/null +++ b/calculated_remote.go @@ -0,0 +1,143 @@ +package nebula + +import ( + "fmt" + "math" + "net" + "strconv" + + "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/iputil" +) + +// This allows us to "guess" what the remote might be for a host while we wait +// for the lighthouse response. See "lighthouse.calculated_remotes" in the +// example config file. +type calculatedRemote struct { + ipNet net.IPNet + maskIP iputil.VpnIp + mask iputil.VpnIp + port uint32 +} + +func newCalculatedRemote(ipNet *net.IPNet, port int) (*calculatedRemote, error) { + // Ensure this is an IPv4 mask that we expect + ones, bits := ipNet.Mask.Size() + if ones == 0 || bits != 32 { + return nil, fmt.Errorf("invalid mask: %v", ipNet) + } + if port < 0 || port > math.MaxUint16 { + return nil, fmt.Errorf("invalid port: %d", port) + } + + return &calculatedRemote{ + ipNet: *ipNet, + maskIP: iputil.Ip2VpnIp(ipNet.IP), + mask: iputil.Ip2VpnIp(ipNet.Mask), + port: uint32(port), + }, nil +} + +func (c *calculatedRemote) String() string { + return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port) +} + +func (c *calculatedRemote) Apply(ip iputil.VpnIp) *Ip4AndPort { + // Combine the masked bytes of the "mask" IP with the unmasked bytes + // of the overlay IP + masked := (c.maskIP & c.mask) | (ip & ^c.mask) + + return &Ip4AndPort{Ip: uint32(masked), Port: c.port} +} + +func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4, error) { + value := c.Get(k) + if value == nil { + return nil, nil + } + + calculatedRemotes := cidr.NewTree4() + + rawMap, ok := value.(map[any]any) + if !ok { + return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value) + } + for rawKey, rawValue := range rawMap { + rawCIDR, ok := rawKey.(string) + if !ok { + return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey) + } + + _, ipNet, err := net.ParseCIDR(rawCIDR) + if err != nil { + return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) + } + + entry, err := newCalculatedRemotesListFromConfig(rawValue) + if err != nil { + return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err) + } + + calculatedRemotes.AddCIDR(ipNet, entry) + } + + return calculatedRemotes, nil +} + +func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) { + rawList, ok := raw.([]any) + if !ok { + return nil, fmt.Errorf("calculated_remotes entry has invalid type: %T", raw) + } + + var l []*calculatedRemote + for _, e := range rawList { + c, err := newCalculatedRemotesEntryFromConfig(e) + if err != nil { + return nil, fmt.Errorf("calculated_remotes entry: %w", err) + } + l = append(l, c) + } + + return l, nil +} + +func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) { + rawMap, ok := raw.(map[any]any) + if !ok { + return nil, fmt.Errorf("invalid type: %T", raw) + } + + rawValue := rawMap["mask"] + if rawValue == nil { + return nil, fmt.Errorf("missing mask: %v", rawMap) + } + rawMask, ok := rawValue.(string) + if !ok { + return nil, fmt.Errorf("invalid mask (type %T): %v", rawValue, rawValue) + } + _, ipNet, err := net.ParseCIDR(rawMask) + if err != nil { + return nil, fmt.Errorf("invalid mask: %s", rawMask) + } + + var port int + rawValue = rawMap["port"] + if rawValue == nil { + return nil, fmt.Errorf("missing port: %v", rawMap) + } + switch v := rawValue.(type) { + case int: + port = v + case string: + port, err = strconv.Atoi(v) + if err != nil { + return nil, fmt.Errorf("invalid port: %s: %w", v, err) + } + default: + return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue) + } + + return newCalculatedRemote(ipNet, port) +} diff --git a/calculated_remote_test.go b/calculated_remote_test.go new file mode 100644 index 0000000..2ddebca --- /dev/null +++ b/calculated_remote_test.go @@ -0,0 +1,27 @@ +package nebula + +import ( + "net" + "testing" + + "github.com/slackhq/nebula/iputil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCalculatedRemoteApply(t *testing.T) { + _, ipNet, err := net.ParseCIDR("192.168.1.0/24") + require.NoError(t, err) + + c, err := newCalculatedRemote(ipNet, 4242) + require.NoError(t, err) + + input := iputil.Ip2VpnIp([]byte{10, 0, 10, 182}) + + expected := &Ip4AndPort{ + Ip: uint32(iputil.Ip2VpnIp([]byte{192, 168, 1, 182})), + Port: 4242, + } + + assert.Equal(t, expected, c.Apply(input)) +} diff --git a/examples/config.yml b/examples/config.yml index 9fe95ce..f7bb95d 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -91,6 +91,19 @@ lighthouse: #- "1.1.1.1:4242" #- "1.2.3.4:0" # port will be replaced with the real listening port + # EXPERIMENTAL: This option may change or disappear in the future. + # This setting allows us to "guess" what the remote might be for a host + # while we wait for the lighthouse response. + #calculated_remotes: + # For any Nebula IPs in 10.0.10.0/24, this will apply the mask and add + # the calculated IP as an initial remote (while we wait for the response + # from the lighthouse). Both CIDRs must have the same mask size. + # For example, Nebula IP 10.0.10.123 will have a calculated remote of + # 192.168.1.123 + #10.0.10.0/24: + #- mask: 192.168.1.0/24 + # port: 4242 + # Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined, # however using port 0 will dynamically assign a port and is recommended for roaming nodes. listen: diff --git a/handshake_manager.go b/handshake_manager.go index 06805b6..8166bda 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -142,14 +142,6 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, l return } - // We only care about a lighthouse trigger before the first handshake transmit attempt. This is a very specific - // optimization for a fast lighthouse reply - //TODO: it would feel better to do this once, anytime, as our delay increases over time - if lighthouseTriggered && hostinfo.HandshakeCounter > 0 { - // If we didn't return here a lighthouse could cause us to aggressively send handshakes - return - } - // Get a remotes object if we don't already have one. // This is mainly to protect us as this should never be the case // NB ^ This comment doesn't jive. It's how the thing gets initialized. @@ -158,8 +150,22 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, l hostinfo.remotes = c.lightHouse.QueryCache(vpnIp) } - //TODO: this will generate a load of queries for hosts with only 1 ip (i'm not using a lighthouse, static mapped) - if hostinfo.remotes.Len(c.pendingHostMap.preferredRanges) <= 1 { + remotes := hostinfo.remotes.CopyAddrs(c.pendingHostMap.preferredRanges) + remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hostinfo.HandshakeLastRemotes) + + // We only care about a lighthouse trigger if we have new remotes to send to. + // This is a very specific optimization for a fast lighthouse reply. + if lighthouseTriggered && !remotesHaveChanged { + // If we didn't return here a lighthouse could cause us to aggressively send handshakes + return + } + + hostinfo.HandshakeLastRemotes = remotes + + // TODO: this will generate a load of queries for hosts with only 1 ip + // (such as ones registered to the lighthouse with only a private IP) + // So we only do it one time after attempting 5 handshakes already. + if len(remotes) <= 1 && hostinfo.HandshakeCounter == 5 { // If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse // Our vpnIp here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about // the learned public ip for them. Query again to short circuit the promotion counter @@ -182,12 +188,18 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, l } }) - // Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout - if len(sentTo) > 0 { + // Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout, + // so only log when the list of remotes has changed + if remotesHaveChanged { hostinfo.logger(c.l).WithField("udpAddrs", sentTo). WithField("initiatorIndex", hostinfo.localIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Info("Handshake message sent") + } else if c.l.IsLevelEnabled(logrus.DebugLevel) { + hostinfo.logger(c.l).WithField("udpAddrs", sentTo). + WithField("initiatorIndex", hostinfo.localIndexId). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Debug("Handshake message sent") } if c.config.useRelays && len(hostinfo.remotes.relays) > 0 { diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 413a50a..84b8ef6 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -66,46 +66,6 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { assert.NotContains(t, blah.pendingHostMap.Hosts, ip) } -func Test_NewHandshakeManagerTrigger(t *testing.T) { - l := test.NewLogger() - _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2")) - preferredRanges := []*net.IPNet{localrange} - mw := &mockEncWriter{} - mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) - lh := newTestLighthouse() - - blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.Conn{}, defaultHandshakeConfig) - - now := time.Now() - blah.NextOutboundHandshakeTimerTick(now, mw) - - assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) - - hi := blah.AddVpnIp(ip, nil) - hi.HandshakeReady = true - assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) - assert.Equal(t, 0, hi.HandshakeCounter, "Should not have attempted a handshake yet") - - // Trigger the same method the channel will but, this should set our remotes pointer - blah.handleOutbound(ip, mw, true) - assert.Equal(t, 1, hi.HandshakeCounter, "Trigger should have done a handshake attempt") - assert.NotNil(t, hi.remotes, "Manager should have set my remotes pointer") - - // Make sure the trigger doesn't double schedule the timer entry - assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) - - uaddr := udp.NewAddrFromString("10.1.1.1:4242") - hi.remotes.unlockedPrependV4(ip, NewIp4AndPort(uaddr.IP, uint32(uaddr.Port))) - - // We now have remotes but only the first trigger should have pushed things forward - blah.handleOutbound(ip, mw, true) - assert.Equal(t, 1, hi.HandshakeCounter, "Trigger should have not done a handshake attempt") - assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) -} - func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) { for _, i := range tw.t.wheel { n := i.Head diff --git a/hostmap.go b/hostmap.go index 231beb1..185ecf5 100644 --- a/hostmap.go +++ b/hostmap.go @@ -155,22 +155,23 @@ func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) { type HostInfo struct { sync.RWMutex - remote *udp.Addr - remotes *RemoteList - promoteCounter atomic.Uint32 - ConnectionState *ConnectionState - handshakeStart time.Time //todo: this an entry in the handshake manager - HandshakeReady bool //todo: being in the manager means you are ready - HandshakeCounter int //todo: another handshake manager entry - HandshakeComplete bool //todo: this should go away in favor of ConnectionState.ready - HandshakePacket map[uint8][]byte //todo: this is other handshake manager entry - packetStore []*cachedPacket //todo: this is other handshake manager entry - remoteIndexId uint32 - localIndexId uint32 - vpnIp iputil.VpnIp - recvError int - remoteCidr *cidr.Tree4 - relayState RelayState + remote *udp.Addr + remotes *RemoteList + promoteCounter atomic.Uint32 + ConnectionState *ConnectionState + handshakeStart time.Time //todo: this an entry in the handshake manager + HandshakeReady bool //todo: being in the manager means you are ready + HandshakeCounter int //todo: another handshake manager entry + HandshakeLastRemotes []*udp.Addr //todo: another handshake manager entry, which remotes we sent to last time + HandshakeComplete bool //todo: this should go away in favor of ConnectionState.ready + HandshakePacket map[uint8][]byte //todo: this is other handshake manager entry + packetStore []*cachedPacket //todo: this is other handshake manager entry + remoteIndexId uint32 + localIndexId uint32 + vpnIp iputil.VpnIp + recvError int + remoteCidr *cidr.Tree4 + relayState RelayState // lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH // for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like diff --git a/inside.go b/inside.go index 0734883..ddfaa20 100644 --- a/inside.go +++ b/inside.go @@ -153,7 +153,13 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo { // If this is a static host, we don't need to wait for the HostQueryReply // We can trigger the handshake right now - if _, ok := f.lightHouse.GetStaticHostList()[vpnIp]; ok { + _, doTrigger := f.lightHouse.GetStaticHostList()[vpnIp] + if !doTrigger { + // Add any calculated remotes, and trigger early handshake if one found + doTrigger = f.lightHouse.addCalculatedRemotes(vpnIp) + } + + if doTrigger { select { case f.handshakeManager.trigger <- vpnIp: default: diff --git a/lighthouse.go b/lighthouse.go index 60e1f29..a3341b4 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -12,6 +12,7 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" @@ -72,6 +73,8 @@ type LightHouse struct { // IP's of relays that can be used by peers to access me relaysForMe atomic.Pointer[[]iputil.VpnIp] + calculatedRemotes atomic.Pointer[cidr.Tree4] // Maps VpnIp to []*calculatedRemote + metrics *MessageMetrics metricHolepunchTx metrics.Counter l *logrus.Logger @@ -161,6 +164,10 @@ func (lh *LightHouse) GetRelaysForMe() []iputil.VpnIp { return *lh.relaysForMe.Load() } +func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4 { + return lh.calculatedRemotes.Load() +} + func (lh *LightHouse) GetUpdateInterval() int64 { return lh.interval.Load() } @@ -237,6 +244,19 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { } } + if initial || c.HasChanged("lighthouse.calculated_remotes") { + cr, err := NewCalculatedRemotesFromConfig(c, "lighthouse.calculated_remotes") + if err != nil { + return util.NewContextualError("Invalid lighthouse.calculated_remotes", nil, err) + } + + lh.calculatedRemotes.Store(cr) + if !initial { + //TODO: a diff will be annoyingly difficult + lh.l.Info("lighthouse.calculated_remotes has changed") + } + } + //NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config if initial || c.HasChanged("static_host_map") { staticList := make(map[iputil.VpnIp]struct{}) @@ -488,6 +508,39 @@ func (lh *LightHouse) addStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr, stat staticList[vpnIp] = struct{}{} } +// addCalculatedRemotes adds any calculated remotes based on the +// lighthouse.calculated_remotes configuration. It returns true if any +// calculated remotes were added +func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool { + tree := lh.getCalculatedRemotes() + if tree == nil { + return false + } + value := tree.MostSpecificContains(vpnIp) + if value == nil { + return false + } + calculatedRemotes := value.([]*calculatedRemote) + + var calculated []*Ip4AndPort + for _, cr := range calculatedRemotes { + c := cr.Apply(vpnIp) + if c != nil { + calculated = append(calculated, c) + } + } + + lh.Lock() + am := lh.unlockedGetRemoteList(vpnIp) + am.Lock() + defer am.Unlock() + lh.Unlock() + + am.unlockedSetV4(lh.myVpnIp, vpnIp, calculated, lh.unlockedShouldAddV4) + + return len(calculated) > 0 +} + // unlockedGetRemoteList assumes you have the lh lock func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList { am, ok := lh.addrMap[vpnIp] diff --git a/udp/udp_all.go b/udp/udp_all.go index a4a462e..093bf69 100644 --- a/udp/udp_all.go +++ b/udp/udp_all.go @@ -64,6 +64,22 @@ func (ua *Addr) Copy() *Addr { return &nu } +type AddrSlice []*Addr + +func (a AddrSlice) Equal(b AddrSlice) bool { + if len(a) != len(b) { + return false + } + + for i := range a { + if !a[i].Equals(b[i]) { + return false + } + } + + return true +} + func ParseIPAndPort(s string) (net.IP, uint16, error) { rIp, sPort, err := net.SplitHostPort(s) if err != nil {