From 56657065e03cb1d6e6664266ac13b4e67bd84fd1 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Tue, 17 Dec 2019 23:36:12 -0800 Subject: [PATCH 1/5] Fix ca* checks --- firewall.go | 134 ++++++++++++++++++++++++++++++++--------------- firewall_test.go | 81 ++++++++++++++-------------- 2 files changed, 135 insertions(+), 80 deletions(-) diff --git a/firewall.go b/firewall.go index 763a66d..1c5ec9b 100644 --- a/firewall.go +++ b/firewall.go @@ -83,19 +83,23 @@ func newFirewallTable() *FirewallTable { } } +type FirewallCA struct { + Any *FirewallRule + CANames map[string]*FirewallRule + CAShas map[string]*FirewallRule +} + type FirewallRule struct { - // Any makes Hosts, Groups, and CIDR irrelevant. CAName and CASha still need to be checked - Any bool - Hosts map[string]struct{} - Groups [][]string - CIDR *CIDRTree - CANames map[string]struct{} - CAShas map[string]struct{} + // Any makes Hosts, Groups, and CIDR irrelevant + Any bool + Hosts map[string]struct{} + Groups [][]string + CIDR *CIDRTree } // Even though ports are uint16, int32 maps are faster for lookup // Plus we can use `-1` for fragment rules -type firewallPort map[int32]*FirewallRule +type firewallPort map[int32]*FirewallCA type FirewallPacket struct { LocalIP uint32 @@ -182,9 +186,9 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, error) { fw := NewFirewall( - c.GetDuration("firewall.conntrack.tcp_timeout", time.Duration(time.Minute*12)), - c.GetDuration("firewall.conntrack.udp_timeout", time.Duration(time.Minute*3)), - c.GetDuration("firewall.conntrack.default_timeout", time.Duration(time.Minute*10)), + c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), + c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), + c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), nc, //TODO: max_connections ) @@ -499,12 +503,9 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, for i := startPort; i <= endPort; i++ { if _, ok := fp[i]; !ok { - fp[i] = &FirewallRule{ - Groups: make([][]string, 0), - Hosts: make(map[string]struct{}), - CIDR: NewCIDRTree(), - CANames: make(map[string]struct{}), - CAShas: make(map[string]struct{}), + fp[i] = &FirewallCA{ + CANames: make(map[string]*FirewallRule), + CAShas: make(map[string]*FirewallRule), } } @@ -539,15 +540,83 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert return fp[fwPortAny].match(p, c, caPool) } -func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, caName string, caSha string) error { - if caName != "" { - fr.CANames[caName] = struct{}{} +func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error { + // If there is an any rule then there is no need to establish specific ca rules + if fc.Any != nil { + return fc.Any.addRule(groups, host, ip) + } + + fr := func() *FirewallRule { + return &FirewallRule{ + Hosts: make(map[string]struct{}), + Groups: make([][]string, 0), + CIDR: NewCIDRTree(), + } + } + + any := false + if caSha == "" && caName == "" { + any = true + } + + if any { + if fc.Any == nil { + fc.Any = fr() + } + + // If it's any we need to wipe out any pre-existing rules to save on memory + fc.CAShas = make(map[string]*FirewallRule) + fc.CANames = make(map[string]*FirewallRule) + return fc.Any.addRule(groups, host, ip) } if caSha != "" { - fr.CAShas[caSha] = struct{}{} + if _, ok := fc.CAShas[caSha]; !ok { + fc.CAShas[caSha] = fr() + } + err := fc.CAShas[caSha].addRule(groups, host, ip) + if err != nil { + return err + } } + if caName != "" { + if _, ok := fc.CANames[caName]; !ok { + fc.CANames[caName] = fr() + } + err := fc.CANames[caName].addRule(groups, host, ip) + if err != nil { + return err + } + } + + return nil +} + +func (fc *FirewallCA) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { + if fc == nil { + return false + } + + if fc.Any != nil { + return fc.Any.match(p, c) + } + + if t, ok := fc.CAShas[c.Details.Issuer]; ok { + if t.match(p, c) { + return true + } + } + + s, err := caPool.GetCAForCert(c) + if err != nil { + return false + } + + return fc.CANames[s.Details.Name].match(p, c) +} + +func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error { if fr.Any { return nil } @@ -593,28 +662,11 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool return false } -func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { +func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate) bool { if fr == nil { return false } - // CASha and CAName always need to be checked - if len(fr.CAShas) > 0 { - if _, ok := fr.CAShas[c.Details.Issuer]; !ok { - return false - } - } - - if len(fr.CANames) > 0 { - s, err := caPool.GetCAForCert(c) - if err != nil { - return false - } - if _, ok := fr.CANames[s.Details.Name]; !ok { - return false - } - } - // Shortcut path for if groups, hosts, or cidr contained an `any` if fr.Any { return true @@ -773,7 +825,7 @@ func setTCPRTTTracking(c *conn, p []byte) { ihl := int(p[0]&0x0f) << 2 // Don't track FIN packets - if uint8(p[ihl+13])&tcpFIN != 0 { + if p[ihl+13]&tcpFIN != 0 { return } @@ -787,7 +839,7 @@ func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool { } ihl := int(p[0]&0x0f) << 2 - if uint8(p[ihl+13])&tcpACK == 0 { + if p[ihl+13]&tcpACK == 0 { return false } diff --git a/firewall_test.go b/firewall_test.go index 371bb91..b897b44 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "errors" + "fmt" "math" "net" "testing" @@ -61,37 +62,37 @@ func TestFirewall_AddRule(t *testing.T) { assert.Nil(t, fw.AddRule(true, fwProtoTCP, 1, 1, []string{}, "", nil, "", "")) // Make sure an empty rule creates structure but doesn't allow anything to flow //TODO: ideally an empty rule would return an error - assert.False(t, fw.InRules.TCP[1].Any) - assert.Empty(t, fw.InRules.TCP[1].Groups) - assert.Empty(t, fw.InRules.TCP[1].Hosts) - assert.Nil(t, fw.InRules.TCP[1].CIDR.root.left) - assert.Nil(t, fw.InRules.TCP[1].CIDR.root.right) - assert.Nil(t, fw.InRules.TCP[1].CIDR.root.value) + assert.False(t, fw.InRules.TCP[1].Any.Any) + assert.Empty(t, fw.InRules.TCP[1].Any.Groups) + assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) + assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.left) + assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right) + assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value) fw = NewFirewall(time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "")) - assert.False(t, fw.InRules.UDP[1].Any) - assert.Contains(t, fw.InRules.UDP[1].Groups[0], "g1") - assert.Empty(t, fw.InRules.UDP[1].Hosts) - assert.Nil(t, fw.InRules.UDP[1].CIDR.root.left) - assert.Nil(t, fw.InRules.UDP[1].CIDR.root.right) - assert.Nil(t, fw.InRules.UDP[1].CIDR.root.value) + assert.False(t, fw.InRules.UDP[1].Any.Any) + assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1") + assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) + assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.left) + assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right) + assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value) fw = NewFirewall(time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", "")) - assert.False(t, fw.InRules.ICMP[1].Any) - assert.Empty(t, fw.InRules.ICMP[1].Groups) - assert.Contains(t, fw.InRules.ICMP[1].Hosts, "h1") - assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.left) - assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.right) - assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.value) + assert.False(t, fw.InRules.ICMP[1].Any.Any) + assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) + assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") + assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.left) + assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right) + assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value) fw = NewFirewall(time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", "")) - assert.False(t, fw.OutRules.AnyProto[1].Any) - assert.Empty(t, fw.OutRules.AnyProto[1].Groups) - assert.Empty(t, fw.OutRules.AnyProto[1].Hosts) - assert.NotNil(t, fw.OutRules.AnyProto[1].CIDR.Match(ip2int(ti.IP))) + assert.False(t, fw.OutRules.AnyProto[1].Any.Any) + assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups) + assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts) + assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP))) fw = NewFirewall(time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", "")) @@ -104,28 +105,30 @@ func TestFirewall_AddRule(t *testing.T) { // Set any and clear fields fw = NewFirewall(time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", "")) - assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Groups[0]) - assert.Contains(t, fw.OutRules.AnyProto[0].Hosts, "h1") - assert.NotNil(t, fw.OutRules.AnyProto[0].CIDR.Match(ip2int(ti.IP))) + assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0]) + assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1") + assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(ip2int(ti.IP))) // run twice just to make sure + //TODO: these ANY rules should clear the CA firewall portion assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", "")) - assert.True(t, fw.OutRules.AnyProto[0].Any) - assert.Empty(t, fw.OutRules.AnyProto[0].Groups) - assert.Empty(t, fw.OutRules.AnyProto[0].Hosts) - assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.left) - assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.right) - assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.value) + assert.True(t, fw.OutRules.AnyProto[0].Any.Any) + assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups) + assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts) + assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.left) + assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right) + assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value) + fmt.Printf("%+v\n", fw.OutRules.AnyProto[0]) fw = NewFirewall(time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", "")) - assert.True(t, fw.OutRules.AnyProto[0].Any) + assert.True(t, fw.OutRules.AnyProto[0].Any.Any) fw = NewFirewall(time.Second, time.Minute, time.Hour, c) _, anyIp, _ := net.ParseCIDR("0.0.0.0/0") assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", "")) - assert.True(t, fw.OutRules.AnyProto[0].Any) + assert.True(t, fw.OutRules.AnyProto[0].Any.Any) // Test error conditions fw = NewFirewall(time.Second, time.Minute, time.Hour, c) @@ -209,11 +212,11 @@ func BenchmarkFirewallTable_match(b *testing.B) { } _, n, _ := net.ParseCIDR("172.1.1.1/32") - ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "") - ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "") - ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "") - ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "") - ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "") + _ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "") + _ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "") + _ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "") + _ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "") + _ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "") cp := cert.NewCAPool() b.Run("fail on proto", func(b *testing.B) { @@ -281,7 +284,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { } }) - ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "") + _ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "") b.Run("pass on ip with any port", func(b *testing.B) { ip := ip2int(net.IPv4(172, 1, 1, 1)) From c359a5cf7128700cec8d67ba0e683ab4be04baf4 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Tue, 17 Dec 2019 23:43:10 -0800 Subject: [PATCH 2/5] Correct example config doc --- examples/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/config.yml b/examples/config.yml index dd3e9df..593b4ab 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -141,7 +141,7 @@ firewall: # The firewall is default deny. There is no way to write a deny rule. # Rules are comprised of a protocol, port, and one or more of host, group, or CIDR - # Logical evaluation is roughly: port AND proto AND ca_sha AND ca_name AND (host OR group OR groups OR cidr) + # Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) # - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available). # code: same as port but makes more sense when talking about ICMP, TODO: this is not currently implemented in a way that works, use `any` # proto: `any`, `tcp`, `udp`, or `icmp` From 99cac0da550452e0430852f65b8adbcbb4ae0dac Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Tue, 17 Dec 2019 23:48:33 -0800 Subject: [PATCH 3/5] Remove println --- firewall_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/firewall_test.go b/firewall_test.go index b897b44..0e9ede7 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -119,7 +119,6 @@ func TestFirewall_AddRule(t *testing.T) { assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.left) assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right) assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value) - fmt.Printf("%+v\n", fw.OutRules.AnyProto[0]) fw = NewFirewall(time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", "")) From 4e378fdb5b5c6609dd5c6a25e05b18e84122d9dc Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Wed, 18 Dec 2019 11:06:51 -0800 Subject: [PATCH 4/5] Add test for current bug in master, reduce log output in test --- firewall_test.go | 53 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/firewall_test.go b/firewall_test.go index 0e9ede7..3c6025f 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/binary" "errors" - "fmt" "math" "net" "testing" @@ -52,6 +51,11 @@ func TestNewFirewall(t *testing.T) { } func TestFirewall_AddRule(t *testing.T) { + ob := &bytes.Buffer{} + out := l.Out + l.SetOutput(ob) + defer l.SetOutput(out) + c := &cert.NebulaCertificate{} fw := NewFirewall(time.Second, time.Minute, time.Hour, c) assert.NotNil(t, fw.InRules) @@ -136,6 +140,11 @@ func TestFirewall_AddRule(t *testing.T) { } func TestFirewall_Drop(t *testing.T) { + ob := &bytes.Buffer{} + out := l.Out + l.SetOutput(ob) + defer l.SetOutput(out) + p := FirewallPacket{ ip2int(net.IPv4(1, 2, 3, 4)), ip2int(net.IPv4(1, 2, 3, 4)), @@ -152,10 +161,11 @@ func TestFirewall_Drop(t *testing.T) { c := cert.NebulaCertificate{ Details: cert.NebulaCertificateDetails{ - Name: "host1", - Ips: []*net.IPNet{&ipNet}, - Groups: []string{"default-group"}, - Issuer: "signer-shasum", + Name: "host1", + Ips: []*net.IPNet{&ipNet}, + Groups: []string{"default-group"}, + InvertedGroups: map[string]struct{}{"default-group": {}}, + Issuer: "signer-shasum", }, } h := HostInfo{ @@ -182,27 +192,31 @@ func TestFirewall_Drop(t *testing.T) { assert.True(t, fw.Drop([]byte{}, p, false, &h, cp)) p.RemoteIP = oldRemote - // test caSha assertions true + // ensure signer doesn't get in the way of group checks fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum")) - assert.False(t, fw.Drop([]byte{}, p, true, &h, cp)) - - // test caSha assertions false - fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum-nope")) + assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum")) + assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad")) assert.True(t, fw.Drop([]byte{}, p, true, &h, cp)) - // test caName true - cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} + // test caSha doesn't drop on match fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-good", "")) + assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad")) + assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum")) assert.False(t, fw.Drop([]byte{}, p, true, &h, cp)) - // test caName false + // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-bad", "")) + assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", "")) + assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", "")) assert.True(t, fw.Drop([]byte{}, p, true, &h, cp)) + + // test caName doesn't drop on match + cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} + fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) + assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", "")) + assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", "")) + assert.False(t, fw.Drop([]byte{}, p, true, &h, cp)) } func BenchmarkFirewallTable_match(b *testing.B) { @@ -300,6 +314,11 @@ func BenchmarkFirewallTable_match(b *testing.B) { } func TestFirewall_Drop2(t *testing.T) { + ob := &bytes.Buffer{} + out := l.Out + l.SetOutput(ob) + defer l.SetOutput(out) + p := FirewallPacket{ ip2int(net.IPv4(1, 2, 3, 4)), ip2int(net.IPv4(1, 2, 3, 4)), From 2d8a8143dee9418c34390082dd99c44ad30885c3 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Wed, 18 Dec 2019 21:23:34 -0800 Subject: [PATCH 5/5] Actual fix for the real issue with tests --- firewall.go | 21 +++-------- firewall_test.go | 97 ++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 100 insertions(+), 18 deletions(-) diff --git a/firewall.go b/firewall.go index 1c5ec9b..45373b6 100644 --- a/firewall.go +++ b/firewall.go @@ -541,11 +541,6 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert } func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error { - // If there is an any rule then there is no need to establish specific ca rules - if fc.Any != nil { - return fc.Any.addRule(groups, host, ip) - } - fr := func() *FirewallRule { return &FirewallRule{ Hosts: make(map[string]struct{}), @@ -554,19 +549,11 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam } } - any := false if caSha == "" && caName == "" { - any = true - } - - if any { if fc.Any == nil { fc.Any = fr() } - // If it's any we need to wipe out any pre-existing rules to save on memory - fc.CAShas = make(map[string]*FirewallRule) - fc.CANames = make(map[string]*FirewallRule) return fc.Any.addRule(groups, host, ip) } @@ -598,8 +585,8 @@ func (fc *FirewallCA) match(p FirewallPacket, c *cert.NebulaCertificate, caPool return false } - if fc.Any != nil { - return fc.Any.match(p, c) + if fc.Any.match(p, c) { + return true } if t, ok := fc.CAShas[c.Details.Issuer]; ok { @@ -645,6 +632,10 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) err } func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool { + if len(groups) == 0 && host == "" && ip == nil { + return true + } + for _, group := range groups { if group == "any" { return true diff --git a/firewall_test.go b/firewall_test.go index 3c6025f..ceb589d 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -64,9 +64,8 @@ func TestFirewall_AddRule(t *testing.T) { _, ti, _ := net.ParseCIDR("1.2.3.4/32") assert.Nil(t, fw.AddRule(true, fwProtoTCP, 1, 1, []string{}, "", nil, "", "")) - // Make sure an empty rule creates structure but doesn't allow anything to flow - //TODO: ideally an empty rule would return an error - assert.False(t, fw.InRules.TCP[1].Any.Any) + // An empty rule is any + assert.True(t, fw.InRules.TCP[1].Any.Any) assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.left) @@ -182,6 +181,7 @@ func TestFirewall_Drop(t *testing.T) { // Drop outbound assert.True(t, fw.Drop([]byte{}, p, false, &h, cp)) // Allow inbound + resetConntrack(fw) assert.False(t, fw.Drop([]byte{}, p, true, &h, cp)) // Allow outbound because conntrack assert.False(t, fw.Drop([]byte{}, p, false, &h, cp)) @@ -368,9 +368,94 @@ func TestFirewall_Drop2(t *testing.T) { // h1/c1 lacks the proper groups assert.True(t, fw.Drop([]byte{}, p, true, &h1, cp)) // c has the proper groups + resetConntrack(fw) assert.False(t, fw.Drop([]byte{}, p, true, &h, cp)) } +func TestFirewall_Drop3(t *testing.T) { + ob := &bytes.Buffer{} + out := l.Out + l.SetOutput(ob) + defer l.SetOutput(out) + + p := FirewallPacket{ + ip2int(net.IPv4(1, 2, 3, 4)), + ip2int(net.IPv4(1, 2, 3, 4)), + 1, + 1, + fwProtoUDP, + false, + } + + ipNet := net.IPNet{ + IP: net.IPv4(1, 2, 3, 4), + Mask: net.IPMask{255, 255, 255, 0}, + } + + c := cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "host-owner", + Ips: []*net.IPNet{&ipNet}, + }, + } + + c1 := cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "host1", + Ips: []*net.IPNet{&ipNet}, + Issuer: "signer-sha-bad", + }, + } + h1 := HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &c1, + }, + } + h1.CreateRemoteCIDR(&c1) + + c2 := cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "host2", + Ips: []*net.IPNet{&ipNet}, + Issuer: "signer-sha", + }, + } + h2 := HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &c2, + }, + } + h2.CreateRemoteCIDR(&c2) + + c3 := cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "host3", + Ips: []*net.IPNet{&ipNet}, + Issuer: "signer-sha-bad", + }, + } + h3 := HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &c3, + }, + } + h3.CreateRemoteCIDR(&c3) + + fw := NewFirewall(time.Second, time.Minute, time.Hour, &c) + assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "host1", nil, "", "")) + assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha")) + cp := cert.NewCAPool() + + // c1 should pass because host match + assert.False(t, fw.Drop([]byte{}, p, true, &h1, cp)) + // c2 should pass because ca sha match + resetConntrack(fw) + assert.False(t, fw.Drop([]byte{}, p, true, &h2, cp)) + // c3 should fail because no match + resetConntrack(fw) + assert.True(t, fw.Drop([]byte{}, p, true, &h3, cp)) +} + func BenchmarkLookup(b *testing.B) { ml := func(m map[string]struct{}, a [][]string) { for n := 0; n < b.N; n++ { @@ -769,3 +854,9 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end mf.nextCallReturn = nil return err } + +func resetConntrack(fw *Firewall) { + fw.connMutex.Lock() + fw.Conns = map[FirewallPacket]*conn{} + fw.connMutex.Unlock() +}