In the middle

This commit is contained in:
Nate Brown 2024-01-23 16:02:10 -06:00
parent 8822f1366c
commit 8f44f22c37
4 changed files with 253 additions and 180 deletions

View File

@ -142,15 +142,22 @@ func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) {
return ok, value return ok, value
} }
// Match finds the most specific match type eachFunc[T any] func(T) bool
// TODO this is exact match
func (tree *Tree4[T]) Match(ip iputil.VpnIp) (ok bool, value T) { // EachContains will call a function, passing the value, for each entry until the function returns false or the search is complete
// The final return value will be true if the provided function returned true
func (tree *Tree4[T]) EachContains(ip iputil.VpnIp, each eachFunc[T]) bool {
bit := startbit bit := startbit
node := tree.root node := tree.root
lastNode := node
for node != nil { for node != nil {
lastNode = node if node.hasValue {
// If the each func returns true then we can exit the loop
if each(node.value) {
return true
}
}
if ip&bit != 0 { if ip&bit != 0 {
node = node.right node = node.right
} else { } else {
@ -160,10 +167,33 @@ func (tree *Tree4[T]) Match(ip iputil.VpnIp) (ok bool, value T) {
bit >>= 1 bit >>= 1
} }
if bit == 0 && lastNode != nil { return false
value = lastNode.value
ok = true
} }
// GetCIDR returns the entry added by the most recent matching AddCIDR call
func (tree *Tree4[T]) GetCIDR(cidr *net.IPNet) (ok bool, value T) {
bit := startbit
node := tree.root
ip := iputil.Ip2VpnIp(cidr.IP)
mask := iputil.Ip2VpnIp(cidr.Mask)
// Find our last ancestor in the tree
for node != nil && bit&mask != 0 {
if ip&bit != 0 {
node = node.right
} else {
node = node.left
}
bit = bit >> 1
}
if bit&mask == 0 && node != nil {
value = node.value
ok = node.hasValue
}
return ok, value return ok, value
} }

View File

@ -115,35 +115,36 @@ func TestCIDRTree_MostSpecificContains(t *testing.T) {
assert.Equal(t, "cool", r) assert.Equal(t, "cool", r)
} }
func TestCIDRTree_Match(t *testing.T) { func TestTree4_GetCIDR(t *testing.T) {
tree := NewTree4[string]() tree := NewTree4[string]()
tree.AddCIDR(Parse("4.1.1.0/32"), "1a") tree.AddCIDR(Parse("1.0.0.0/8"), "1")
tree.AddCIDR(Parse("4.1.1.1/32"), "1b") tree.AddCIDR(Parse("2.1.0.0/16"), "2")
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
tree.AddCIDR(Parse("4.1.1.1/32"), "4b")
tree.AddCIDR(Parse("4.1.2.1/32"), "4c")
tree.AddCIDR(Parse("254.0.0.0/4"), "5")
tests := []struct { tests := []struct {
Found bool Found bool
Result interface{} Result interface{}
IP string IPNet *net.IPNet
}{ }{
{true, "1a", "4.1.1.0"}, {true, "1", Parse("1.0.0.0/8")},
{true, "1b", "4.1.1.1"}, {true, "2", Parse("2.1.0.0/16")},
{true, "3", Parse("3.1.1.0/24")},
{true, "4a", Parse("4.1.1.0/24")},
{true, "4b", Parse("4.1.1.1/32")},
{true, "4c", Parse("4.1.2.1/32")},
{true, "5", Parse("254.0.0.0/4")},
{false, "", Parse("2.0.0.0/8")},
} }
for _, tt := range tests { for _, tt := range tests {
ok, r := tree.Match(iputil.Ip2VpnIp(net.ParseIP(tt.IP))) ok, r := tree.GetCIDR(tt.IPNet)
assert.Equal(t, tt.Found, ok) assert.Equal(t, tt.Found, ok)
assert.Equal(t, tt.Result, r) assert.Equal(t, tt.Result, r)
} }
tree = NewTree4[string]()
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
assert.True(t, ok)
assert.Equal(t, "cool", r)
ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
assert.True(t, ok)
assert.Equal(t, "cool", r)
} }
func BenchmarkCIDRTree_Contains(b *testing.B) { func BenchmarkCIDRTree_Contains(b *testing.B) {
@ -167,25 +168,3 @@ func BenchmarkCIDRTree_Contains(b *testing.B) {
} }
}) })
} }
func BenchmarkCIDRTree_Match(b *testing.B) {
tree := NewTree4[string]()
tree.AddCIDR(Parse("1.1.0.0/16"), "1")
tree.AddCIDR(Parse("1.2.1.1/32"), "1")
tree.AddCIDR(Parse("192.2.1.1/32"), "1")
tree.AddCIDR(Parse("172.2.1.1/32"), "1")
ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1"))
b.Run("found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Match(ip)
}
})
ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255"))
b.Run("not found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Match(ip)
}
})
}

