checkpt, try to parse packets only once pt2

This commit is contained in:
JackDoan
2026-05-07 11:26:17 -05:00
parent 0375aff451
commit 5bdf645b0b
10 changed files with 1150 additions and 92 deletions

View File

@@ -211,44 +211,44 @@ func TestFirewall_Drop(t *testing.T) {
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p.Key(), &p, false, &h, cp, nil))
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil))
// Allow outbound because conntrack
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p.Key(), &p, false, &h, cp, nil))
// test remote mismatch
oldRemote := p.RemoteAddr
p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
assert.Equal(t, fw.Drop(p.Key(), &p, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteAddr = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p.Key(), &p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p.Key(), &p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil))
}
func TestFirewall_DropV6(t *testing.T) {
@@ -289,44 +289,44 @@ func TestFirewall_DropV6(t *testing.T) {
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p.Key(), &p, false, &h, cp, nil))
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil))
// Allow outbound because conntrack
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p.Key(), &p, false, &h, cp, nil))
// test remote mismatch
oldRemote := p.RemoteAddr
p.RemoteAddr = netip.MustParseAddr("fd12::56")
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
assert.Equal(t, fw.Drop(p.Key(), &p, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteAddr = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p.Key(), &p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p.Key(), &p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil))
}
func BenchmarkFirewallTable_match(b *testing.B) {
@@ -533,10 +533,10 @@ func TestFirewall_Drop2(t *testing.T) {
cp := cert.NewCAPool()
// h1/c1 lacks the proper groups
require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
require.ErrorIs(t, fw.Drop(p.Key(), &p, true, &h1, cp, nil), ErrNoMatchingRule)
// c has the proper groups
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil))
}
func TestFirewall_Drop3(t *testing.T) {
@@ -613,18 +613,18 @@ func TestFirewall_Drop3(t *testing.T) {
cp := cert.NewCAPool()
// c1 should pass because host match
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
require.NoError(t, fw.Drop(p.Key(), &p, true, &h1, cp, nil))
// c2 should pass because ca sha match
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h2, cp, nil))
require.NoError(t, fw.Drop(p.Key(), &p, true, &h2, cp, nil))
// c3 should fail because no match
resetConntrack(fw)
assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p.Key(), &p, true, &h3, cp, nil), ErrNoMatchingRule)
// Test a remote address match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", ""))
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
require.NoError(t, fw.Drop(p.Key(), &p, true, &h1, cp, nil))
}
func TestFirewall_Drop3V6(t *testing.T) {
@@ -661,7 +661,7 @@ func TestFirewall_Drop3V6(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
cp := cert.NewCAPool()
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil))
}
func TestFirewall_DropConntrackReload(t *testing.T) {
@@ -702,12 +702,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p.Key(), &p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil))
// Allow outbound because conntrack
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p.Key(), &p, false, &h, cp, nil))
oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
@@ -716,7 +716,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
fw.rulesVersion = oldFw.rulesVersion + 1
// Allow outbound because conntrack and new rules allow port 10
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p.Key(), &p, false, &h, cp, nil))
oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
@@ -725,7 +725,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
fw.rulesVersion = oldFw.rulesVersion + 1
// Drop outbound because conntrack doesn't match new ruleset
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p.Key(), &p, false, &h, cp, nil), ErrNoMatchingRule)
}
func TestFirewall_ICMPPortBehavior(t *testing.T) {
@@ -770,12 +770,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
p.LocalPort = 0
p.RemotePort = 0
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p.Key(), p, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p.Key(), p, false, &h, cp, nil))
})
t.Run("nonzero ports", func(t *testing.T) {
@@ -783,12 +783,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
p.LocalPort = 0xabcd
p.RemotePort = 0x1234
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p.Key(), p, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p.Key(), p, false, &h, cp, nil))
})
})
@@ -800,12 +800,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
p.LocalPort = 0
p.RemotePort = 0
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p.Key(), p, true, &h, cp, nil), ErrNoMatchingRule)
//now also allow outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule)
})
t.Run("nonzero ports, still blocked", func(t *testing.T) {
@@ -813,12 +813,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
p.LocalPort = 0xabcd
p.RemotePort = 0x1234
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p.Key(), p, true, &h, cp, nil), ErrNoMatchingRule)
//now also allow outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule)
})
t.Run("nonzero, matching ports, still blocked", func(t *testing.T) {
@@ -826,12 +826,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
p.LocalPort = 80
p.RemotePort = 80
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p.Key(), p, true, &h, cp, nil), ErrNoMatchingRule)
//now also allow outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule)
})
})
t.Run("Any proto, any port", func(t *testing.T) {
@@ -843,12 +843,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
p.LocalPort = 0
p.RemotePort = 0
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p.Key(), p, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p.Key(), p, false, &h, cp, nil))
})
t.Run("nonzero ports, allowed", func(t *testing.T) {
@@ -857,15 +857,15 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
p.LocalPort = 0xabcd
p.RemotePort = 0x1234
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p.Key(), p, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p.Key(), p, false, &h, cp, nil))
//different ID is blocked
p.RemotePort++
require.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
require.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule)
})
})
@@ -913,7 +913,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
Protocol: firewall.ProtoUDP,
Fragment: false,
}
assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
assert.Equal(t, fw.Drop(p.Key(), &p, true, &h1, cp, nil), ErrInvalidRemoteIP)
}
func BenchmarkLookup(b *testing.B) {
@@ -1327,7 +1327,7 @@ func (c *testcase) Test(t *testing.T, fw *Firewall) {
t.Helper()
cp := cert.NewCAPool()
resetConntrack(fw)
err := fw.Drop(c.p, true, c.h, cp, nil)
err := fw.Drop(c.p.Key(), &c.p, true, c.h, cp, nil)
if c.err == nil {
require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
} else {
@@ -1519,6 +1519,6 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
func resetConntrack(fw *Firewall) {
fw.Conntrack.Lock()
fw.Conntrack.Conns = map[firewall.Packet]*conn{}
fw.Conntrack.Conns = map[firewall.PacketKey]*conn{}
fw.Conntrack.Unlock()
}