diff --git a/firewall.go b/firewall.go index adecbe81..904c71b2 100644 --- a/firewall.go +++ b/firewall.go @@ -58,8 +58,9 @@ type Firewall struct { routableNetworks *bart.Lite // assignedNetworks is a list of vpn networks assigned to us in the certificate. - assignedNetworks []netip.Prefix - hasUnsafeNetworks bool + assignedNetworks []netip.Prefix + // unsafeNetworks is the list of unsafe networks issued to us in the certificate + unsafeNetworks []netip.Prefix rules string rulesVersion uint16 @@ -158,10 +159,9 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur assignedNetworks = append(assignedNetworks, network) } - hasUnsafeNetworks := false - for _, n := range c.UnsafeNetworks() { + unsafeNetworks := c.UnsafeNetworks() + for _, n := range unsafeNetworks { routableNetworks.Insert(n) - hasUnsafeNetworks = true } return &Firewall{ @@ -169,15 +169,15 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur Conns: make(map[firewall.Packet]*conn), TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax), }, - InRules: newFirewallTable(), - OutRules: newFirewallTable(), - TCPTimeout: tcpTimeout, - UDPTimeout: UDPTimeout, - DefaultTimeout: defaultTimeout, - routableNetworks: routableNetworks, - assignedNetworks: assignedNetworks, - hasUnsafeNetworks: hasUnsafeNetworks, - l: l, + InRules: newFirewallTable(), + OutRules: newFirewallTable(), + TCPTimeout: tcpTimeout, + UDPTimeout: UDPTimeout, + DefaultTimeout: defaultTimeout, + routableNetworks: routableNetworks, + assignedNetworks: assignedNetworks, + unsafeNetworks: unsafeNetworks, + l: l, incomingMetrics: firewallMetrics{ droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil), @@ -897,7 +897,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error { } if localCidr == "" { - if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny { + if len(f.unsafeNetworks) == 0 || f.defaultLocalCIDRAny { flc.Any = true return nil } diff --git a/interface.go b/interface.go index 32f5c2a6..f96e431a 100644 --- a/interface.go +++ b/interface.go @@ -7,6 +7,7 @@ import ( "io" "log/slog" "net/netip" + "slices" "sync" "sync/atomic" "time" @@ -14,6 +15,7 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -375,13 +377,22 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) { } func (f *Interface) reloadFirewall(c *config.C) { - //TODO: need to trigger/detect if the certificate changed too - if c.HasChanged("firewall") == false { + cs := f.pki.getCertState() + curCert := cs.getCertificate(cert.Version2) + if curCert == nil { + curCert = cs.getCertificate(cert.Version1) + } + + // The firewall builds its routableNetworks set from the certificate's UnsafeNetworks at construction. + // Check to see if that set has changed, and if so, rebuild the firewall. + certUnsafeChanged := curCert != nil && !slices.Equal(curCert.UnsafeNetworks(), f.firewall.unsafeNetworks) + + if !c.HasChanged("firewall") && !certUnsafeChanged { f.l.Debug("No firewall config change detected") return } - fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c) + fw, err := NewFirewallFromConfig(f.l, cs, c) if err != nil { f.l.Error("Error while creating firewall during reload", "error", err) return diff --git a/interface_emit_test.go b/interface_emit_test.go new file mode 100644 index 00000000..b0a9d025 --- /dev/null +++ b/interface_emit_test.go @@ -0,0 +1,73 @@ +//go:build linux || darwin + +package nebula + +import ( + "context" + "net/netip" + "testing" + "time" + + "github.com/rcrowley/go-metrics" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/firewall" + "github.com/slackhq/nebula/overlay/overlaytest" + "github.com/slackhq/nebula/test" + "github.com/slackhq/nebula/udp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test_emitStats_primesGauges covers issue #907: a Prometheus scrape that +// landed before the first ticker fire used to read 0 for the cert gauges. +// emitStats now primes the gauges before entering the ticker loop. We assert +// the gauge is zero before the first call and non-zero after. +func Test_emitStats_primesGauges(t *testing.T) { + defer metrics.DefaultRegistry.UnregisterAll() + + l := test.NewLogger() + hostMap := newHostMap(l) + preferredRanges := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")} + hostMap.preferredRanges.Store(&preferredRanges) + + notAfter := time.Now().Add(time.Hour) + cs := &CertState{ + initiatingVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1, notAfter: notAfter}, + v1Credential: nil, + } + + lh := newTestLighthouse() + ifce := &Interface{ + hostMap: hostMap, + inside: &overlaytest.NoopTun{}, + outside: &udp.NoopConn{}, + firewall: &Firewall{Conntrack: &FirewallConntrack{Conns: map[firewall.Packet]*conn{}}}, + lightHouse: lh, + pki: &PKI{}, + handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), + l: l, + // On linux, udp.NewUDPStatsEmitter indexes writers[0] and asserts to + // *udp.StdConn. A zero value works: getMemInfo sees a nil rawConn, + // returns an error, and the emitter falls through to a no-op. + writers: []udp.Conn{&udp.StdConn{}}, + } + ifce.pki.cs.Store(cs) + + ttlGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil) + require.Zero(t, ttlGauge.Value(), "gauge should be zero before emitStats runs") + + // Pre-cancel the context so emitStats returns after priming the gauges + // without ever reading from ticker.C. The one hour interval is just a + // belt-and-suspenders, the test does not expect the ticker to fire. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + ifce.emitStats(ctx, time.Hour) + + ttl := ttlGauge.Value() + assert.Positive(t, ttl, "ttl gauge should be primed by emitStats before its first tick") + assert.LessOrEqual(t, ttl, int64(3600)) + assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.initiating_version", nil).Value()) + assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.max_version", nil).Value()) +} diff --git a/interface_test.go b/interface_test.go index b0a9d025..1b912bbb 100644 --- a/interface_test.go +++ b/interface_test.go @@ -1,73 +1,120 @@ -//go:build linux || darwin - package nebula import ( - "context" "net/netip" "testing" - "time" - "github.com/rcrowley/go-metrics" "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/overlay/overlaytest" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" - "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -// Test_emitStats_primesGauges covers issue #907: a Prometheus scrape that -// landed before the first ticker fire used to read 0 for the cert gauges. -// emitStats now primes the gauges before entering the ticker loop. We assert -// the gauge is zero before the first call and non-zero after. -func Test_emitStats_primesGauges(t *testing.T) { - defer metrics.DefaultRegistry.UnregisterAll() - +// TestReloadFirewall_CertUnsafeNetworksChanged verifies that reloadFirewall +// rebuilds the firewall when only the certificate's UnsafeNetworks have changed, +// even if the firewall section of the YAML has not. +func TestReloadFirewall_CertUnsafeNetworksChanged(t *testing.T) { l := test.NewLogger() - hostMap := newHostMap(l) - preferredRanges := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")} - hostMap.preferredRanges.Store(&preferredRanges) - notAfter := time.Now().Add(time.Hour) - cs := &CertState{ - initiatingVersion: cert.Version1, - privateKey: []byte{}, - v1Cert: &dummyCert{version: cert.Version1, notAfter: notAfter}, - v1Credential: nil, + vpnNet := netip.MustParsePrefix("10.0.0.1/24") + initialUnsafe := []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")} + + // dummyCert avoids dragging the real signing pipeline into a unit test. + c1 := &dummyCert{ + version: cert.Version2, + networks: []netip.Prefix{vpnNet}, + unsafeNetworks: initialUnsafe, + } + pki := &PKI{} + pki.cs.Store(&CertState{v2Cert: c1, initiatingVersion: cert.Version2}) + + rawYAML := `firewall: + outbound: + - port: any + proto: any + host: any + inbound: + - port: any + proto: any + host: any +` + cfg := config.NewC(l) + require.NoError(t, cfg.LoadString(rawYAML)) + + fw, err := NewFirewallFromConfig(l, pki.getCertState(), cfg) + require.NoError(t, err) + require.Equal(t, initialUnsafe, fw.unsafeNetworks) + + f := &Interface{ + pki: pki, + firewall: fw, + l: l, } - lh := newTestLighthouse() - ifce := &Interface{ - hostMap: hostMap, - inside: &overlaytest.NoopTun{}, - outside: &udp.NoopConn{}, - firewall: &Firewall{Conntrack: &FirewallConntrack{Conns: map[firewall.Packet]*conn{}}}, - lightHouse: lh, - pki: &PKI{}, - handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), - l: l, - // On linux, udp.NewUDPStatsEmitter indexes writers[0] and asserts to - // *udp.StdConn. A zero value works: getMemInfo sees a nil rawConn, - // returns an error, and the emitter falls through to a no-op. - writers: []udp.Conn{&udp.StdConn{}}, + // Swap the cert with a different UnsafeNetworks set. + newUnsafe := []netip.Prefix{ + netip.MustParsePrefix("198.51.100.0/24"), + netip.MustParsePrefix("203.0.113.0/24"), } - ifce.pki.cs.Store(cs) + c2 := &dummyCert{ + version: cert.Version2, + networks: []netip.Prefix{vpnNet}, + unsafeNetworks: newUnsafe, + } + pki.cs.Store(&CertState{v2Cert: c2, initiatingVersion: cert.Version2}) - ttlGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil) - require.Zero(t, ttlGauge.Value(), "gauge should be zero before emitStats runs") + // Reload with the same YAML so HasChanged("firewall") reports false. + require.NoError(t, cfg.ReloadConfigString(rawYAML)) + require.False(t, cfg.HasChanged("firewall")) - // Pre-cancel the context so emitStats returns after priming the gauges - // without ever reading from ticker.C. The one hour interval is just a - // belt-and-suspenders, the test does not expect the ticker to fire. - ctx, cancel := context.WithCancel(context.Background()) - cancel() - ifce.emitStats(ctx, time.Hour) + f.reloadFirewall(cfg) - ttl := ttlGauge.Value() - assert.Positive(t, ttl, "ttl gauge should be primed by emitStats before its first tick") - assert.LessOrEqual(t, ttl, int64(3600)) - assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.initiating_version", nil).Value()) - assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.max_version", nil).Value()) + assert.NotSame(t, fw, f.firewall, "firewall pointer should have been replaced") + assert.Equal(t, newUnsafe, f.firewall.unsafeNetworks) + assert.True(t, f.firewall.routableNetworks.Contains(netip.MustParseAddr("203.0.113.5"))) +} + +// TestReloadFirewall_NoChange verifies that reloadFirewall is a no-op when +// neither the firewall config nor the cert's UnsafeNetworks have changed. +func TestReloadFirewall_NoChange(t *testing.T) { + l := test.NewLogger() + + vpnNet := netip.MustParsePrefix("10.0.0.1/24") + unsafe := []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")} + + c1 := &dummyCert{ + version: cert.Version2, + networks: []netip.Prefix{vpnNet}, + unsafeNetworks: unsafe, + } + pki := &PKI{} + pki.cs.Store(&CertState{v2Cert: c1, initiatingVersion: cert.Version2}) + + rawYAML := `firewall: + outbound: + - port: any + proto: any + host: any + inbound: + - port: any + proto: any + host: any +` + cfg := config.NewC(l) + require.NoError(t, cfg.LoadString(rawYAML)) + + fw, err := NewFirewallFromConfig(l, pki.getCertState(), cfg) + require.NoError(t, err) + + f := &Interface{ + pki: pki, + firewall: fw, + l: l, + } + + require.NoError(t, cfg.ReloadConfigString(rawYAML)) + f.reloadFirewall(cfg) + + assert.Same(t, fw, f.firewall, "firewall should not have been replaced") }