View File

@ -84,6 +84,8 @@ type FirewallConntrack struct {
TimerWheel *TimerWheel[firewall.Packet] TimerWheel *TimerWheel[firewall.Packet]
} }
// FirewallTable is the entry point for a rule, the evaluation order is:
// Proto AND port AND (CA SHA or CA name) AND local CIDR AND (group OR groups OR name OR remote CIDR)
type FirewallTable struct { type FirewallTable struct {
TCP firewallPort TCP firewallPort
UDP firewallPort UDP firewallPort
@ -101,24 +103,28 @@ func newFirewallTable() *FirewallTable {
} }
type FirewallCA struct { type FirewallCA struct {
Any *FirewallRule Any *firewallLocalCIDR
CANames map[string]*FirewallRule CANames map[string]*firewallLocalCIDR
CAShas map[string]*FirewallRule CAShas map[string]*firewallLocalCIDR
} }
type FirewallRule struct { type FirewallRule struct {
// Any makes Hosts, Groups, CIDR and LocalCIDR irrelevant // Any makes Hosts, Groups, and CIDR irrelevant
Any bool Any bool
Hosts map[string]struct{} Hosts map[string]struct{}
Groups [][]string Groups [][]string
CIDR *cidr.Tree4[struct{}] CIDR *cidr.Tree4[struct{}]
LocalCIDR *cidr.Tree4[struct{}]
} }
// Even though ports are uint16, int32 maps are faster for lookup // Even though ports are uint16, int32 maps are faster for lookup
// Plus we can use `-1` for fragment rules // Plus we can use `-1` for fragment rules
type firewallPort map[int32]*FirewallCA type firewallPort map[int32]*FirewallCA
type firewallLocalCIDR struct {
Any *FirewallRule
LocalCIDR *cidr.Tree4[*FirewallRule]
}
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall { func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
//TODO: error on 0 duration //TODO: error on 0 duration
@ -632,8 +638,8 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
for i := startPort; i <= endPort; i++ { for i := startPort; i <= endPort; i++ {
if _, ok := fp[i]; !ok { if _, ok := fp[i]; !ok {
fp[i] = &FirewallCA{ fp[i] = &FirewallCA{
CANames: make(map[string]*FirewallRule), CANames: make(map[string]*firewallLocalCIDR),
CAShas: make(map[string]*FirewallRule), CAShas: make(map[string]*firewallLocalCIDR),
} }
} }
@ -669,18 +675,15 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer
} }
func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error { func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
fr := func() *FirewallRule { fl := func() *firewallLocalCIDR {
return &FirewallRule{ return &firewallLocalCIDR{
Hosts: make(map[string]struct{}), LocalCIDR: cidr.NewTree4[*FirewallRule](),
Groups: make([][]string, 0),
CIDR: cidr.NewTree4[struct{}](),
LocalCIDR: cidr.NewTree4[struct{}](),
} }
} }
if caSha == "" && caName == "" { if caSha == "" && caName == "" {
if fc.Any == nil { if fc.Any == nil {
fc.Any = fr() fc.Any = fl()
} }
return fc.Any.addRule(groups, host, ip, localIp) return fc.Any.addRule(groups, host, ip, localIp)
@ -688,7 +691,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN
if caSha != "" { if caSha != "" {
if _, ok := fc.CAShas[caSha]; !ok { if _, ok := fc.CAShas[caSha]; !ok {
fc.CAShas[caSha] = fr() fc.CAShas[caSha] = fl()
} }
err := fc.CAShas[caSha].addRule(groups, host, ip, localIp) err := fc.CAShas[caSha].addRule(groups, host, ip, localIp)
if err != nil { if err != nil {
@ -698,7 +701,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN
if caName != "" { if caName != "" {
if _, ok := fc.CANames[caName]; !ok { if _, ok := fc.CANames[caName]; !ok {
fc.CANames[caName] = fr() fc.CANames[caName] = fl()
} }
err := fc.CANames[caName].addRule(groups, host, ip, localIp) err := fc.CANames[caName].addRule(groups, host, ip, localIp)
if err != nil { if err != nil {
@ -732,18 +735,63 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
return fc.CANames[s.Details.Name].match(p, c) return fc.CANames[s.Details.Name].match(p, c)
} }
func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, localIp *net.IPNet) error { func (fc *firewallLocalCIDR) addRule(groups []string, host string, ip, localIp *net.IPNet) error {
fr := func() *FirewallRule {
return &FirewallRule{
Hosts: make(map[string]struct{}),
Groups: make([][]string, 0),
CIDR: cidr.NewTree4[struct{}](),
}
}
if localIp == nil || (localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0))) {
if fc.Any == nil {
fc.Any = fr()
}
return fc.Any.addRule(groups, host, ip)
}
_, efr := fc.LocalCIDR.GetCIDR(localIp)
if efr != nil {
return efr.addRule(groups, host, ip)
}
nfr := fr()
err := nfr.addRule(groups, host, ip)
if err != nil {
return err
}
fc.LocalCIDR.AddCIDR(localIp, nfr)
return nil
}
func (fc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
if fc == nil {
return false
}
if fc.Any.match(p, c) {
return true
}
return fc.LocalCIDR.EachContains(p.LocalIP, func(fr *FirewallRule) bool {
return fr.match(p, c)
})
}
func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error {
if fr.Any { if fr.Any {
return nil return nil
} }
if fr.isAny(groups, host, ip, localIp) { if fr.isAny(groups, host, ip) {
fr.Any = true fr.Any = true
// If it's any we need to wipe out any pre-existing rules to save on memory // If it's any we need to wipe out any pre-existing rules to save on memory
fr.Groups = make([][]string, 0) fr.Groups = make([][]string, 0)
fr.Hosts = make(map[string]struct{}) fr.Hosts = make(map[string]struct{})
fr.CIDR = cidr.NewTree4[struct{}]() fr.CIDR = cidr.NewTree4[struct{}]()
fr.LocalCIDR = cidr.NewTree4[struct{}]()
} else { } else {
if len(groups) > 0 { if len(groups) > 0 {
fr.Groups = append(fr.Groups, groups) fr.Groups = append(fr.Groups, groups)
@ -756,17 +804,13 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, loc
if ip != nil { if ip != nil {
fr.CIDR.AddCIDR(ip, struct{}{}) fr.CIDR.AddCIDR(ip, struct{}{})
} }
if localIp != nil {
fr.LocalCIDR.AddCIDR(localIp, struct{}{})
}
} }
return nil return nil
} }
func (fr *FirewallRule) isAny(groups []string, host string, ip, localIp *net.IPNet) bool { func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
if len(groups) == 0 && host == "" && ip == nil && localIp == nil { if len(groups) == 0 && host == "" && ip == nil {
return true return true
} }
@ -784,10 +828,6 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip, localIp *net.IPN
return true return true
} }
if localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0)) {
return true
}
return false return false
} }
@ -832,13 +872,6 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
} }
} }
if fr.LocalCIDR != nil {
ok, _ := fr.LocalCIDR.Contains(p.LocalIP)
if ok {
return true
}
}
// No host, group, or cidr matched, bye bye // No host, group, or cidr matched, bye bye
return false return false
} }

