mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-16 04:47:38 +02:00
Make firewall reload when unsafe networks in the cert changes
This commit is contained in:
30
firewall.go
30
firewall.go
@@ -58,8 +58,9 @@ type Firewall struct {
|
||||
routableNetworks *bart.Lite
|
||||
|
||||
// assignedNetworks is a list of vpn networks assigned to us in the certificate.
|
||||
assignedNetworks []netip.Prefix
|
||||
hasUnsafeNetworks bool
|
||||
assignedNetworks []netip.Prefix
|
||||
// unsafeNetworks is the list of unsafe networks issued to us in the certificate
|
||||
unsafeNetworks []netip.Prefix
|
||||
|
||||
rules string
|
||||
rulesVersion uint16
|
||||
@@ -158,10 +159,9 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur
|
||||
assignedNetworks = append(assignedNetworks, network)
|
||||
}
|
||||
|
||||
hasUnsafeNetworks := false
|
||||
for _, n := range c.UnsafeNetworks() {
|
||||
unsafeNetworks := c.UnsafeNetworks()
|
||||
for _, n := range unsafeNetworks {
|
||||
routableNetworks.Insert(n)
|
||||
hasUnsafeNetworks = true
|
||||
}
|
||||
|
||||
return &Firewall{
|
||||
@@ -169,15 +169,15 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur
|
||||
Conns: make(map[firewall.Packet]*conn),
|
||||
TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax),
|
||||
},
|
||||
InRules: newFirewallTable(),
|
||||
OutRules: newFirewallTable(),
|
||||
TCPTimeout: tcpTimeout,
|
||||
UDPTimeout: UDPTimeout,
|
||||
DefaultTimeout: defaultTimeout,
|
||||
routableNetworks: routableNetworks,
|
||||
assignedNetworks: assignedNetworks,
|
||||
hasUnsafeNetworks: hasUnsafeNetworks,
|
||||
l: l,
|
||||
InRules: newFirewallTable(),
|
||||
OutRules: newFirewallTable(),
|
||||
TCPTimeout: tcpTimeout,
|
||||
UDPTimeout: UDPTimeout,
|
||||
DefaultTimeout: defaultTimeout,
|
||||
routableNetworks: routableNetworks,
|
||||
assignedNetworks: assignedNetworks,
|
||||
unsafeNetworks: unsafeNetworks,
|
||||
l: l,
|
||||
|
||||
incomingMetrics: firewallMetrics{
|
||||
droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil),
|
||||
@@ -897,7 +897,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error {
|
||||
}
|
||||
|
||||
if localCidr == "" {
|
||||
if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
|
||||
if len(f.unsafeNetworks) == 0 || f.defaultLocalCIDRAny {
|
||||
flc.Any = true
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user