From 312a01dc093593dc59d01cca2522c9791c07e851 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Mon, 14 Mar 2022 12:35:13 -0500 Subject: [PATCH] Lighthouse reload support (#649) Co-authored-by: John Maguire --- config/config.go | 26 ++++ connection_manager_test.go | 6 +- control.go | 3 +- handshake.go | 2 +- handshake_ix.go | 4 +- handshake_manager_test.go | 13 +- inside.go | 2 +- lighthouse.go | 298 +++++++++++++++++++++++++++++-------- lighthouse_test.go | 95 ++++++++---- main.go | 97 ++---------- outside.go | 2 +- punchy.go | 93 +++++++++--- punchy_test.go | 52 +++++-- 13 files changed, 471 insertions(+), 222 deletions(-) diff --git a/config/config.go b/config/config.go index 2328007..c51a78c 100644 --- a/config/config.go +++ b/config/config.go @@ -11,6 +11,7 @@ import ( "sort" "strconv" "strings" + "sync" "syscall" "time" @@ -26,6 +27,7 @@ type C struct { oldSettings map[interface{}]interface{} callbacks []func(*C) l *logrus.Logger + reloadLock sync.Mutex } func NewC(l *logrus.Logger) *C { @@ -133,6 +135,9 @@ func (c *C) CatchHUP(ctx context.Context) { } func (c *C) ReloadConfig() { + c.reloadLock.Lock() + defer c.reloadLock.Unlock() + c.oldSettings = make(map[interface{}]interface{}) for k, v := range c.Settings { c.oldSettings[k] = v @@ -149,6 +154,27 @@ func (c *C) ReloadConfig() { } } +func (c *C) ReloadConfigString(raw string) error { + c.reloadLock.Lock() + defer c.reloadLock.Unlock() + + c.oldSettings = make(map[interface{}]interface{}) + for k, v := range c.Settings { + c.oldSettings[k] = v + } + + err := c.LoadString(raw) + if err != nil { + return err + } + + for _, v := range c.callbacks { + v(c) + } + + return nil +} + // GetString will get the string for k or return the default d if not found or invalid func (c *C) GetString(k, d string) string { r := c.Get(k) diff --git a/connection_manager_test.go b/connection_manager_test.go index 80a0178..bae48e5 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -35,7 +35,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { rawCertificateNoKey: []byte{}, } - lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false) + lh := &LightHouse{l: l, atomicStaticList: make(map[iputil.VpnIp]struct{}), atomicLighthouses: make(map[iputil.VpnIp]struct{})} ifce := &Interface{ hostMap: hostMap, inside: &test.NoopTun{}, @@ -104,7 +104,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { rawCertificateNoKey: []byte{}, } - lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false) + lh := &LightHouse{l: l, atomicStaticList: make(map[iputil.VpnIp]struct{}), atomicLighthouses: make(map[iputil.VpnIp]struct{})} ifce := &Interface{ hostMap: hostMap, inside: &test.NoopTun{}, @@ -213,7 +213,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { rawCertificateNoKey: []byte{}, } - lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false) + lh := &LightHouse{l: l, atomicStaticList: make(map[iputil.VpnIp]struct{}), atomicLighthouses: make(map[iputil.VpnIp]struct{})} ifce := &Interface{ hostMap: hostMap, inside: &test.NoopTun{}, diff --git a/control.go b/control.go index c905d23..27521da 100644 --- a/control.go +++ b/control.go @@ -160,9 +160,10 @@ func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool { func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { //TODO: this is probably better as a function in ConnectionManager or HostMap directly c.f.hostMap.Lock() + lighthouses := c.f.lightHouse.GetLighthouses() for _, h := range c.f.hostMap.Hosts { if excludeLighthouses { - if _, ok := c.f.lightHouse.lighthouses[h.vpnIp]; ok { + if _, ok := lighthouses[h.vpnIp]; ok { continue } } diff --git a/handshake.go b/handshake.go index a08fb2e..fa66711 100644 --- a/handshake.go +++ b/handshake.go @@ -7,7 +7,7 @@ import ( func HandleIncomingHandshake(f *Interface, addr *udp.Addr, packet []byte, h *header.H, hostinfo *HostInfo) { // First remote allow list check before we know the vpnIp - if !f.lightHouse.remoteAllowList.AllowUnknownVpnIp(addr.IP) { + if !f.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) { f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return } diff --git a/handshake_ix.go b/handshake_ix.go index a0defc6..9e6a89b 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -114,7 +114,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, packet []byte, h *header.H) return } - if !f.lightHouse.remoteAllowList.Allow(vpnIp, addr.IP) { + if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.IP) { f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return } @@ -321,7 +321,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, hostinfo *HostInfo, packet hostinfo.Lock() defer hostinfo.Unlock() - if !f.lightHouse.remoteAllowList.Allow(hostinfo.vpnIp, addr.IP) { + if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) { f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return false } diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 0ca651c..df36be0 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -21,8 +21,12 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { preferredRanges := []*net.IPNet{localrange} mw := &mockEncWriter{} mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) + lh := &LightHouse{ + atomicStaticList: make(map[iputil.VpnIp]struct{}), + atomicLighthouses: make(map[iputil.VpnIp]struct{}), + } - blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udp.Conn{}, defaultHandshakeConfig) + blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.Conn{}, defaultHandshakeConfig) now := time.Now() blah.NextOutboundHandshakeTimerTick(now, mw) @@ -74,7 +78,12 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) { preferredRanges := []*net.IPNet{localrange} mw := &mockEncWriter{} mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) - lh := &LightHouse{addrMap: make(map[iputil.VpnIp]*RemoteList), l: l} + lh := &LightHouse{ + addrMap: make(map[iputil.VpnIp]*RemoteList), + l: l, + atomicStaticList: make(map[iputil.VpnIp]struct{}), + atomicLighthouses: make(map[iputil.VpnIp]struct{}), + } blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.Conn{}, defaultHandshakeConfig) diff --git a/inside.go b/inside.go index 988bb65..3bbdc5f 100644 --- a/inside.go +++ b/inside.go @@ -110,7 +110,7 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo { // If this is a static host, we don't need to wait for the HostQueryReply // We can trigger the handshake right now - if _, ok := f.lightHouse.staticList[vpnIp]; ok { + if _, ok := f.lightHouse.GetStaticHostList()[vpnIp]; ok { select { case f.handshakeManager.trigger <- vpnIp: default: diff --git a/lighthouse.go b/lighthouse.go index 6c07440..e3a9bc7 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -7,14 +7,18 @@ import ( "fmt" "net" "sync" + "sync/atomic" "time" + "unsafe" "github.com/golang/protobuf/proto" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" + "github.com/slackhq/nebula/util" ) //TODO: if a lighthouse doesn't have an answer, clients AGGRESSIVELY REQUERY.. why? handshake manager and/or getOrHandshake? @@ -28,7 +32,9 @@ type LightHouse struct { amLighthouse bool myVpnIp iputil.VpnIp myVpnZeros iputil.VpnIp + myVpnNet *net.IPNet punchConn *udp.Conn + punchy *Punchy // Local cache of answers from light houses // map of vpn Ip to answers @@ -39,80 +45,240 @@ type LightHouse struct { // respond with. // - When we are not a lighthouse, this filters which addresses we accept // from lighthouses. - remoteAllowList *RemoteAllowList + atomicRemoteAllowList *RemoteAllowList // filters local addresses that we advertise to lighthouses - localAllowList *LocalAllowList + atomicLocalAllowList *LocalAllowList // used to trigger the HandshakeManager when we receive HostQueryReply handshakeTrigger chan<- iputil.VpnIp - // staticList exists to avoid having a bool in each addrMap entry + // atomicStaticList exists to avoid having a bool in each addrMap entry // since static should be rare - staticList map[iputil.VpnIp]struct{} - lighthouses map[iputil.VpnIp]struct{} - interval int - nebulaPort uint32 // 32 bits because protobuf does not have a uint16 - punchBack bool - punchDelay time.Duration + atomicStaticList map[iputil.VpnIp]struct{} + atomicLighthouses map[iputil.VpnIp]struct{} + + atomicInterval int64 + updateCancel context.CancelFunc + updateParentCtx context.Context + updateUdp udp.EncWriter + nebulaPort uint32 // 32 bits because protobuf does not have a uint16 metrics *MessageMetrics metricHolepunchTx metrics.Counter l *logrus.Logger } -func NewLightHouse(l *logrus.Logger, amLighthouse bool, myVpnIpNet *net.IPNet, ips []iputil.VpnIp, interval int, nebulaPort uint32, pc *udp.Conn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse { - ones, _ := myVpnIpNet.Mask.Size() - h := LightHouse{ - amLighthouse: amLighthouse, - myVpnIp: iputil.Ip2VpnIp(myVpnIpNet.IP), - myVpnZeros: iputil.VpnIp(32 - ones), - addrMap: make(map[iputil.VpnIp]*RemoteList), - nebulaPort: nebulaPort, - lighthouses: make(map[iputil.VpnIp]struct{}), - staticList: make(map[iputil.VpnIp]struct{}), - interval: interval, - punchConn: pc, - punchBack: punchBack, - punchDelay: punchDelay, - l: l, +// NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object +// addrMap should be nil unless this is during a config reload +func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc *udp.Conn, p *Punchy) (*LightHouse, error) { + amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) + nebulaPort := uint32(c.GetInt("listen.port", 0)) + if amLighthouse && nebulaPort == 0 { + return nil, util.NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil) } - if metricsEnabled { - h.metrics = newLighthouseMetrics() + ones, _ := myVpnNet.Mask.Size() + h := LightHouse{ + amLighthouse: amLighthouse, + myVpnIp: iputil.Ip2VpnIp(myVpnNet.IP), + myVpnZeros: iputil.VpnIp(32 - ones), + myVpnNet: myVpnNet, + addrMap: make(map[iputil.VpnIp]*RemoteList), + nebulaPort: nebulaPort, + atomicLighthouses: make(map[iputil.VpnIp]struct{}), + atomicStaticList: make(map[iputil.VpnIp]struct{}), + punchConn: pc, + punchy: p, + l: l, + } + if c.GetBool("stats.lighthouse_metrics", false) { + h.metrics = newLighthouseMetrics() h.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil) } else { h.metricHolepunchTx = metrics.NilCounter{} } - for _, ip := range ips { - h.lighthouses[ip] = struct{}{} + err := h.reload(c, true) + if err != nil { + return nil, err } - return &h + c.RegisterReloadCallback(func(c *config.C) { + err := h.reload(c, false) + switch v := err.(type) { + case util.ContextualError: + v.Log(l) + case error: + l.WithError(err).Error("failed to reload lighthouse") + } + }) + + return &h, nil } -func (lh *LightHouse) SetRemoteAllowList(allowList *RemoteAllowList) { - lh.Lock() - defer lh.Unlock() - - lh.remoteAllowList = allowList +func (lh *LightHouse) GetStaticHostList() map[iputil.VpnIp]struct{} { + return *(*map[iputil.VpnIp]struct{})(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicStaticList)))) } -func (lh *LightHouse) SetLocalAllowList(allowList *LocalAllowList) { - lh.Lock() - defer lh.Unlock() - - lh.localAllowList = allowList +func (lh *LightHouse) GetLighthouses() map[iputil.VpnIp]struct{} { + return *(*map[iputil.VpnIp]struct{})(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicLighthouses)))) } -func (lh *LightHouse) ValidateLHStaticEntries() error { - for lhIP, _ := range lh.lighthouses { - if _, ok := lh.staticList[lhIP]; !ok { - return fmt.Errorf("Lighthouse %s does not have a static_host_map entry", lhIP) +func (lh *LightHouse) GetRemoteAllowList() *RemoteAllowList { + return (*RemoteAllowList)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicRemoteAllowList)))) +} + +func (lh *LightHouse) GetLocalAllowList() *LocalAllowList { + return (*LocalAllowList)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicLocalAllowList)))) +} + +func (lh *LightHouse) GetUpdateInterval() int64 { + return atomic.LoadInt64(&lh.atomicInterval) +} + +func (lh *LightHouse) reload(c *config.C, initial bool) error { + if initial || c.HasChanged("lighthouse.interval") { + atomic.StoreInt64(&lh.atomicInterval, int64(c.GetInt("lighthouse.interval", 10))) + + if !initial { + lh.l.Infof("lighthouse.interval changed to %v", lh.atomicInterval) + + if lh.updateCancel != nil { + // May not always have a running routine + lh.updateCancel() + } + + lh.LhUpdateWorker(lh.updateParentCtx, lh.updateUdp) } } + + if initial || c.HasChanged("lighthouse.remote_allow_list") || c.HasChanged("lighthouse.remote_allow_ranges") { + ral, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges") + if err != nil { + return util.NewContextualError("Invalid lighthouse.remote_allow_list", nil, err) + } + + atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicRemoteAllowList)), unsafe.Pointer(ral)) + if !initial { + //TODO: a diff will be annoyingly difficult + lh.l.Info("lighthouse.remote_allow_list and/or lighthouse.remote_allow_ranges has changed") + } + } + + if initial || c.HasChanged("lighthouse.local_allow_list") { + lal, err := NewLocalAllowListFromConfig(c, "lighthouse.local_allow_list") + if err != nil { + return util.NewContextualError("Invalid lighthouse.local_allow_list", nil, err) + } + + atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicLocalAllowList)), unsafe.Pointer(lal)) + if !initial { + //TODO: a diff will be annoyingly difficult + lh.l.Info("lighthouse.local_allow_list has changed") + } + } + + //NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config + if initial || c.HasChanged("static_host_map") { + staticList := make(map[iputil.VpnIp]struct{}) + err := lh.loadStaticMap(c, lh.myVpnNet, staticList) + if err != nil { + return err + } + + atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicStaticList)), unsafe.Pointer(&staticList)) + if !initial { + //TODO: we should remove any remote list entries for static hosts that were removed/modified? + lh.l.Info("static_host_map has changed") + } + + } + + if initial || c.HasChanged("lighthouse.hosts") { + lhMap := make(map[iputil.VpnIp]struct{}) + err := lh.parseLighthouses(c, lh.myVpnNet, lhMap) + if err != nil { + return err + } + + atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicLighthouses)), unsafe.Pointer(&lhMap)) + if !initial { + //NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic + lh.l.Info("lighthouse.hosts has changed") + } + } + + return nil +} + +func (lh *LightHouse) parseLighthouses(c *config.C, tunCidr *net.IPNet, lhMap map[iputil.VpnIp]struct{}) error { + lhs := c.GetStringSlice("lighthouse.hosts", []string{}) + if lh.amLighthouse && len(lhs) != 0 { + lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config") + } + + for i, host := range lhs { + ip := net.ParseIP(host) + if ip == nil { + return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil) + } + if !tunCidr.Contains(ip) { + return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil) + } + lhMap[iputil.Ip2VpnIp(ip)] = struct{}{} + } + + if !lh.amLighthouse && len(lhMap) == 0 { + lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries") + } + + staticList := lh.GetStaticHostList() + for lhIP, _ := range lhMap { + if _, ok := staticList[lhIP]; !ok { + return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhIP) + } + } + + return nil +} + +func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList map[iputil.VpnIp]struct{}) error { + shm := c.GetMap("static_host_map", map[interface{}]interface{}{}) + i := 0 + + for k, v := range shm { + rip := net.ParseIP(fmt.Sprintf("%v", k)) + if rip == nil { + return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, nil) + } + + if !tunCidr.Contains(rip) { + return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": rip, "network": tunCidr.String(), "entry": i + 1}, nil) + } + + vpnIp := iputil.Ip2VpnIp(rip) + vals, ok := v.([]interface{}) + if ok { + for _, v := range vals { + ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v)) + if err != nil { + return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err) + } + lh.addStaticRemote(vpnIp, udp.NewAddr(ip, port), staticList) + } + + } else { + ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v)) + if err != nil { + return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err) + } + lh.addStaticRemote(vpnIp, udp.NewAddr(ip, port), staticList) + } + i++ + } + return nil } @@ -146,10 +312,11 @@ func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f udp.EncWriter) { return } - lh.metricTx(NebulaMeta_HostQuery, int64(len(lh.lighthouses))) + lighthouses := lh.GetLighthouses() + lh.metricTx(NebulaMeta_HostQuery, int64(len(lighthouses))) nb := make([]byte, 12, 12) out := make([]byte, mtu) - for n := range lh.lighthouses { + for n := range lighthouses { f.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out) } } @@ -197,7 +364,7 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (in func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) { // First we check the static mapping // and do nothing if it is there - if _, ok := lh.staticList[vpnIp]; ok { + if _, ok := lh.GetStaticHostList()[vpnIp]; ok { return } lh.Lock() @@ -211,10 +378,11 @@ func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) { lh.Unlock() } -// AddStaticRemote adds a static host entry for vpnIp as ourselves as the owner +// addStaticRemote adds a static host entry for vpnIp as ourselves as the owner // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client -func (lh *LightHouse) AddStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr) { +//NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it +func (lh *LightHouse) addStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr, staticList map[iputil.VpnIp]struct{}) { lh.Lock() am := lh.unlockedGetRemoteList(vpnIp) am.Lock() @@ -236,8 +404,8 @@ func (lh *LightHouse) AddStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr) { am.unlockedPrependV6(lh.myVpnIp, to) } - // Mark it as static - lh.staticList[vpnIp] = struct{}{} + // Mark it as static in the caller provided map + staticList[vpnIp] = struct{}{} } // unlockedGetRemoteList assumes you have the lh lock @@ -252,7 +420,7 @@ func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList { // unlockedShouldAddV4 checks if to is allowed by our allow list func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool { - allow := lh.remoteAllowList.AllowIpV4(vpnIp, iputil.VpnIp(to.Ip)) + allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, iputil.VpnIp(to.Ip)) if lh.l.Level >= logrus.TraceLevel { lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") } @@ -266,7 +434,7 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bo // unlockedShouldAddV6 checks if to is allowed by our allow list func (lh *LightHouse) unlockedShouldAddV6(vpnIp iputil.VpnIp, to *Ip6AndPort) bool { - allow := lh.remoteAllowList.AllowIpV6(vpnIp, to.Hi, to.Lo) + allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, to.Hi, to.Lo) if lh.l.Level >= logrus.TraceLevel { lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow") } @@ -287,7 +455,7 @@ func lhIp6ToIp(v *Ip6AndPort) net.IP { } func (lh *LightHouse) IsLighthouseIP(vpnIp iputil.VpnIp) bool { - if _, ok := lh.lighthouses[vpnIp]; ok { + if _, ok := lh.GetLighthouses()[vpnIp]; ok { return true } return false @@ -329,18 +497,24 @@ func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr { } func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) { - if lh.amLighthouse || lh.interval == 0 { + lh.updateParentCtx = ctx + lh.updateUdp = f + + interval := lh.GetUpdateInterval() + if lh.amLighthouse || interval == 0 { return } - clockSource := time.NewTicker(time.Second * time.Duration(lh.interval)) + clockSource := time.NewTicker(time.Second * time.Duration(interval)) + updateCtx, cancel := context.WithCancel(ctx) + lh.updateCancel = cancel defer clockSource.Stop() for { lh.SendUpdate(f) select { - case <-ctx.Done(): + case <-updateCtx.Done(): return case <-clockSource.C: continue @@ -352,7 +526,8 @@ func (lh *LightHouse) SendUpdate(f udp.EncWriter) { var v4 []*Ip4AndPort var v6 []*Ip6AndPort - for _, e := range *localIps(lh.l, lh.localAllowList) { + lal := lh.GetLocalAllowList() + for _, e := range *localIps(lh.l, lal) { if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.Ip2VpnIp(ip4)) { continue } @@ -373,7 +548,8 @@ func (lh *LightHouse) SendUpdate(f udp.EncWriter) { }, } - lh.metricTx(NebulaMeta_HostUpdateNotification, int64(len(lh.lighthouses))) + lighthouses := lh.GetLighthouses() + lh.metricTx(NebulaMeta_HostUpdateNotification, int64(len(lighthouses))) nb := make([]byte, 12, 12) out := make([]byte, mtu) @@ -383,7 +559,7 @@ func (lh *LightHouse) SendUpdate(f udp.EncWriter) { return } - for vpnIp := range lh.lighthouses { + for vpnIp := range lighthouses { f.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, mm, nb, out) } } @@ -609,7 +785,7 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i } go func() { - time.Sleep(lhh.lh.punchDelay) + time.Sleep(lhh.lh.punchy.GetDelay()) lhh.lh.metricHolepunchTx.Inc(1) lhh.lh.punchConn.WriteTo(empty, vpnPeer) }() @@ -631,7 +807,7 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i // This sends a nebula test packet to the host trying to contact us. In the case // of a double nat or other difficult scenario, this may help establish // a tunnel. - if lhh.lh.punchBack { + if lhh.lh.punchy.GetRespond() { queryVpnIp := iputil.VpnIp(n.Details.VpnIp) go func() { time.Sleep(time.Second * 5) diff --git a/lighthouse_test.go b/lighthouse_test.go index 41fde97..d6be9a3 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/golang/protobuf/proto" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" @@ -47,33 +48,32 @@ func TestNewLhQuery(t *testing.T) { func Test_lhStaticMapping(t *testing.T) { l := test.NewLogger() + _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16") lh1 := "10.128.0.2" - lh1IP := net.ParseIP(lh1) - udpServer, _ := udp.NewListener(l, "0.0.0.0", 0, true, 2) - - meh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []iputil.VpnIp{iputil.Ip2VpnIp(lh1IP)}, 10, 10003, udpServer, false, 1, false) - meh.AddStaticRemote(iputil.Ip2VpnIp(lh1IP), udp.NewAddr(lh1IP, uint16(4242))) - err := meh.ValidateLHStaticEntries() + c := config.NewC(l) + c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}} + c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} + _, err := NewLightHouseFromConfig(l, c, myVpnNet, nil, nil) assert.Nil(t, err) lh2 := "10.128.0.3" - lh2IP := net.ParseIP(lh2) - - meh = NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []iputil.VpnIp{iputil.Ip2VpnIp(lh1IP), iputil.Ip2VpnIp(lh2IP)}, 10, 10003, udpServer, false, 1, false) - meh.AddStaticRemote(iputil.Ip2VpnIp(lh1IP), udp.NewAddr(lh1IP, uint16(4242))) - err = meh.ValidateLHStaticEntries() - assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry") + c = config.NewC(l) + c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}} + c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}} + _, err = NewLightHouseFromConfig(l, c, myVpnNet, nil, nil) + assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") } func BenchmarkLighthouseHandleRequest(b *testing.B) { l := test.NewLogger() - lh1 := "10.128.0.2" - lh1IP := net.ParseIP(lh1) + _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0") - udpServer, _ := udp.NewListener(l, "0.0.0.0", 0, true, 2) - - lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{iputil.Ip2VpnIp(lh1IP)}, 10, 10003, udpServer, false, 1, false) + c := config.NewC(l) + lh, err := NewLightHouseFromConfig(l, c, myVpnNet, nil, nil) + if !assert.NoError(b, err) { + b.Fatal() + } hAddr := udp.NewAddrFromString("4.5.6.7:12345") hAddr2 := udp.NewAddrFromString("4.5.6.7:12346") @@ -160,8 +160,11 @@ func TestLighthouse_Memory(t *testing.T) { theirUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.3"), Port: 4242} theirVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.3")) - udpServer, _ := udp.NewListener(l, "0.0.0.0", 0, true, 2) - lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []iputil.VpnIp{}, 10, 10003, udpServer, false, 1, false) + c := config.NewC(l) + c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} + c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} + lh, err := NewLightHouseFromConfig(l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) + assert.NoError(t, err) lhh := lh.NewRequestHandler() // Test that my first update responds with just that @@ -179,9 +182,16 @@ func TestLighthouse_Memory(t *testing.T) { r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) - // Update a different host + // Update a different host and ask about it newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udp.Addr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh) + r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh) + assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) + + // Have both hosts ask about the other r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh) + assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) + + r = newLHHostRequest(myUdpAddr0, myVpnIp, theirVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) // Make sure we didn't get changed @@ -224,6 +234,18 @@ func TestLighthouse_Memory(t *testing.T) { assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good) } +func TestLighthouse_reload(t *testing.T) { + l := test.NewLogger() + c := config.NewC(l) + c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} + c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} + lh, err := NewLightHouseFromConfig(l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) + assert.NoError(t, err) + + c.Settings["static_host_map"] = map[interface{}]interface{}{"10.128.0.2": []interface{}{"1.1.1.1:4242"}} + lh.reload(c, false) +} + func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply { req := &NebulaMeta{ Type: NebulaMeta_HostQuery, @@ -237,7 +259,10 @@ func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh panic(err) } - w := &testEncWriter{} + filter := NebulaMeta_HostQueryReply + w := &testEncWriter{ + metaFilter: &filter, + } lhh.HandleRequest(fromAddr, myVpnIp, b, w) return w.lastReply } @@ -344,18 +369,22 @@ type testLhReply struct { } type testEncWriter struct { - lastReply testLhReply + lastReply testLhReply + metaFilter *NebulaMeta_MessageType } func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, _, _ []byte) { - tw.lastReply = testLhReply{ - nebType: t, - nebSubType: st, - vpnIp: vpnIp, - msg: &NebulaMeta{}, + msg := &NebulaMeta{} + err := proto.Unmarshal(p, msg) + if tw.metaFilter == nil || msg.Type == *tw.metaFilter { + tw.lastReply = testLhReply{ + nebType: t, + nebSubType: st, + vpnIp: vpnIp, + msg: msg, + } } - err := proto.Unmarshal(p, tw.lastReply.msg) if err != nil { panic(err) } @@ -363,7 +392,10 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udp.Addr) { - assert.Len(t, have, len(want)) + if !assert.Len(t, have, len(want)) { + return + } + for k, w := range want { if !(have[k].Ip == uint32(iputil.Ip2VpnIp(w.IP)) && have[k].Port == uint32(w.Port)) { assert.Fail(t, fmt.Sprintf("Response did not contain: %v:%v at %v; %v", w.IP, w.Port, k, translateV4toUdpAddr(have))) @@ -373,7 +405,10 @@ func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udp.Addr) { // assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match func assertUdpAddrInArray(t *testing.T, have []*udp.Addr, want ...*udp.Addr) { - assert.Len(t, have, len(want)) + if !assert.Len(t, have, len(want)) { + return + } + for k, w := range want { if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) { assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v; %v", w, k, have)) diff --git a/main.go b/main.go index 004f739..ec08817 100644 --- a/main.go +++ b/main.go @@ -3,13 +3,13 @@ package nebula import ( "context" "encoding/binary" + "errors" "fmt" "net" "time" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/sshd" "github.com/slackhq/nebula/udp" @@ -218,95 +218,18 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg go hostMap.Promoter(config.GetInt("promoter.interval")) */ - punchy := NewPunchyFromConfig(c) - if punchy.Punch && !configTest { + punchy := NewPunchyFromConfig(l, c) + if punchy.GetPunch() && !configTest { l.Info("UDP hole punching enabled") go hostMap.Punchy(ctx, udpConns[0]) } - amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) - - // fatal if am_lighthouse is enabled but we are using an ephemeral port - if amLighthouse && (c.GetInt("listen.port", 0) == 0) { - return nil, util.NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil) - } - - // warn if am_lighthouse is enabled but upstream lighthouses exists - rawLighthouseHosts := c.GetStringSlice("lighthouse.hosts", []string{}) - if amLighthouse && len(rawLighthouseHosts) != 0 { - l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config") - } - - lighthouseHosts := make([]iputil.VpnIp, len(rawLighthouseHosts)) - for i, host := range rawLighthouseHosts { - ip := net.ParseIP(host) - if ip == nil { - return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil) - } - if !tunCidr.Contains(ip) { - return nil, util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil) - } - lighthouseHosts[i] = iputil.Ip2VpnIp(ip) - } - - if !amLighthouse && len(lighthouseHosts) == 0 { - l.Warn("No lighthouses.hosts configured, this host will only be able to initiate tunnels with static_host_map entries") - } - - lightHouse := NewLightHouse( - l, - amLighthouse, - tunCidr, - lighthouseHosts, - //TODO: change to a duration - c.GetInt("lighthouse.interval", 10), - uint32(port), - udpConns[0], - punchy.Respond, - punchy.Delay, - c.GetBool("stats.lighthouse_metrics", false), - ) - - remoteAllowList, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges") - if err != nil { - return nil, util.NewContextualError("Invalid lighthouse.remote_allow_list", nil, err) - } - lightHouse.SetRemoteAllowList(remoteAllowList) - - localAllowList, err := NewLocalAllowListFromConfig(c, "lighthouse.local_allow_list") - if err != nil { - return nil, util.NewContextualError("Invalid lighthouse.local_allow_list", nil, err) - } - lightHouse.SetLocalAllowList(localAllowList) - - //TODO: Move all of this inside functions in lighthouse.go - for k, v := range c.GetMap("static_host_map", map[interface{}]interface{}{}) { - ip := net.ParseIP(fmt.Sprintf("%v", k)) - vpnIp := iputil.Ip2VpnIp(ip) - if !tunCidr.Contains(ip) { - return nil, util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil) - } - vals, ok := v.([]interface{}) - if ok { - for _, v := range vals { - ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v)) - if err != nil { - return nil, util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) - } - lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port)) - } - } else { - ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v)) - if err != nil { - return nil, util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) - } - lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port)) - } - } - - err = lightHouse.ValidateLHStaticEntries() - if err != nil { - l.WithError(err).Error("Lighthouse unreachable") + lightHouse, err := NewLightHouseFromConfig(l, c, tunCidr, udpConns[0], punchy) + switch { + case errors.As(err, &util.ContextualError{}): + return nil, err + case err != nil: + return nil, util.NewContextualError("Failed to initialize lighthouse handler", nil, err) } var messageMetrics *MessageMetrics @@ -411,7 +334,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg // Start DNS server last to allow using the nebula IP as lighthouse.dns.host var dnsStart func() - if amLighthouse && serveDns { + if lightHouse.amLighthouse && serveDns { l.Debugln("Starting dns server") dnsStart = dnsMain(l, hostMap, c) } diff --git a/outside.go b/outside.go index e0dba7d..ea96df1 100644 --- a/outside.go +++ b/outside.go @@ -157,7 +157,7 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) { func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udp.Addr) { if !hostinfo.remote.Equals(addr) { - if !f.lightHouse.remoteAllowList.Allow(hostinfo.vpnIp, addr.IP) { + if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) { hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming") return } diff --git a/punchy.go b/punchy.go index 90d7b94..d81ed83 100644 --- a/punchy.go +++ b/punchy.go @@ -1,34 +1,89 @@ package nebula import ( + "sync/atomic" "time" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" ) type Punchy struct { - Punch bool - Respond bool - Delay time.Duration + atomicPunch int32 + atomicRespond int32 + atomicDelay time.Duration + l *logrus.Logger } -func NewPunchyFromConfig(c *config.C) *Punchy { - p := &Punchy{} +func NewPunchyFromConfig(l *logrus.Logger, c *config.C) *Punchy { + p := &Punchy{l: l} - if c.IsSet("punchy.punch") { - p.Punch = c.GetBool("punchy.punch", false) - } else { - // Deprecated fallback - p.Punch = c.GetBool("punchy", false) - } + p.reload(c, true) + c.RegisterReloadCallback(func(c *config.C) { + p.reload(c, false) + }) - if c.IsSet("punchy.respond") { - p.Respond = c.GetBool("punchy.respond", false) - } else { - // Deprecated fallback - p.Respond = c.GetBool("punch_back", false) - } - - p.Delay = c.GetDuration("punchy.delay", time.Second) return p } + +func (p *Punchy) reload(c *config.C, initial bool) { + if initial { + var yes bool + if c.IsSet("punchy.punch") { + yes = c.GetBool("punchy.punch", false) + } else { + // Deprecated fallback + yes = c.GetBool("punchy", false) + } + + if yes { + atomic.StoreInt32(&p.atomicPunch, 1) + } else { + atomic.StoreInt32(&p.atomicPunch, 0) + } + + } else if c.HasChanged("punchy.punch") || c.HasChanged("punchy") { + //TODO: it should be relatively easy to support this, just need to be able to cancel the goroutine and boot it up from here + p.l.Warn("Changing punchy.punch with reload is not supported, ignoring.") + } + + if initial || c.HasChanged("punchy.respond") || c.HasChanged("punch_back") { + var yes bool + if c.IsSet("punchy.respond") { + yes = c.GetBool("punchy.respond", false) + } else { + // Deprecated fallback + yes = c.GetBool("punch_back", false) + } + + if yes { + atomic.StoreInt32(&p.atomicRespond, 1) + } else { + atomic.StoreInt32(&p.atomicRespond, 0) + } + + if !initial { + p.l.Infof("punchy.respond changed to %v", p.GetRespond()) + } + } + + //NOTE: this will not apply to any in progress operations, only the next one + if initial || c.HasChanged("punchy.delay") { + atomic.StoreInt64((*int64)(&p.atomicDelay), (int64)(c.GetDuration("punchy.delay", time.Second))) + if !initial { + p.l.Infof("punchy.delay changed to %s", p.GetDelay()) + } + } +} + +func (p *Punchy) GetPunch() bool { + return atomic.LoadInt32(&p.atomicPunch) == 1 +} + +func (p *Punchy) GetRespond() bool { + return atomic.LoadInt32(&p.atomicRespond) == 1 +} + +func (p *Punchy) GetDelay() time.Duration { + return (time.Duration)(atomic.LoadInt64((*int64)(&p.atomicDelay))) +} diff --git a/punchy_test.go b/punchy_test.go index 89b5136..0aa9b62 100644 --- a/punchy_test.go +++ b/punchy_test.go @@ -14,34 +14,58 @@ func TestNewPunchyFromConfig(t *testing.T) { c := config.NewC(l) // Test defaults - p := NewPunchyFromConfig(c) - assert.Equal(t, false, p.Punch) - assert.Equal(t, false, p.Respond) - assert.Equal(t, time.Second, p.Delay) + p := NewPunchyFromConfig(l, c) + assert.Equal(t, false, p.GetPunch()) + assert.Equal(t, false, p.GetRespond()) + assert.Equal(t, time.Second, p.GetDelay()) // punchy deprecation c.Settings["punchy"] = true - p = NewPunchyFromConfig(c) - assert.Equal(t, true, p.Punch) + p = NewPunchyFromConfig(l, c) + assert.Equal(t, true, p.GetPunch()) // punchy.punch c.Settings["punchy"] = map[interface{}]interface{}{"punch": true} - p = NewPunchyFromConfig(c) - assert.Equal(t, true, p.Punch) + p = NewPunchyFromConfig(l, c) + assert.Equal(t, true, p.GetPunch()) // punch_back deprecation c.Settings["punch_back"] = true - p = NewPunchyFromConfig(c) - assert.Equal(t, true, p.Respond) + p = NewPunchyFromConfig(l, c) + assert.Equal(t, true, p.GetRespond()) // punchy.respond c.Settings["punchy"] = map[interface{}]interface{}{"respond": true} c.Settings["punch_back"] = false - p = NewPunchyFromConfig(c) - assert.Equal(t, true, p.Respond) + p = NewPunchyFromConfig(l, c) + assert.Equal(t, true, p.GetRespond()) // punchy.delay c.Settings["punchy"] = map[interface{}]interface{}{"delay": "1m"} - p = NewPunchyFromConfig(c) - assert.Equal(t, time.Minute, p.Delay) + p = NewPunchyFromConfig(l, c) + assert.Equal(t, time.Minute, p.GetDelay()) +} + +func TestPunchy_reload(t *testing.T) { + l := test.NewLogger() + c := config.NewC(l) + delay, _ := time.ParseDuration("1m") + assert.NoError(t, c.LoadString(` +punchy: + delay: 1m + respond: false +`)) + p := NewPunchyFromConfig(l, c) + assert.Equal(t, delay, p.GetDelay()) + assert.Equal(t, false, p.GetRespond()) + + newDelay, _ := time.ParseDuration("10m") + assert.NoError(t, c.ReloadConfigString(` +punchy: + delay: 10m + respond: true +`)) + p.reload(c, false) + assert.Equal(t, newDelay, p.GetDelay()) + assert.Equal(t, true, p.GetRespond()) }