diff --git a/firewall_test.go b/firewall_test.go index 4731a6f..49f4d73 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -68,6 +68,9 @@ func TestFirewall_AddRule(t *testing.T) { ti, err := netip.ParsePrefix("1.2.3.4/32") require.NoError(t, err) + ti6, err := netip.ParsePrefix("fd12::34/128") + require.NoError(t, err) + require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) // An empty rule is any assert.True(t, fw.InRules.TCP[1].Any.Any.Any) @@ -92,12 +95,24 @@ func TestFirewall_AddRule(t *testing.T) { _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) assert.True(t, ok) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", "")) + assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) + _, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6) + assert.True(t, ok) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", "")) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) assert.True(t, ok) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", "")) + assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) + _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6) + assert.True(t, ok) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", "")) assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") @@ -117,6 +132,13 @@ func TestFirewall_AddRule(t *testing.T) { require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + anyIp6, err := netip.ParsePrefix("::/0") + require.NoError(t, err) + + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", "")) + assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) + // Test error conditions fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -199,6 +221,82 @@ func TestFirewall_Drop(t *testing.T) { require.NoError(t, fw.Drop(p, true, &h, cp, nil)) } +func TestFirewall_DropV6(t *testing.T) { + l := test.NewLogger() + ob := &bytes.Buffer{} + l.SetOutput(ob) + + p := firewall.Packet{ + LocalAddr: netip.MustParseAddr("fd12::34"), + RemoteAddr: netip.MustParseAddr("fd12::34"), + LocalPort: 10, + RemotePort: 90, + Protocol: firewall.ProtoUDP, + Fragment: false, + } + + c := dummyCert{ + name: "host1", + networks: []netip.Prefix{netip.MustParsePrefix("fd12::34/120")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + h := HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &cert.CachedCertificate{ + Certificate: &c, + InvertedGroups: map[string]struct{}{"default-group": {}}, + }, + }, + vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")}, + } + h.buildNetworks(c.networks, c.unsafeNetworks) + + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + cp := cert.NewCAPool() + + // Drop outbound + assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil)) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + // Allow outbound because conntrack + require.NoError(t, fw.Drop(p, false, &h, cp, nil)) + + // test remote mismatch + oldRemote := p.RemoteAddr + p.RemoteAddr = netip.MustParseAddr("fd12::56") + assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) + p.RemoteAddr = oldRemote + + // ensure signer doesn't get in the way of group checks + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) + assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) + + // test caSha doesn't drop on match + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + + // ensure ca name doesn't get in the way of group checks + cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) + assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) + + // test caName doesn't drop on match + cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) +} + func BenchmarkFirewallTable_match(b *testing.B) { f := &Firewall{} ft := FirewallTable{ @@ -208,6 +306,10 @@ func BenchmarkFirewallTable_match(b *testing.B) { pfix := netip.MustParsePrefix("172.1.1.1/32") _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "") _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "") + + pfix6 := netip.MustParsePrefix("fd11::11/128") + _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "") + _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "") cp := cert.NewCAPool() b.Run("fail on proto", func(b *testing.B) { @@ -239,6 +341,15 @@ func BenchmarkFirewallTable_match(b *testing.B) { assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp)) } }) + b.Run("pass proto, port, fail on local CIDRv6", func(b *testing.B) { + c := &cert.CachedCertificate{ + Certificate: &dummyCert{}, + } + ip := netip.MustParsePrefix("fd99::99/128") + for n := 0; n < b.N; n++ { + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp)) + } + }) b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) { c := &cert.CachedCertificate{ @@ -252,6 +363,18 @@ func BenchmarkFirewallTable_match(b *testing.B) { assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)) } }) + b.Run("pass proto, port, any local CIDRv6, fail all group, name, and cidr", func(b *testing.B) { + c := &cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "nope", + networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")}, + }, + InvertedGroups: map[string]struct{}{"nope": {}}, + } + for n := 0; n < b.N; n++ { + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)) + } + }) b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) { c := &cert.CachedCertificate{ @@ -265,6 +388,18 @@ func BenchmarkFirewallTable_match(b *testing.B) { assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp)) } }) + b.Run("pass proto, port, specific local CIDRv6, fail all group, name, and cidr", func(b *testing.B) { + c := &cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "nope", + networks: []netip.Prefix{netip.MustParsePrefix("fd99:99/128")}, + }, + InvertedGroups: map[string]struct{}{"nope": {}}, + } + for n := 0; n < b.N; n++ { + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp)) + } + }) b.Run("pass on group on any local cidr", func(b *testing.B) { c := &cert.CachedCertificate{ @@ -289,6 +424,17 @@ func BenchmarkFirewallTable_match(b *testing.B) { assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp)) } }) + b.Run("pass on group on specific local cidr6", func(b *testing.B) { + c := &cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "nope", + }, + InvertedGroups: map[string]struct{}{"good-group": {}}, + } + for n := 0; n < b.N; n++ { + assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp)) + } + }) b.Run("pass on name", func(b *testing.B) { c := &cert.CachedCertificate{ @@ -447,6 +593,42 @@ func TestFirewall_Drop3(t *testing.T) { require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) } +func TestFirewall_Drop3V6(t *testing.T) { + l := test.NewLogger() + ob := &bytes.Buffer{} + l.SetOutput(ob) + + p := firewall.Packet{ + LocalAddr: netip.MustParseAddr("fd12::34"), + RemoteAddr: netip.MustParseAddr("fd12::34"), + LocalPort: 1, + RemotePort: 1, + Protocol: firewall.ProtoUDP, + Fragment: false, + } + + network := netip.MustParsePrefix("fd12::34/120") + c := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host-owner", + networks: []netip.Prefix{network}, + }, + } + h := HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &c, + }, + vpnAddrs: []netip.Addr{network.Addr()}, + } + h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) + + // Test a remote address match + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + cp := cert.NewCAPool() + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", "")) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) +} + func TestFirewall_DropConntrackReload(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} @@ -727,6 +909,21 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall) + // Test adding rule with cidr ipv6 + cidr6 := netip.MustParsePrefix("fd00::/8") + conf = config.NewC(l) + mf = &mockFirewall{} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}} + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall) + + // Test adding rule with local_cidr ipv6 + conf = config.NewC(l) + mf = &mockFirewall{} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}} + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall) + // Test adding rule with ca_sha conf = config.NewC(l) mf = &mockFirewall{}