Make firewall reload when unsafe networks in the cert changes

This commit is contained in:
Nate Brown
2026-05-08 17:50:40 -05:00
parent ffd5249cf5
commit 43079bae04
4 changed files with 201 additions and 70 deletions

View File

@@ -58,8 +58,9 @@ type Firewall struct {
routableNetworks *bart.Lite routableNetworks *bart.Lite
// assignedNetworks is a list of vpn networks assigned to us in the certificate. // assignedNetworks is a list of vpn networks assigned to us in the certificate.
assignedNetworks []netip.Prefix assignedNetworks []netip.Prefix
hasUnsafeNetworks bool // unsafeNetworks is the list of unsafe networks issued to us in the certificate
unsafeNetworks []netip.Prefix
rules string rules string
rulesVersion uint16 rulesVersion uint16
@@ -158,10 +159,9 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur
assignedNetworks = append(assignedNetworks, network) assignedNetworks = append(assignedNetworks, network)
} }
hasUnsafeNetworks := false unsafeNetworks := c.UnsafeNetworks()
for _, n := range c.UnsafeNetworks() { for _, n := range unsafeNetworks {
routableNetworks.Insert(n) routableNetworks.Insert(n)
hasUnsafeNetworks = true
} }
return &Firewall{ return &Firewall{
@@ -169,15 +169,15 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur
Conns: make(map[firewall.Packet]*conn), Conns: make(map[firewall.Packet]*conn),
TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax), TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax),
}, },
InRules: newFirewallTable(), InRules: newFirewallTable(),
OutRules: newFirewallTable(), OutRules: newFirewallTable(),
TCPTimeout: tcpTimeout, TCPTimeout: tcpTimeout,
UDPTimeout: UDPTimeout, UDPTimeout: UDPTimeout,
DefaultTimeout: defaultTimeout, DefaultTimeout: defaultTimeout,
routableNetworks: routableNetworks, routableNetworks: routableNetworks,
assignedNetworks: assignedNetworks, assignedNetworks: assignedNetworks,
hasUnsafeNetworks: hasUnsafeNetworks, unsafeNetworks: unsafeNetworks,
l: l, l: l,
incomingMetrics: firewallMetrics{ incomingMetrics: firewallMetrics{
droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil), droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil),
@@ -897,7 +897,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error {
} }
if localCidr == "" { if localCidr == "" {
if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny { if len(f.unsafeNetworks) == 0 || f.defaultLocalCIDRAny {
flc.Any = true flc.Any = true
return nil return nil
} }

View File

@@ -7,6 +7,7 @@ import (
"io" "io"
"log/slog" "log/slog"
"net/netip" "net/netip"
"slices"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -14,6 +15,7 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
@@ -375,13 +377,22 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) {
} }
func (f *Interface) reloadFirewall(c *config.C) { func (f *Interface) reloadFirewall(c *config.C) {
//TODO: need to trigger/detect if the certificate changed too cs := f.pki.getCertState()
if c.HasChanged("firewall") == false { curCert := cs.getCertificate(cert.Version2)
if curCert == nil {
curCert = cs.getCertificate(cert.Version1)
}
// The firewall builds its routableNetworks set from the certificate's UnsafeNetworks at construction.
// Check to see if that set has changed, and if so, rebuild the firewall.
certUnsafeChanged := curCert != nil && !slices.Equal(curCert.UnsafeNetworks(), f.firewall.unsafeNetworks)
if !c.HasChanged("firewall") && !certUnsafeChanged {
f.l.Debug("No firewall config change detected") f.l.Debug("No firewall config change detected")
return return
} }
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c) fw, err := NewFirewallFromConfig(f.l, cs, c)
if err != nil { if err != nil {
f.l.Error("Error while creating firewall during reload", "error", err) f.l.Error("Error while creating firewall during reload", "error", err)
return return

73
interface_emit_test.go Normal file
View File

@@ -0,0 +1,73 @@
//go:build linux || darwin
package nebula
import (
"context"
"net/netip"
"testing"
"time"
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/overlay/overlaytest"
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test_emitStats_primesGauges covers issue #907: a Prometheus scrape that
// landed before the first ticker fire used to read 0 for the cert gauges.
// emitStats now primes the gauges before entering the ticker loop. We assert
// the gauge is zero before the first call and non-zero after.
func Test_emitStats_primesGauges(t *testing.T) {
defer metrics.DefaultRegistry.UnregisterAll()
l := test.NewLogger()
hostMap := newHostMap(l)
preferredRanges := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}
hostMap.preferredRanges.Store(&preferredRanges)
notAfter := time.Now().Add(time.Hour)
cs := &CertState{
initiatingVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1, notAfter: notAfter},
v1Credential: nil,
}
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &overlaytest.NoopTun{},
outside: &udp.NoopConn{},
firewall: &Firewall{Conntrack: &FirewallConntrack{Conns: map[firewall.Packet]*conn{}}},
lightHouse: lh,
pki: &PKI{},
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
l: l,
// On linux, udp.NewUDPStatsEmitter indexes writers[0] and asserts to
// *udp.StdConn. A zero value works: getMemInfo sees a nil rawConn,
// returns an error, and the emitter falls through to a no-op.
writers: []udp.Conn{&udp.StdConn{}},
}
ifce.pki.cs.Store(cs)
ttlGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
require.Zero(t, ttlGauge.Value(), "gauge should be zero before emitStats runs")
// Pre-cancel the context so emitStats returns after priming the gauges
// without ever reading from ticker.C. The one hour interval is just a
// belt-and-suspenders, the test does not expect the ticker to fire.
ctx, cancel := context.WithCancel(context.Background())
cancel()
ifce.emitStats(ctx, time.Hour)
ttl := ttlGauge.Value()
assert.Positive(t, ttl, "ttl gauge should be primed by emitStats before its first tick")
assert.LessOrEqual(t, ttl, int64(3600))
assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.initiating_version", nil).Value())
assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.max_version", nil).Value())
}

View File

@@ -1,73 +1,120 @@
//go:build linux || darwin
package nebula package nebula
import ( import (
"context"
"net/netip" "net/netip"
"testing" "testing"
"time"
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/overlaytest"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// Test_emitStats_primesGauges covers issue #907: a Prometheus scrape that // TestReloadFirewall_CertUnsafeNetworksChanged verifies that reloadFirewall
// landed before the first ticker fire used to read 0 for the cert gauges. // rebuilds the firewall when only the certificate's UnsafeNetworks have changed,
// emitStats now primes the gauges before entering the ticker loop. We assert // even if the firewall section of the YAML has not.
// the gauge is zero before the first call and non-zero after. func TestReloadFirewall_CertUnsafeNetworksChanged(t *testing.T) {
func Test_emitStats_primesGauges(t *testing.T) {
defer metrics.DefaultRegistry.UnregisterAll()
l := test.NewLogger() l := test.NewLogger()
hostMap := newHostMap(l)
preferredRanges := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}
hostMap.preferredRanges.Store(&preferredRanges)
notAfter := time.Now().Add(time.Hour) vpnNet := netip.MustParsePrefix("10.0.0.1/24")
cs := &CertState{ initialUnsafe := []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")}
initiatingVersion: cert.Version1,
privateKey: []byte{}, // dummyCert avoids dragging the real signing pipeline into a unit test.
v1Cert: &dummyCert{version: cert.Version1, notAfter: notAfter}, c1 := &dummyCert{
v1Credential: nil, version: cert.Version2,
networks: []netip.Prefix{vpnNet},
unsafeNetworks: initialUnsafe,
}
pki := &PKI{}
pki.cs.Store(&CertState{v2Cert: c1, initiatingVersion: cert.Version2})
rawYAML := `firewall:
outbound:
- port: any
proto: any
host: any
inbound:
- port: any
proto: any
host: any
`
cfg := config.NewC(l)
require.NoError(t, cfg.LoadString(rawYAML))
fw, err := NewFirewallFromConfig(l, pki.getCertState(), cfg)
require.NoError(t, err)
require.Equal(t, initialUnsafe, fw.unsafeNetworks)
f := &Interface{
pki: pki,
firewall: fw,
l: l,
} }
lh := newTestLighthouse() // Swap the cert with a different UnsafeNetworks set.
ifce := &Interface{ newUnsafe := []netip.Prefix{
hostMap: hostMap, netip.MustParsePrefix("198.51.100.0/24"),
inside: &overlaytest.NoopTun{}, netip.MustParsePrefix("203.0.113.0/24"),
outside: &udp.NoopConn{},
firewall: &Firewall{Conntrack: &FirewallConntrack{Conns: map[firewall.Packet]*conn{}}},
lightHouse: lh,
pki: &PKI{},
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
l: l,
// On linux, udp.NewUDPStatsEmitter indexes writers[0] and asserts to
// *udp.StdConn. A zero value works: getMemInfo sees a nil rawConn,
// returns an error, and the emitter falls through to a no-op.
writers: []udp.Conn{&udp.StdConn{}},
} }
ifce.pki.cs.Store(cs) c2 := &dummyCert{
version: cert.Version2,
networks: []netip.Prefix{vpnNet},
unsafeNetworks: newUnsafe,
}
pki.cs.Store(&CertState{v2Cert: c2, initiatingVersion: cert.Version2})
ttlGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil) // Reload with the same YAML so HasChanged("firewall") reports false.
require.Zero(t, ttlGauge.Value(), "gauge should be zero before emitStats runs") require.NoError(t, cfg.ReloadConfigString(rawYAML))
require.False(t, cfg.HasChanged("firewall"))
// Pre-cancel the context so emitStats returns after priming the gauges f.reloadFirewall(cfg)
// without ever reading from ticker.C. The one hour interval is just a
// belt-and-suspenders, the test does not expect the ticker to fire.
ctx, cancel := context.WithCancel(context.Background())
cancel()
ifce.emitStats(ctx, time.Hour)
ttl := ttlGauge.Value() assert.NotSame(t, fw, f.firewall, "firewall pointer should have been replaced")
assert.Positive(t, ttl, "ttl gauge should be primed by emitStats before its first tick") assert.Equal(t, newUnsafe, f.firewall.unsafeNetworks)
assert.LessOrEqual(t, ttl, int64(3600)) assert.True(t, f.firewall.routableNetworks.Contains(netip.MustParseAddr("203.0.113.5")))
assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.initiating_version", nil).Value()) }
assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.max_version", nil).Value())
// TestReloadFirewall_NoChange verifies that reloadFirewall is a no-op when
// neither the firewall config nor the cert's UnsafeNetworks have changed.
func TestReloadFirewall_NoChange(t *testing.T) {
l := test.NewLogger()
vpnNet := netip.MustParsePrefix("10.0.0.1/24")
unsafe := []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")}
c1 := &dummyCert{
version: cert.Version2,
networks: []netip.Prefix{vpnNet},
unsafeNetworks: unsafe,
}
pki := &PKI{}
pki.cs.Store(&CertState{v2Cert: c1, initiatingVersion: cert.Version2})
rawYAML := `firewall:
outbound:
- port: any
proto: any
host: any
inbound:
- port: any
proto: any
host: any
`
cfg := config.NewC(l)
require.NoError(t, cfg.LoadString(rawYAML))
fw, err := NewFirewallFromConfig(l, pki.getCertState(), cfg)
require.NoError(t, err)
f := &Interface{
pki: pki,
firewall: fw,
l: l,
}
require.NoError(t, cfg.ReloadConfigString(rawYAML))
f.reloadFirewall(cfg)
assert.Same(t, fw, f.firewall, "firewall should not have been replaced")
} }