From d27ca77a540a067c66ebdb1a942a9eec3e56b889 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 8 May 2026 17:50:40 -0500 Subject: [PATCH] Make firewall reload when unsafe networks in the cert changes --- firewall.go | 30 ++++++------ interface.go | 17 +++++-- interface_test.go | 120 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 149 insertions(+), 18 deletions(-) create mode 100644 interface_test.go diff --git a/firewall.go b/firewall.go index adecbe81..904c71b2 100644 --- a/firewall.go +++ b/firewall.go @@ -58,8 +58,9 @@ type Firewall struct { routableNetworks *bart.Lite // assignedNetworks is a list of vpn networks assigned to us in the certificate. - assignedNetworks []netip.Prefix - hasUnsafeNetworks bool + assignedNetworks []netip.Prefix + // unsafeNetworks is the list of unsafe networks issued to us in the certificate + unsafeNetworks []netip.Prefix rules string rulesVersion uint16 @@ -158,10 +159,9 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur assignedNetworks = append(assignedNetworks, network) } - hasUnsafeNetworks := false - for _, n := range c.UnsafeNetworks() { + unsafeNetworks := c.UnsafeNetworks() + for _, n := range unsafeNetworks { routableNetworks.Insert(n) - hasUnsafeNetworks = true } return &Firewall{ @@ -169,15 +169,15 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur Conns: make(map[firewall.Packet]*conn), TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax), }, - InRules: newFirewallTable(), - OutRules: newFirewallTable(), - TCPTimeout: tcpTimeout, - UDPTimeout: UDPTimeout, - DefaultTimeout: defaultTimeout, - routableNetworks: routableNetworks, - assignedNetworks: assignedNetworks, - hasUnsafeNetworks: hasUnsafeNetworks, - l: l, + InRules: newFirewallTable(), + OutRules: newFirewallTable(), + TCPTimeout: tcpTimeout, + UDPTimeout: UDPTimeout, + DefaultTimeout: defaultTimeout, + routableNetworks: routableNetworks, + assignedNetworks: assignedNetworks, + unsafeNetworks: unsafeNetworks, + l: l, incomingMetrics: firewallMetrics{ droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil), @@ -897,7 +897,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error { } if localCidr == "" { - if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny { + if len(f.unsafeNetworks) == 0 || f.defaultLocalCIDRAny { flc.Any = true return nil } diff --git a/interface.go b/interface.go index 5fedcdd3..a62457cd 100644 --- a/interface.go +++ b/interface.go @@ -7,6 +7,7 @@ import ( "io" "log/slog" "net/netip" + "slices" "sync" "sync/atomic" "time" @@ -14,6 +15,7 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -375,13 +377,22 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) { } func (f *Interface) reloadFirewall(c *config.C) { - //TODO: need to trigger/detect if the certificate changed too - if c.HasChanged("firewall") == false { + cs := f.pki.getCertState() + curCert := cs.getCertificate(cert.Version2) + if curCert == nil { + curCert = cs.getCertificate(cert.Version1) + } + + // The firewall builds its routableNetworks set from the certificate's UnsafeNetworks at construction. + // Check to see if that set has changed, and if so, rebuild the firewall. + certUnsafeChanged := curCert != nil && !slices.Equal(curCert.UnsafeNetworks(), f.firewall.unsafeNetworks) + + if !c.HasChanged("firewall") && !certUnsafeChanged { f.l.Debug("No firewall config change detected") return } - fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c) + fw, err := NewFirewallFromConfig(f.l, cs, c) if err != nil { f.l.Error("Error while creating firewall during reload", "error", err) return diff --git a/interface_test.go b/interface_test.go new file mode 100644 index 00000000..1b912bbb --- /dev/null +++ b/interface_test.go @@ -0,0 +1,120 @@ +package nebula + +import ( + "net/netip" + "testing" + + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestReloadFirewall_CertUnsafeNetworksChanged verifies that reloadFirewall +// rebuilds the firewall when only the certificate's UnsafeNetworks have changed, +// even if the firewall section of the YAML has not. +func TestReloadFirewall_CertUnsafeNetworksChanged(t *testing.T) { + l := test.NewLogger() + + vpnNet := netip.MustParsePrefix("10.0.0.1/24") + initialUnsafe := []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")} + + // dummyCert avoids dragging the real signing pipeline into a unit test. + c1 := &dummyCert{ + version: cert.Version2, + networks: []netip.Prefix{vpnNet}, + unsafeNetworks: initialUnsafe, + } + pki := &PKI{} + pki.cs.Store(&CertState{v2Cert: c1, initiatingVersion: cert.Version2}) + + rawYAML := `firewall: + outbound: + - port: any + proto: any + host: any + inbound: + - port: any + proto: any + host: any +` + cfg := config.NewC(l) + require.NoError(t, cfg.LoadString(rawYAML)) + + fw, err := NewFirewallFromConfig(l, pki.getCertState(), cfg) + require.NoError(t, err) + require.Equal(t, initialUnsafe, fw.unsafeNetworks) + + f := &Interface{ + pki: pki, + firewall: fw, + l: l, + } + + // Swap the cert with a different UnsafeNetworks set. + newUnsafe := []netip.Prefix{ + netip.MustParsePrefix("198.51.100.0/24"), + netip.MustParsePrefix("203.0.113.0/24"), + } + c2 := &dummyCert{ + version: cert.Version2, + networks: []netip.Prefix{vpnNet}, + unsafeNetworks: newUnsafe, + } + pki.cs.Store(&CertState{v2Cert: c2, initiatingVersion: cert.Version2}) + + // Reload with the same YAML so HasChanged("firewall") reports false. + require.NoError(t, cfg.ReloadConfigString(rawYAML)) + require.False(t, cfg.HasChanged("firewall")) + + f.reloadFirewall(cfg) + + assert.NotSame(t, fw, f.firewall, "firewall pointer should have been replaced") + assert.Equal(t, newUnsafe, f.firewall.unsafeNetworks) + assert.True(t, f.firewall.routableNetworks.Contains(netip.MustParseAddr("203.0.113.5"))) +} + +// TestReloadFirewall_NoChange verifies that reloadFirewall is a no-op when +// neither the firewall config nor the cert's UnsafeNetworks have changed. +func TestReloadFirewall_NoChange(t *testing.T) { + l := test.NewLogger() + + vpnNet := netip.MustParsePrefix("10.0.0.1/24") + unsafe := []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")} + + c1 := &dummyCert{ + version: cert.Version2, + networks: []netip.Prefix{vpnNet}, + unsafeNetworks: unsafe, + } + pki := &PKI{} + pki.cs.Store(&CertState{v2Cert: c1, initiatingVersion: cert.Version2}) + + rawYAML := `firewall: + outbound: + - port: any + proto: any + host: any + inbound: + - port: any + proto: any + host: any +` + cfg := config.NewC(l) + require.NoError(t, cfg.LoadString(rawYAML)) + + fw, err := NewFirewallFromConfig(l, pki.getCertState(), cfg) + require.NoError(t, err) + + f := &Interface{ + pki: pki, + firewall: fw, + l: l, + } + + require.NoError(t, cfg.ReloadConfigString(rawYAML)) + f.reloadFirewall(cfg) + + assert.Same(t, fw, f.firewall, "firewall should not have been replaced") +}