mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 08:24:25 +01:00
Preserve conntrack table during firewall rules reload (SIGHUP) (#233)
Currently, we drop the conntrack table when firewall rules change during a SIGHUP reload. This means responses to inflight HTTP requests can be dropped, among other issues. This change copies the conntrack table over to the new firewall (it holds the conntrack mutex lock during this process, to be safe). This change also records which firewall rules hash each conntrack entry used, so that we can re-verify the rules after the new firewall has been loaded.
This commit is contained in:
106
firewall_test.go
106
firewall_test.go
@@ -17,37 +17,39 @@ import (
|
||||
func TestNewFirewall(t *testing.T) {
|
||||
c := &cert.NebulaCertificate{}
|
||||
fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||
assert.NotNil(t, fw.Conns)
|
||||
conntrack := fw.Conntrack
|
||||
assert.NotNil(t, conntrack)
|
||||
assert.NotNil(t, conntrack.Conns)
|
||||
assert.NotNil(t, conntrack.TimerWheel)
|
||||
assert.NotNil(t, fw.InRules)
|
||||
assert.NotNil(t, fw.OutRules)
|
||||
assert.NotNil(t, fw.TimerWheel)
|
||||
assert.Equal(t, time.Second, fw.TCPTimeout)
|
||||
assert.Equal(t, time.Minute, fw.UDPTimeout)
|
||||
assert.Equal(t, time.Hour, fw.DefaultTimeout)
|
||||
|
||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||
|
||||
fw = NewFirewall(time.Second, time.Hour, time.Minute, c)
|
||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||
|
||||
fw = NewFirewall(time.Hour, time.Second, time.Minute, c)
|
||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||
|
||||
fw = NewFirewall(time.Hour, time.Minute, time.Second, c)
|
||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||
|
||||
fw = NewFirewall(time.Minute, time.Hour, time.Second, c)
|
||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||
|
||||
fw = NewFirewall(time.Minute, time.Second, time.Hour, c)
|
||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||
}
|
||||
|
||||
func TestFirewall_AddRule(t *testing.T) {
|
||||
@@ -461,6 +463,74 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp), ErrNoMatchingRule)
|
||||
}
|
||||
|
||||
func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
ob := &bytes.Buffer{}
|
||||
out := l.Out
|
||||
l.SetOutput(ob)
|
||||
defer l.SetOutput(out)
|
||||
|
||||
p := FirewallPacket{
|
||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||
10,
|
||||
90,
|
||||
fwProtoUDP,
|
||||
false,
|
||||
}
|
||||
|
||||
ipNet := net.IPNet{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
}
|
||||
|
||||
c := cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "host1",
|
||||
Ips: []*net.IPNet{&ipNet},
|
||||
Groups: []string{"default-group"},
|
||||
InvertedGroups: map[string]struct{}{"default-group": {}},
|
||||
Issuer: "signer-shasum",
|
||||
},
|
||||
}
|
||||
h := HostInfo{
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c,
|
||||
},
|
||||
hostId: ip2int(ipNet.IP),
|
||||
}
|
||||
h.CreateRemoteCIDR(&c)
|
||||
|
||||
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
|
||||
// Allow outbound because conntrack
|
||||
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
|
||||
|
||||
oldFw := fw
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
|
||||
fw.Conntrack = oldFw.Conntrack
|
||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||
|
||||
// Allow outbound because conntrack and new rules allow port 10
|
||||
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
|
||||
|
||||
oldFw = fw
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
|
||||
fw.Conntrack = oldFw.Conntrack
|
||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||
|
||||
// Drop outbound because conntrack doesn't match new ruleset
|
||||
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
|
||||
}
|
||||
|
||||
func BenchmarkLookup(b *testing.B) {
|
||||
ml := func(m map[string]struct{}, a [][]string) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
@@ -861,7 +931,7 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
|
||||
}
|
||||
|
||||
func resetConntrack(fw *Firewall) {
|
||||
fw.connMutex.Lock()
|
||||
fw.Conns = map[FirewallPacket]*conn{}
|
||||
fw.connMutex.Unlock()
|
||||
fw.Conntrack.Lock()
|
||||
fw.Conntrack.Conns = map[FirewallPacket]*conn{}
|
||||
fw.Conntrack.Unlock()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user