mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-15 20:37:36 +02:00
Make firewall reload when unsafe networks in the cert changes
This commit is contained in:
12
firewall.go
12
firewall.go
@@ -59,7 +59,8 @@ type Firewall struct {
|
||||
|
||||
// assignedNetworks is a list of vpn networks assigned to us in the certificate.
|
||||
assignedNetworks []netip.Prefix
|
||||
hasUnsafeNetworks bool
|
||||
// unsafeNetworks is the list of unsafe networks issued to us in the certificate
|
||||
unsafeNetworks []netip.Prefix
|
||||
|
||||
rules string
|
||||
rulesVersion uint16
|
||||
@@ -158,10 +159,9 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur
|
||||
assignedNetworks = append(assignedNetworks, network)
|
||||
}
|
||||
|
||||
hasUnsafeNetworks := false
|
||||
for _, n := range c.UnsafeNetworks() {
|
||||
unsafeNetworks := c.UnsafeNetworks()
|
||||
for _, n := range unsafeNetworks {
|
||||
routableNetworks.Insert(n)
|
||||
hasUnsafeNetworks = true
|
||||
}
|
||||
|
||||
return &Firewall{
|
||||
@@ -176,7 +176,7 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur
|
||||
DefaultTimeout: defaultTimeout,
|
||||
routableNetworks: routableNetworks,
|
||||
assignedNetworks: assignedNetworks,
|
||||
hasUnsafeNetworks: hasUnsafeNetworks,
|
||||
unsafeNetworks: unsafeNetworks,
|
||||
l: l,
|
||||
|
||||
incomingMetrics: firewallMetrics{
|
||||
@@ -897,7 +897,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error {
|
||||
}
|
||||
|
||||
if localCidr == "" {
|
||||
if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
|
||||
if len(f.unsafeNetworks) == 0 || f.defaultLocalCIDRAny {
|
||||
flc.Any = true
|
||||
return nil
|
||||
}
|
||||
|
||||
17
interface.go
17
interface.go
@@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -14,6 +15,7 @@ import (
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/rcrowley/go-metrics"
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/slackhq/nebula/header"
|
||||
@@ -375,13 +377,22 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) {
|
||||
}
|
||||
|
||||
func (f *Interface) reloadFirewall(c *config.C) {
|
||||
//TODO: need to trigger/detect if the certificate changed too
|
||||
if c.HasChanged("firewall") == false {
|
||||
cs := f.pki.getCertState()
|
||||
curCert := cs.getCertificate(cert.Version2)
|
||||
if curCert == nil {
|
||||
curCert = cs.getCertificate(cert.Version1)
|
||||
}
|
||||
|
||||
// The firewall builds its routableNetworks set from the certificate's UnsafeNetworks at construction.
|
||||
// Check to see if that set has changed, and if so, rebuild the firewall.
|
||||
certUnsafeChanged := curCert != nil && !slices.Equal(curCert.UnsafeNetworks(), f.firewall.unsafeNetworks)
|
||||
|
||||
if !c.HasChanged("firewall") && !certUnsafeChanged {
|
||||
f.l.Debug("No firewall config change detected")
|
||||
return
|
||||
}
|
||||
|
||||
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
|
||||
fw, err := NewFirewallFromConfig(f.l, cs, c)
|
||||
if err != nil {
|
||||
f.l.Error("Error while creating firewall during reload", "error", err)
|
||||
return
|
||||
|
||||
120
interface_test.go
Normal file
120
interface_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestReloadFirewall_CertUnsafeNetworksChanged verifies that reloadFirewall
|
||||
// rebuilds the firewall when only the certificate's UnsafeNetworks have changed,
|
||||
// even if the firewall section of the YAML has not.
|
||||
func TestReloadFirewall_CertUnsafeNetworksChanged(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
|
||||
vpnNet := netip.MustParsePrefix("10.0.0.1/24")
|
||||
initialUnsafe := []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")}
|
||||
|
||||
// dummyCert avoids dragging the real signing pipeline into a unit test.
|
||||
c1 := &dummyCert{
|
||||
version: cert.Version2,
|
||||
networks: []netip.Prefix{vpnNet},
|
||||
unsafeNetworks: initialUnsafe,
|
||||
}
|
||||
pki := &PKI{}
|
||||
pki.cs.Store(&CertState{v2Cert: c1, initiatingVersion: cert.Version2})
|
||||
|
||||
rawYAML := `firewall:
|
||||
outbound:
|
||||
- port: any
|
||||
proto: any
|
||||
host: any
|
||||
inbound:
|
||||
- port: any
|
||||
proto: any
|
||||
host: any
|
||||
`
|
||||
cfg := config.NewC(l)
|
||||
require.NoError(t, cfg.LoadString(rawYAML))
|
||||
|
||||
fw, err := NewFirewallFromConfig(l, pki.getCertState(), cfg)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, initialUnsafe, fw.unsafeNetworks)
|
||||
|
||||
f := &Interface{
|
||||
pki: pki,
|
||||
firewall: fw,
|
||||
l: l,
|
||||
}
|
||||
|
||||
// Swap the cert with a different UnsafeNetworks set.
|
||||
newUnsafe := []netip.Prefix{
|
||||
netip.MustParsePrefix("198.51.100.0/24"),
|
||||
netip.MustParsePrefix("203.0.113.0/24"),
|
||||
}
|
||||
c2 := &dummyCert{
|
||||
version: cert.Version2,
|
||||
networks: []netip.Prefix{vpnNet},
|
||||
unsafeNetworks: newUnsafe,
|
||||
}
|
||||
pki.cs.Store(&CertState{v2Cert: c2, initiatingVersion: cert.Version2})
|
||||
|
||||
// Reload with the same YAML so HasChanged("firewall") reports false.
|
||||
require.NoError(t, cfg.ReloadConfigString(rawYAML))
|
||||
require.False(t, cfg.HasChanged("firewall"))
|
||||
|
||||
f.reloadFirewall(cfg)
|
||||
|
||||
assert.NotSame(t, fw, f.firewall, "firewall pointer should have been replaced")
|
||||
assert.Equal(t, newUnsafe, f.firewall.unsafeNetworks)
|
||||
assert.True(t, f.firewall.routableNetworks.Contains(netip.MustParseAddr("203.0.113.5")))
|
||||
}
|
||||
|
||||
// TestReloadFirewall_NoChange verifies that reloadFirewall is a no-op when
|
||||
// neither the firewall config nor the cert's UnsafeNetworks have changed.
|
||||
func TestReloadFirewall_NoChange(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
|
||||
vpnNet := netip.MustParsePrefix("10.0.0.1/24")
|
||||
unsafe := []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")}
|
||||
|
||||
c1 := &dummyCert{
|
||||
version: cert.Version2,
|
||||
networks: []netip.Prefix{vpnNet},
|
||||
unsafeNetworks: unsafe,
|
||||
}
|
||||
pki := &PKI{}
|
||||
pki.cs.Store(&CertState{v2Cert: c1, initiatingVersion: cert.Version2})
|
||||
|
||||
rawYAML := `firewall:
|
||||
outbound:
|
||||
- port: any
|
||||
proto: any
|
||||
host: any
|
||||
inbound:
|
||||
- port: any
|
||||
proto: any
|
||||
host: any
|
||||
`
|
||||
cfg := config.NewC(l)
|
||||
require.NoError(t, cfg.LoadString(rawYAML))
|
||||
|
||||
fw, err := NewFirewallFromConfig(l, pki.getCertState(), cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
f := &Interface{
|
||||
pki: pki,
|
||||
firewall: fw,
|
||||
l: l,
|
||||
}
|
||||
|
||||
require.NoError(t, cfg.ReloadConfigString(rawYAML))
|
||||
f.reloadFirewall(cfg)
|
||||
|
||||
assert.Same(t, fw, f.firewall, "firewall should not have been replaced")
|
||||
}
|
||||
Reference in New Issue
Block a user