From 629700fbb6a5279415869f3fc147c5b27cde14bc Mon Sep 17 00:00:00 2001 From: JackDoan Date: Thu, 26 Feb 2026 10:58:10 -0600 Subject: [PATCH] feedback --- firewall.go | 3 +- firewall_test.go | 92 ++++++++++++++++++++++++------------------------ snat_test.go | 44 +++++++++++------------ 3 files changed, 69 insertions(+), 70 deletions(-) diff --git a/firewall.go b/firewall.go index 9eb2a670..4d8d7b3b 100644 --- a/firewall.go +++ b/firewall.go @@ -165,7 +165,7 @@ type firewallLocalCIDR struct { // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. // The certificate provided should be the highest version loaded in memory. -func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate, snatAddr netip.Addr) *Firewall { +func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall { //TODO: error on 0 duration var tmin, tmax time.Duration @@ -241,7 +241,6 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), certificate, - netip.Addr{}, ) fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false) diff --git a/firewall_test.go b/firewall_test.go index c42cad65..4df6eadd 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -21,7 +21,7 @@ import ( func TestNewFirewall(t *testing.T) { l := test.NewLogger() c := &dummyCert{} - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) conntrack := fw.Conntrack assert.NotNil(t, conntrack) assert.NotNil(t, conntrack.Conns) @@ -36,23 +36,23 @@ func TestNewFirewall(t *testing.T) { assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) - fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) - fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c, netip.Addr{}) + fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) - fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c, netip.Addr{}) + fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) - fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c, netip.Addr{}) + fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) - fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) } @@ -63,7 +63,7 @@ func TestFirewall_AddRule(t *testing.T) { l.SetOutput(ob) c := &dummyCert{} - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.NotNil(t, fw.InRules) assert.NotNil(t, fw.OutRules) @@ -79,56 +79,56 @@ func TestFirewall_AddRule(t *testing.T) { assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "")) assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", "")) //no matter what port is given for icmp, it should end up as "any" assert.Nil(t, fw.InRules.ICMP[firewall.PortAny].Any.Any) assert.Empty(t, fw.InRules.ICMP[firewall.PortAny].Any.Groups) assert.Contains(t, fw.InRules.ICMP[firewall.PortAny].Any.Hosts, "h1") - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", "")) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) assert.True(t, ok) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", "")) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) _, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6) assert.True(t, ok) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", "")) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) assert.True(t, ok) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", "")) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6) assert.True(t, ok) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", "")) assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha")) assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) anyIp, err := netip.ParsePrefix("0.0.0.0/0") require.NoError(t, err) @@ -139,7 +139,7 @@ func TestFirewall_AddRule(t *testing.T) { table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9")) assert.False(t, ok) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) anyIp6, err := netip.ParsePrefix("::/0") require.NoError(t, err) @@ -150,28 +150,28 @@ func TestFirewall_AddRule(t *testing.T) { table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1")) assert.False(t, ok) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", "")) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1"))) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9"))) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", "")) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9"))) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1"))) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) // Test error conditions - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", "")) require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", "")) } @@ -208,7 +208,7 @@ func TestFirewall_Drop(t *testing.T) { } h.buildNetworks(myVpnNetworksTable, &c) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) cp := cert.NewCAPool() @@ -227,27 +227,27 @@ func TestFirewall_Drop(t *testing.T) { p.RemoteAddr = oldRemote // ensure signer doesn't get in the way of group checks - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) @@ -287,7 +287,7 @@ func TestFirewall_DropV6(t *testing.T) { } h.buildNetworks(myVpnNetworksTable, &c) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) cp := cert.NewCAPool() @@ -306,27 +306,27 @@ func TestFirewall_DropV6(t *testing.T) { p.RemoteAddr = oldRemote // ensure signer doesn't get in the way of group checks - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) @@ -532,7 +532,7 @@ func TestFirewall_Drop2(t *testing.T) { } h1.buildNetworks(myVpnNetworksTable, c1.Certificate) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", "")) cp := cert.NewCAPool() @@ -612,7 +612,7 @@ func TestFirewall_Drop3(t *testing.T) { } h3.buildNetworks(myVpnNetworksTable, c3.Certificate) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha")) cp := cert.NewCAPool() @@ -627,7 +627,7 @@ func TestFirewall_Drop3(t *testing.T) { assert.Equal(t, fw.Drop(p, nil, true, &h3, cp, nil), ErrNoMatchingRule) // Test a remote address match - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", "")) require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil)) } @@ -664,7 +664,7 @@ func TestFirewall_Drop3V6(t *testing.T) { h.buildNetworks(myVpnNetworksTable, c.Certificate) // Test a remote address match - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) cp := cert.NewCAPool() require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", "")) require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) @@ -704,7 +704,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { } h.buildNetworks(myVpnNetworksTable, c.Certificate) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) cp := cert.NewCAPool() @@ -717,7 +717,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil)) oldFw := fw - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -726,7 +726,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil)) oldFw = fw - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -771,7 +771,7 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { } t.Run("ICMP allowed", func(t *testing.T) { - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 0, 0, []string{"any"}, "", "", "", "", "")) t.Run("zero ports", func(t *testing.T) { p := templ.Copy() @@ -801,7 +801,7 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { }) t.Run("Any proto, some ports allowed", func(t *testing.T) { - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 444, []string{"any"}, "", "", "", "", "")) t.Run("zero ports, still blocked", func(t *testing.T) { p := templ.Copy() @@ -843,7 +843,7 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { }) }) t.Run("Any proto, any port", func(t *testing.T) { - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) t.Run("zero ports, allowed", func(t *testing.T) { resetConntrack(fw) @@ -908,7 +908,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) { } h1.buildNetworks(myVpnNetworksTable, c1.Certificate) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "")) cp := cert.NewCAPool() @@ -1420,7 +1420,7 @@ func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup { for _, prefix := range c.Networks() { myVpnNetworksTable.Insert(prefix) } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) return testsetup{ @@ -1572,7 +1572,7 @@ func TestFirewall_SNAT(t *testing.T) { t.Parallel() myCert := MyCert.Copy() setup := newSnatSetup(t, l, myPrefix, snatAddr) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) resetConntrack(setup.fw) h := buildHostinfo(setup, theirPrefix) diff --git a/snat_test.go b/snat_test.go index 8fb43a2b..14d55f4a 100644 --- a/snat_test.go +++ b/snat_test.go @@ -413,7 +413,7 @@ func TestFirewall_FindUsableSNATPort(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fp := firewall.Packet{ LocalAddr: netip.MustParseAddr("192.168.1.1"), @@ -434,7 +434,7 @@ func TestFirewall_FindUsableSNATPort(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fp := firewall.Packet{ LocalAddr: netip.MustParseAddr("192.168.1.1"), @@ -459,7 +459,7 @@ func TestFirewall_FindUsableSNATPort(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) // Fill all ports baseFP := firewall.Packet{ @@ -498,7 +498,7 @@ func TestFirewall_ApplySnat(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw.snatAddr = snatAddr pkt := slices.Clone(canonicalUDPTest) @@ -531,7 +531,7 @@ func TestFirewall_ApplySnat(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw.snatAddr = snatAddr pkt := slices.Clone(canonicalUDPTest) @@ -564,7 +564,7 @@ func TestFirewall_ApplySnat(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw.snatAddr = snatAddr pkt := slices.Clone(canonicalUDPTest) @@ -593,7 +593,7 @@ func TestFirewall_ApplySnat(t *testing.T) { c := &dummyCert{ networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) pkt := slices.Clone(canonicalUDPTest) fp := firewall.Packet{ @@ -615,7 +615,7 @@ func TestFirewall_ApplySnat(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw.snatAddr = snatAddr pkt := slices.Clone(canonicalUDPTest) @@ -648,7 +648,7 @@ func TestFirewall_UnSnat(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw.snatAddr = snatAddr // Create a conntrack entry for the snatted flow @@ -693,7 +693,7 @@ func TestFirewall_UnSnat(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw.snatAddr = snatAddr pkt := slices.Clone(canonicalUDPReply) @@ -727,7 +727,7 @@ func TestFirewall_Drop_SNATFullFlow(t *testing.T) { issuer: "signer-shasum", } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert) fw.snatAddr = snatAddr require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "any", "", "")) @@ -816,7 +816,7 @@ func TestFirewall_ApplySnat_CrossHostHijack(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw.snatAddr = snatAddr // Simulate Host A having established a flow @@ -860,7 +860,7 @@ func TestFirewall_ApplySnat_MixedStackRejected(t *testing.T) { } t.Run("v6 first then v4", func(t *testing.T) { - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw.snatAddr = snatAddr pkt := slices.Clone(canonicalUDPTest) @@ -887,7 +887,7 @@ func TestFirewall_ApplySnat_MixedStackRejected(t *testing.T) { }) t.Run("v4 first then v6", func(t *testing.T) { - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw.snatAddr = snatAddr pkt := slices.Clone(canonicalUDPTest) @@ -923,7 +923,7 @@ func TestFirewall_ApplySnat_PacketUnmodifiedOnError(t *testing.T) { c := &dummyCert{ networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) pkt := slices.Clone(canonicalUDPTest) @@ -948,7 +948,7 @@ func TestFirewall_ApplySnat_PacketUnmodifiedOnError(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw.snatAddr = snatAddr pkt := slices.Clone(canonicalUDPTest) @@ -980,7 +980,7 @@ func TestFirewall_ApplySnat_PacketUnmodifiedOnError(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw.snatAddr = snatAddr pkt := slices.Clone(canonicalUDPTest) @@ -1013,7 +1013,7 @@ func TestFirewall_UnSnat_NonSNATConntrack(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw.snatAddr = snatAddr // Create a conntrack entry with snat=nil (a normal non-SNAT connection) @@ -1061,7 +1061,7 @@ func TestFirewall_Drop_FirewallBlocksSNAT(t *testing.T) { issuer: "signer-shasum", } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert) fw.snatAddr = snatAddr // Only allow port 80 inbound require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 80, []string{"any"}, "", "", "any", "", "")) @@ -1121,7 +1121,7 @@ func TestFirewall_Drop_SNATLocalAddrNotRoutable(t *testing.T) { issuer: "signer-shasum", } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert) fw.snatAddr = snatAddr require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "any", "", "")) @@ -1176,7 +1176,7 @@ func TestFirewall_Drop_NoSnatAddrRejectsV6Peer(t *testing.T) { issuer: "signer-shasum", } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "any", "", "")) peerV6Addr := netip.MustParseAddr("fd00::2") @@ -1236,7 +1236,7 @@ func TestFirewall_Drop_IPv4HostNotSNATted(t *testing.T) { issuer: "signer-shasum", } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert) fw.snatAddr = snatAddr require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "any", "", ""))