mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-16 12:57:38 +02:00
Make firewall reload when unsafe networks in the cert changes
This commit is contained in:
30
firewall.go
30
firewall.go
@@ -58,8 +58,9 @@ type Firewall struct {
|
|||||||
routableNetworks *bart.Lite
|
routableNetworks *bart.Lite
|
||||||
|
|
||||||
// assignedNetworks is a list of vpn networks assigned to us in the certificate.
|
// assignedNetworks is a list of vpn networks assigned to us in the certificate.
|
||||||
assignedNetworks []netip.Prefix
|
assignedNetworks []netip.Prefix
|
||||||
hasUnsafeNetworks bool
|
// unsafeNetworks is the list of unsafe networks issued to us in the certificate
|
||||||
|
unsafeNetworks []netip.Prefix
|
||||||
|
|
||||||
rules string
|
rules string
|
||||||
rulesVersion uint16
|
rulesVersion uint16
|
||||||
@@ -158,10 +159,9 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur
|
|||||||
assignedNetworks = append(assignedNetworks, network)
|
assignedNetworks = append(assignedNetworks, network)
|
||||||
}
|
}
|
||||||
|
|
||||||
hasUnsafeNetworks := false
|
unsafeNetworks := c.UnsafeNetworks()
|
||||||
for _, n := range c.UnsafeNetworks() {
|
for _, n := range unsafeNetworks {
|
||||||
routableNetworks.Insert(n)
|
routableNetworks.Insert(n)
|
||||||
hasUnsafeNetworks = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Firewall{
|
return &Firewall{
|
||||||
@@ -169,15 +169,15 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur
|
|||||||
Conns: make(map[firewall.Packet]*conn),
|
Conns: make(map[firewall.Packet]*conn),
|
||||||
TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax),
|
TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax),
|
||||||
},
|
},
|
||||||
InRules: newFirewallTable(),
|
InRules: newFirewallTable(),
|
||||||
OutRules: newFirewallTable(),
|
OutRules: newFirewallTable(),
|
||||||
TCPTimeout: tcpTimeout,
|
TCPTimeout: tcpTimeout,
|
||||||
UDPTimeout: UDPTimeout,
|
UDPTimeout: UDPTimeout,
|
||||||
DefaultTimeout: defaultTimeout,
|
DefaultTimeout: defaultTimeout,
|
||||||
routableNetworks: routableNetworks,
|
routableNetworks: routableNetworks,
|
||||||
assignedNetworks: assignedNetworks,
|
assignedNetworks: assignedNetworks,
|
||||||
hasUnsafeNetworks: hasUnsafeNetworks,
|
unsafeNetworks: unsafeNetworks,
|
||||||
l: l,
|
l: l,
|
||||||
|
|
||||||
incomingMetrics: firewallMetrics{
|
incomingMetrics: firewallMetrics{
|
||||||
droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil),
|
droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil),
|
||||||
@@ -897,7 +897,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if localCidr == "" {
|
if localCidr == "" {
|
||||||
if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
|
if len(f.unsafeNetworks) == 0 || f.defaultLocalCIDRAny {
|
||||||
flc.Any = true
|
flc.Any = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
17
interface.go
17
interface.go
@@ -7,6 +7,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@@ -14,6 +15,7 @@ import (
|
|||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
@@ -375,13 +377,22 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) reloadFirewall(c *config.C) {
|
func (f *Interface) reloadFirewall(c *config.C) {
|
||||||
//TODO: need to trigger/detect if the certificate changed too
|
cs := f.pki.getCertState()
|
||||||
if c.HasChanged("firewall") == false {
|
curCert := cs.getCertificate(cert.Version2)
|
||||||
|
if curCert == nil {
|
||||||
|
curCert = cs.getCertificate(cert.Version1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The firewall builds its routableNetworks set from the certificate's UnsafeNetworks at construction.
|
||||||
|
// Check to see if that set has changed, and if so, rebuild the firewall.
|
||||||
|
certUnsafeChanged := curCert != nil && !slices.Equal(curCert.UnsafeNetworks(), f.firewall.unsafeNetworks)
|
||||||
|
|
||||||
|
if !c.HasChanged("firewall") && !certUnsafeChanged {
|
||||||
f.l.Debug("No firewall config change detected")
|
f.l.Debug("No firewall config change detected")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
|
fw, err := NewFirewallFromConfig(f.l, cs, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.Error("Error while creating firewall during reload", "error", err)
|
f.l.Error("Error while creating firewall during reload", "error", err)
|
||||||
return
|
return
|
||||||
|
|||||||
73
interface_emit_test.go
Normal file
73
interface_emit_test.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
//go:build linux || darwin
|
||||||
|
|
||||||
|
package nebula
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rcrowley/go-metrics"
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
|
"github.com/slackhq/nebula/firewall"
|
||||||
|
"github.com/slackhq/nebula/overlay/overlaytest"
|
||||||
|
"github.com/slackhq/nebula/test"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test_emitStats_primesGauges covers issue #907: a Prometheus scrape that
|
||||||
|
// landed before the first ticker fire used to read 0 for the cert gauges.
|
||||||
|
// emitStats now primes the gauges before entering the ticker loop. We assert
|
||||||
|
// the gauge is zero before the first call and non-zero after.
|
||||||
|
func Test_emitStats_primesGauges(t *testing.T) {
|
||||||
|
defer metrics.DefaultRegistry.UnregisterAll()
|
||||||
|
|
||||||
|
l := test.NewLogger()
|
||||||
|
hostMap := newHostMap(l)
|
||||||
|
preferredRanges := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}
|
||||||
|
hostMap.preferredRanges.Store(&preferredRanges)
|
||||||
|
|
||||||
|
notAfter := time.Now().Add(time.Hour)
|
||||||
|
cs := &CertState{
|
||||||
|
initiatingVersion: cert.Version1,
|
||||||
|
privateKey: []byte{},
|
||||||
|
v1Cert: &dummyCert{version: cert.Version1, notAfter: notAfter},
|
||||||
|
v1Credential: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
lh := newTestLighthouse()
|
||||||
|
ifce := &Interface{
|
||||||
|
hostMap: hostMap,
|
||||||
|
inside: &overlaytest.NoopTun{},
|
||||||
|
outside: &udp.NoopConn{},
|
||||||
|
firewall: &Firewall{Conntrack: &FirewallConntrack{Conns: map[firewall.Packet]*conn{}}},
|
||||||
|
lightHouse: lh,
|
||||||
|
pki: &PKI{},
|
||||||
|
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
|
||||||
|
l: l,
|
||||||
|
// On linux, udp.NewUDPStatsEmitter indexes writers[0] and asserts to
|
||||||
|
// *udp.StdConn. A zero value works: getMemInfo sees a nil rawConn,
|
||||||
|
// returns an error, and the emitter falls through to a no-op.
|
||||||
|
writers: []udp.Conn{&udp.StdConn{}},
|
||||||
|
}
|
||||||
|
ifce.pki.cs.Store(cs)
|
||||||
|
|
||||||
|
ttlGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
|
||||||
|
require.Zero(t, ttlGauge.Value(), "gauge should be zero before emitStats runs")
|
||||||
|
|
||||||
|
// Pre-cancel the context so emitStats returns after priming the gauges
|
||||||
|
// without ever reading from ticker.C. The one hour interval is just a
|
||||||
|
// belt-and-suspenders, the test does not expect the ticker to fire.
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
ifce.emitStats(ctx, time.Hour)
|
||||||
|
|
||||||
|
ttl := ttlGauge.Value()
|
||||||
|
assert.Positive(t, ttl, "ttl gauge should be primed by emitStats before its first tick")
|
||||||
|
assert.LessOrEqual(t, ttl, int64(3600))
|
||||||
|
assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.initiating_version", nil).Value())
|
||||||
|
assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.max_version", nil).Value())
|
||||||
|
}
|
||||||
@@ -1,73 +1,120 @@
|
|||||||
//go:build linux || darwin
|
|
||||||
|
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/overlay/overlaytest"
|
|
||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/slackhq/nebula/udp"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Test_emitStats_primesGauges covers issue #907: a Prometheus scrape that
|
// TestReloadFirewall_CertUnsafeNetworksChanged verifies that reloadFirewall
|
||||||
// landed before the first ticker fire used to read 0 for the cert gauges.
|
// rebuilds the firewall when only the certificate's UnsafeNetworks have changed,
|
||||||
// emitStats now primes the gauges before entering the ticker loop. We assert
|
// even if the firewall section of the YAML has not.
|
||||||
// the gauge is zero before the first call and non-zero after.
|
func TestReloadFirewall_CertUnsafeNetworksChanged(t *testing.T) {
|
||||||
func Test_emitStats_primesGauges(t *testing.T) {
|
|
||||||
defer metrics.DefaultRegistry.UnregisterAll()
|
|
||||||
|
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
hostMap := newHostMap(l)
|
|
||||||
preferredRanges := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}
|
|
||||||
hostMap.preferredRanges.Store(&preferredRanges)
|
|
||||||
|
|
||||||
notAfter := time.Now().Add(time.Hour)
|
vpnNet := netip.MustParsePrefix("10.0.0.1/24")
|
||||||
cs := &CertState{
|
initialUnsafe := []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")}
|
||||||
initiatingVersion: cert.Version1,
|
|
||||||
privateKey: []byte{},
|
// dummyCert avoids dragging the real signing pipeline into a unit test.
|
||||||
v1Cert: &dummyCert{version: cert.Version1, notAfter: notAfter},
|
c1 := &dummyCert{
|
||||||
v1Credential: nil,
|
version: cert.Version2,
|
||||||
|
networks: []netip.Prefix{vpnNet},
|
||||||
|
unsafeNetworks: initialUnsafe,
|
||||||
|
}
|
||||||
|
pki := &PKI{}
|
||||||
|
pki.cs.Store(&CertState{v2Cert: c1, initiatingVersion: cert.Version2})
|
||||||
|
|
||||||
|
rawYAML := `firewall:
|
||||||
|
outbound:
|
||||||
|
- port: any
|
||||||
|
proto: any
|
||||||
|
host: any
|
||||||
|
inbound:
|
||||||
|
- port: any
|
||||||
|
proto: any
|
||||||
|
host: any
|
||||||
|
`
|
||||||
|
cfg := config.NewC(l)
|
||||||
|
require.NoError(t, cfg.LoadString(rawYAML))
|
||||||
|
|
||||||
|
fw, err := NewFirewallFromConfig(l, pki.getCertState(), cfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, initialUnsafe, fw.unsafeNetworks)
|
||||||
|
|
||||||
|
f := &Interface{
|
||||||
|
pki: pki,
|
||||||
|
firewall: fw,
|
||||||
|
l: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
lh := newTestLighthouse()
|
// Swap the cert with a different UnsafeNetworks set.
|
||||||
ifce := &Interface{
|
newUnsafe := []netip.Prefix{
|
||||||
hostMap: hostMap,
|
netip.MustParsePrefix("198.51.100.0/24"),
|
||||||
inside: &overlaytest.NoopTun{},
|
netip.MustParsePrefix("203.0.113.0/24"),
|
||||||
outside: &udp.NoopConn{},
|
|
||||||
firewall: &Firewall{Conntrack: &FirewallConntrack{Conns: map[firewall.Packet]*conn{}}},
|
|
||||||
lightHouse: lh,
|
|
||||||
pki: &PKI{},
|
|
||||||
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
|
|
||||||
l: l,
|
|
||||||
// On linux, udp.NewUDPStatsEmitter indexes writers[0] and asserts to
|
|
||||||
// *udp.StdConn. A zero value works: getMemInfo sees a nil rawConn,
|
|
||||||
// returns an error, and the emitter falls through to a no-op.
|
|
||||||
writers: []udp.Conn{&udp.StdConn{}},
|
|
||||||
}
|
}
|
||||||
ifce.pki.cs.Store(cs)
|
c2 := &dummyCert{
|
||||||
|
version: cert.Version2,
|
||||||
|
networks: []netip.Prefix{vpnNet},
|
||||||
|
unsafeNetworks: newUnsafe,
|
||||||
|
}
|
||||||
|
pki.cs.Store(&CertState{v2Cert: c2, initiatingVersion: cert.Version2})
|
||||||
|
|
||||||
ttlGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
|
// Reload with the same YAML so HasChanged("firewall") reports false.
|
||||||
require.Zero(t, ttlGauge.Value(), "gauge should be zero before emitStats runs")
|
require.NoError(t, cfg.ReloadConfigString(rawYAML))
|
||||||
|
require.False(t, cfg.HasChanged("firewall"))
|
||||||
|
|
||||||
// Pre-cancel the context so emitStats returns after priming the gauges
|
f.reloadFirewall(cfg)
|
||||||
// without ever reading from ticker.C. The one hour interval is just a
|
|
||||||
// belt-and-suspenders, the test does not expect the ticker to fire.
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
cancel()
|
|
||||||
ifce.emitStats(ctx, time.Hour)
|
|
||||||
|
|
||||||
ttl := ttlGauge.Value()
|
assert.NotSame(t, fw, f.firewall, "firewall pointer should have been replaced")
|
||||||
assert.Positive(t, ttl, "ttl gauge should be primed by emitStats before its first tick")
|
assert.Equal(t, newUnsafe, f.firewall.unsafeNetworks)
|
||||||
assert.LessOrEqual(t, ttl, int64(3600))
|
assert.True(t, f.firewall.routableNetworks.Contains(netip.MustParseAddr("203.0.113.5")))
|
||||||
assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.initiating_version", nil).Value())
|
}
|
||||||
assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.max_version", nil).Value())
|
|
||||||
|
// TestReloadFirewall_NoChange verifies that reloadFirewall is a no-op when
|
||||||
|
// neither the firewall config nor the cert's UnsafeNetworks have changed.
|
||||||
|
func TestReloadFirewall_NoChange(t *testing.T) {
|
||||||
|
l := test.NewLogger()
|
||||||
|
|
||||||
|
vpnNet := netip.MustParsePrefix("10.0.0.1/24")
|
||||||
|
unsafe := []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")}
|
||||||
|
|
||||||
|
c1 := &dummyCert{
|
||||||
|
version: cert.Version2,
|
||||||
|
networks: []netip.Prefix{vpnNet},
|
||||||
|
unsafeNetworks: unsafe,
|
||||||
|
}
|
||||||
|
pki := &PKI{}
|
||||||
|
pki.cs.Store(&CertState{v2Cert: c1, initiatingVersion: cert.Version2})
|
||||||
|
|
||||||
|
rawYAML := `firewall:
|
||||||
|
outbound:
|
||||||
|
- port: any
|
||||||
|
proto: any
|
||||||
|
host: any
|
||||||
|
inbound:
|
||||||
|
- port: any
|
||||||
|
proto: any
|
||||||
|
host: any
|
||||||
|
`
|
||||||
|
cfg := config.NewC(l)
|
||||||
|
require.NoError(t, cfg.LoadString(rawYAML))
|
||||||
|
|
||||||
|
fw, err := NewFirewallFromConfig(l, pki.getCertState(), cfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
f := &Interface{
|
||||||
|
pki: pki,
|
||||||
|
firewall: fw,
|
||||||
|
l: l,
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, cfg.ReloadConfigString(rawYAML))
|
||||||
|
f.reloadFirewall(cfg)
|
||||||
|
|
||||||
|
assert.Same(t, fw, f.firewall, "firewall should not have been replaced")
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user