View File

@ -71,37 +71,34 @@ func TestFirewall_AddRule(t *testing.T) {
assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", ""))
// An empty rule is any // An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any) assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Any.Groups)
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) assert.Empty(t, fw.InRules.TCP[1].Any.Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", ""))
assert.False(t, fw.InRules.UDP[1].Any.Any) assert.False(t, fw.InRules.UDP[1].Any.Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1") assert.Contains(t, fw.InRules.UDP[1].Any.Any.Groups[0], "g1")
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) assert.Empty(t, fw.InRules.UDP[1].Any.Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", ""))
assert.False(t, fw.InRules.ICMP[1].Any.Any) assert.False(t, fw.InRules.ICMP[1].Any.Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Empty(t, fw.InRules.ICMP[1].Any.Any.Groups)
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") assert.Contains(t, fw.InRules.ICMP[1].Any.Any.Hosts, "h1")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", ""))
assert.False(t, fw.OutRules.AnyProto[1].Any.Any) assert.False(t, fw.OutRules.AnyProto[1].Any.Any.Any)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups) ok, _ := fw.OutRules.AnyProto[1].Any.Any.CIDR.GetCIDR(ti)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
assert.False(t, fw.OutRules.AnyProto[1].Any.Any) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups) ok, fr := fw.OutRules.AnyProto[1].Any.LocalCIDR.GetCIDR(ti)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
ok, _ = fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))
assert.True(t, ok) assert.True(t, ok)
assert.True(t, fr.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", ""))
@ -114,29 +111,28 @@ func TestFirewall_AddRule(t *testing.T) {
// Set any and clear fields // Set any and clear fields
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, ti, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, ti, "", ""))
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0]) ok, fr = fw.OutRules.AnyProto[0].Any.LocalCIDR.GetCIDR(ti)
assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
ok, _ = fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))
assert.True(t, ok)
ok, _ = fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))
assert.True(t, ok) assert.True(t, ok)
assert.False(t, fr.Any)
assert.Equal(t, []string{"g1", "g2"}, fr.Groups[0])
assert.Contains(t, fr.Hosts, "h1")
// run twice just to make sure // run twice just to make sure
//TODO: these ANY rules should clear the CA firewall portion //TODO: these ANY rules should clear the CA firewall portion
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups) assert.Empty(t, fw.OutRules.AnyProto[0].Any.Any.Groups)
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts) assert.Empty(t, fw.OutRules.AnyProto[0].Any.Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0") _, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions // Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
@ -231,28 +227,37 @@ func BenchmarkFirewallTable_match(b *testing.B) {
} }
_, n, _ := net.ParseCIDR("172.1.1.1/32") _, n, _ := net.ParseCIDR("172.1.1.1/32")
_ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, n, "", "") goodLocalCIDRIP := iputil.Ip2VpnIp(n.IP)
_ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, n, "", "") _ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, nil, "", "")
_ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, n, "", "") _ = ft.TCP.addRule(100, 100, []string{"good-group"}, "good-host", nil, n, "", "")
_ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, n, "", "")
_ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, n, "", "")
cp := cert.NewCAPool() cp := cert.NewCAPool()
b.Run("fail on proto", func(b *testing.B) { b.Run("fail on proto", func(b *testing.B) {
// This benchmark is showing us the cost of failing to match the protocol
c := &cert.NebulaCertificate{} c := &cert.NebulaCertificate{}
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp) assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp))
} }
}) })
b.Run("fail on port", func(b *testing.B) { b.Run("pass proto, fail on port", func(b *testing.B) {
// This benchmark is showing us the cost of matching a specific protocol but failing to match the port
c := &cert.NebulaCertificate{} c := &cert.NebulaCertificate{}
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp) assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp))
} }
}) })
b.Run("fail all group, name, and cidr", func(b *testing.B) { b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
c := &cert.NebulaCertificate{}
ip, _, _ := net.ParseCIDR("9.254.254.254/32")
lip := iputil.Ip2VpnIp(ip)
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: lip}, true, c, cp))
}
})
b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
_, ip, _ := net.ParseCIDR("9.254.254.254/32") _, ip, _ := net.ParseCIDR("9.254.254.254/32")
c := &cert.NebulaCertificate{ c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{ Details: cert.NebulaCertificateDetails{
@ -262,11 +267,25 @@ func BenchmarkFirewallTable_match(b *testing.B) {
}, },
} }
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp) assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
} }
}) })
b.Run("pass on group", func(b *testing.B) { b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
_, ip, _ := net.ParseCIDR("9.254.254.254/32")
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}},
Name: "nope",
Ips: []*net.IPNet{ip},
},
}
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
}
})
b.Run("pass on group on any local cidr", func(b *testing.B) {
c := &cert.NebulaCertificate{ c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{ Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"good-group": {}}, InvertedGroups: map[string]struct{}{"good-group": {}},
@ -274,7 +293,19 @@ func BenchmarkFirewallTable_match(b *testing.B) {
}, },
} }
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp) assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
}
})
b.Run("pass on group on specific local cidr", func(b *testing.B) {
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"good-group": {}},
Name: "nope",
},
}
for n := 0; n < b.N; n++ {
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
} }
}) })
@ -289,60 +320,60 @@ func BenchmarkFirewallTable_match(b *testing.B) {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp) ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
} }
}) })
//
b.Run("pass on ip", func(b *testing.B) { //b.Run("pass on ip", func(b *testing.B) {
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
c := &cert.NebulaCertificate{ // c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{ // Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}}, // InvertedGroups: map[string]struct{}{"nope": {}},
Name: "good-host", // Name: "good-host",
}, // },
} // }
for n := 0; n < b.N; n++ { // for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp) // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
} // }
}) //})
//
b.Run("pass on local ip", func(b *testing.B) { //b.Run("pass on local ip", func(b *testing.B) {
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
c := &cert.NebulaCertificate{ // c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{ // Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}}, // InvertedGroups: map[string]struct{}{"nope": {}},
Name: "good-host", // Name: "good-host",
}, // },
} // }
for n := 0; n < b.N; n++ { // for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp) // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp)
} // }
}) //})
//
_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "") //_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "")
//
b.Run("pass on ip with any port", func(b *testing.B) { //b.Run("pass on ip with any port", func(b *testing.B) {
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
c := &cert.NebulaCertificate{ // c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{ // Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}}, // InvertedGroups: map[string]struct{}{"nope": {}},
Name: "good-host", // Name: "good-host",
}, // },
} // }
for n := 0; n < b.N; n++ { // for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp) // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
} // }
}) //})
//
b.Run("pass on local ip with any port", func(b *testing.B) { //b.Run("pass on local ip with any port", func(b *testing.B) {
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
c := &cert.NebulaCertificate{ // c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{ // Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}}, // InvertedGroups: map[string]struct{}{"nope": {}},
Name: "good-host", // Name: "good-host",
}, // },
} // }
for n := 0; n < b.N; n++ { // for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp) // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp)
} // }
}) //})
} }
func TestFirewall_Drop2(t *testing.T) { func TestFirewall_Drop2(t *testing.T) {