use in-Nebula SNAT to send IPv4 UnsafeNetworks traffic over an IPv6 overlay

This commit is contained in:
JackDoan
2026-01-14 12:36:55 -06:00
parent 39452b5eec
commit c2a63499ac
22 changed files with 770 additions and 210 deletions

View File

@@ -21,7 +21,7 @@ import (
func TestNewFirewall(t *testing.T) {
l := test.NewLogger()
c := &dummyCert{}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
conntrack := fw.Conntrack
assert.NotNil(t, conntrack)
assert.NotNil(t, conntrack.Conns)
@@ -36,23 +36,23 @@ func TestNewFirewall(t *testing.T) {
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c, netip.Addr{})
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c, netip.Addr{})
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c, netip.Addr{})
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c, netip.Addr{})
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c, netip.Addr{})
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
}
@@ -63,7 +63,7 @@ func TestFirewall_AddRule(t *testing.T) {
l.SetOutput(ob)
c := &dummyCert{}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
assert.NotNil(t, fw.InRules)
assert.NotNil(t, fw.OutRules)
@@ -79,56 +79,56 @@ func TestFirewall_AddRule(t *testing.T) {
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", ""))
assert.Nil(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", ""))
//no matter what port is given for icmp, it should end up as "any"
assert.Nil(t, fw.InRules.ICMP[firewall.PortAny].Any.Any)
assert.Empty(t, fw.InRules.ICMP[firewall.PortAny].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[firewall.PortAny].Any.Hosts, "h1")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
anyIp, err := netip.ParsePrefix("0.0.0.0/0")
require.NoError(t, err)
@@ -139,7 +139,7 @@ func TestFirewall_AddRule(t *testing.T) {
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
assert.False(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
anyIp6, err := netip.ParsePrefix("::/0")
require.NoError(t, err)
@@ -150,28 +150,28 @@ func TestFirewall_AddRule(t *testing.T) {
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
assert.False(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", ""))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", ""))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", ""))
require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", ""))
}
@@ -208,49 +208,49 @@ func TestFirewall_Drop(t *testing.T) {
}
h.buildNetworks(myVpnNetworksTable, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, nil, false, &h, cp, nil))
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
// Allow outbound because conntrack
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
// test remote mismatch
oldRemote := p.RemoteAddr
p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteAddr = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
}
func TestFirewall_DropV6(t *testing.T) {
@@ -287,49 +287,49 @@ func TestFirewall_DropV6(t *testing.T) {
}
h.buildNetworks(myVpnNetworksTable, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, nil, false, &h, cp, nil))
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
// Allow outbound because conntrack
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
// test remote mismatch
oldRemote := p.RemoteAddr
p.RemoteAddr = netip.MustParseAddr("fd12::56")
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteAddr = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
}
func BenchmarkFirewallTable_match(b *testing.B) {
@@ -532,15 +532,15 @@ func TestFirewall_Drop2(t *testing.T) {
}
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", ""))
cp := cert.NewCAPool()
// h1/c1 lacks the proper groups
require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
require.ErrorIs(t, fw.Drop(p, nil, true, &h1, cp, nil), ErrNoMatchingRule)
// c has the proper groups
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
}
func TestFirewall_Drop3(t *testing.T) {
@@ -612,24 +612,24 @@ func TestFirewall_Drop3(t *testing.T) {
}
h3.buildNetworks(myVpnNetworksTable, c3.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha"))
cp := cert.NewCAPool()
// c1 should pass because host match
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil))
// c2 should pass because ca sha match
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h2, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h2, cp, nil))
// c3 should fail because no match
resetConntrack(fw)
assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, true, &h3, cp, nil), ErrNoMatchingRule)
// Test a remote address match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", ""))
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil))
}
func TestFirewall_Drop3V6(t *testing.T) {
@@ -664,10 +664,10 @@ func TestFirewall_Drop3V6(t *testing.T) {
h.buildNetworks(myVpnNetworksTable, c.Certificate)
// Test a remote address match
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
cp := cert.NewCAPool()
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
}
func TestFirewall_DropConntrackReload(t *testing.T) {
@@ -704,35 +704,35 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
}
h.buildNetworks(myVpnNetworksTable, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
// Allow outbound because conntrack
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
// Allow outbound because conntrack and new rules allow port 10
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
// Drop outbound because conntrack doesn't match new ruleset
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule)
}
func TestFirewall_ICMPPortBehavior(t *testing.T) {
@@ -771,19 +771,19 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
}
t.Run("ICMP allowed", func(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 0, 0, []string{"any"}, "", "", "", "", ""))
t.Run("zero ports", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 0
p.RemotePort = 0
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
})
t.Run("nonzero ports", func(t *testing.T) {
@@ -791,29 +791,29 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
p.LocalPort = 0xabcd
p.RemotePort = 0x1234
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
})
})
t.Run("Any proto, some ports allowed", func(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 444, []string{"any"}, "", "", "", "", ""))
t.Run("zero ports, still blocked", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 0
p.RemotePort = 0
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule)
//now also allow outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
})
t.Run("nonzero ports, still blocked", func(t *testing.T) {
@@ -821,12 +821,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
p.LocalPort = 0xabcd
p.RemotePort = 0x1234
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule)
//now also allow outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
})
t.Run("nonzero, matching ports, still blocked", func(t *testing.T) {
@@ -834,16 +834,16 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
p.LocalPort = 80
p.RemotePort = 80
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule)
//now also allow outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
})
})
t.Run("Any proto, any port", func(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
t.Run("zero ports, allowed", func(t *testing.T) {
resetConntrack(fw)
@@ -851,12 +851,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
p.LocalPort = 0
p.RemotePort = 0
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
})
t.Run("nonzero ports, allowed", func(t *testing.T) {
@@ -865,15 +865,15 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
p.LocalPort = 0xabcd
p.RemotePort = 0x1234
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
//different ID is blocked
p.RemotePort++
require.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
require.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
})
})
@@ -908,7 +908,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
}
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", ""))
cp := cert.NewCAPool()
@@ -922,7 +922,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
Protocol: firewall.ProtoUDP,
Fragment: false,
}
assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
assert.Equal(t, fw.Drop(p, nil, true, &h1, cp, nil), ErrInvalidRemoteIP)
}
func BenchmarkLookup(b *testing.B) {
@@ -1047,53 +1047,53 @@ func TestNewFirewallFromConfig(t *testing.T) {
conf := config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
// Test both port and code
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
// Test missing host, group, cidr, ca_name and ca_sha
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
// Test code/port error
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
// Test proto error
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
// Test cidr parse error
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test local_cidr parse error
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test both group and groups
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
}
@@ -1336,7 +1336,7 @@ func (c *testcase) Test(t *testing.T, fw *Firewall) {
t.Helper()
cp := cert.NewCAPool()
resetConntrack(fw)
err := fw.Drop(c.p, true, c.h, cp, nil)
err := fw.Drop(c.p, nil, true, c.h, cp, nil)
if c.err == nil {
require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
} else {
@@ -1344,7 +1344,7 @@ func (c *testcase) Test(t *testing.T, fw *Firewall) {
}
}
func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
func buildHostinfo(setup testsetup, theirPrefixes ...netip.Prefix) *HostInfo {
c1 := dummyCert{
name: "host1",
networks: theirPrefixes,
@@ -1364,6 +1364,11 @@ func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) te
h.vpnAddrs[i] = theirPrefixes[i].Addr()
}
h.buildNetworks(setup.myVpnNetworksTable, &c1)
return &h
}
func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
h := buildHostinfo(setup, theirPrefixes...)
p := firewall.Packet{
LocalAddr: setup.c.Networks()[0].Addr(), //todo?
RemoteAddr: theirPrefixes[0].Addr(),
@@ -1373,9 +1378,9 @@ func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) te
Fragment: false,
}
return testcase{
h: &h,
h: h,
p: p,
c: &c1,
c: h.ConnectionState.peerCert.Certificate,
err: err,
}
}
@@ -1397,12 +1402,25 @@ func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testse
return newSetupFromCert(t, l, c)
}
func newSnatSetup(t *testing.T, l *logrus.Logger, myPrefix netip.Prefix, snatAddr netip.Addr) testsetup {
c := dummyCert{
name: "me",
networks: []netip.Prefix{myPrefix},
groups: []string{"default-group"},
issuer: "signer-shasum",
}
out := newSetupFromCert(t, l, c)
out.fw.snatAddr = snatAddr
return out
}
func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
myVpnNetworksTable := new(bart.Lite)
for _, prefix := range c.Networks() {
myVpnNetworksTable.Insert(prefix)
}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
return testsetup{
@@ -1532,3 +1550,59 @@ func resetConntrack(fw *Firewall) {
fw.Conntrack.Conns = map[firewall.Packet]*conn{}
fw.Conntrack.Unlock()
}
func TestFirewall_SNAT(t *testing.T) {
t.Parallel()
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
cp := cert.NewCAPool()
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
MyCert := dummyCert{
name: "me",
networks: []netip.Prefix{myPrefix},
groups: []string{"default-group"},
issuer: "signer-shasum",
}
theirPrefix := netip.MustParsePrefix("1.2.2.2/8")
snatAddr := netip.MustParseAddr("169.254.55.96")
t.Run("allow inbound all matching", func(t *testing.T) {
t.Parallel()
myCert := MyCert.Copy()
setup := newSnatSetup(t, l, myPrefix, snatAddr)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, snatAddr)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
resetConntrack(setup.fw)
h := buildHostinfo(setup, theirPrefix)
p := firewall.Packet{
LocalAddr: setup.c.Networks()[0].Addr(), //todo?
RemoteAddr: h.vpnAddrs[0],
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
require.NoError(t, setup.fw.Drop(p, nil, true, h, cp, nil))
})
//t.Run("allow inbound unsafe route", func(t *testing.T) {
// t.Parallel()
// unsafePrefix := netip.MustParsePrefix("192.168.0.0/24")
// c := dummyCert{
// name: "me",
// networks: []netip.Prefix{myPrefix},
// unsafeNetworks: []netip.Prefix{unsafePrefix},
// groups: []string{"default-group"},
// issuer: "signer-shasum",
// }
// unsafeSetup := newSetupFromCert(t, l, c)
// tc := buildTestCase(unsafeSetup, nil, twoPrefixes...)
// tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
// tc.err = ErrNoMatchingRule
// tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
// require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", ""))
// tc.err = nil
// tc.Test(t, unsafeSetup.fw) //should pass
//})
}