mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-10 21:43:57 +01:00
In the middle
This commit is contained in:
parent
8822f1366c
commit
8f44f22c37
@ -142,15 +142,22 @@ func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) {
|
|||||||
return ok, value
|
return ok, value
|
||||||
}
|
}
|
||||||
|
|
||||||
// Match finds the most specific match
|
type eachFunc[T any] func(T) bool
|
||||||
// TODO this is exact match
|
|
||||||
func (tree *Tree4[T]) Match(ip iputil.VpnIp) (ok bool, value T) {
|
// EachContains will call a function, passing the value, for each entry until the function returns false or the search is complete
|
||||||
|
// The final return value will be true if the provided function returned true
|
||||||
|
func (tree *Tree4[T]) EachContains(ip iputil.VpnIp, each eachFunc[T]) bool {
|
||||||
bit := startbit
|
bit := startbit
|
||||||
node := tree.root
|
node := tree.root
|
||||||
lastNode := node
|
|
||||||
|
|
||||||
for node != nil {
|
for node != nil {
|
||||||
lastNode = node
|
if node.hasValue {
|
||||||
|
// If the each func returns true then we can exit the loop
|
||||||
|
if each(node.value) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if ip&bit != 0 {
|
if ip&bit != 0 {
|
||||||
node = node.right
|
node = node.right
|
||||||
} else {
|
} else {
|
||||||
@ -160,10 +167,33 @@ func (tree *Tree4[T]) Match(ip iputil.VpnIp) (ok bool, value T) {
|
|||||||
bit >>= 1
|
bit >>= 1
|
||||||
}
|
}
|
||||||
|
|
||||||
if bit == 0 && lastNode != nil {
|
return false
|
||||||
value = lastNode.value
|
|
||||||
ok = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetCIDR returns the entry added by the most recent matching AddCIDR call
|
||||||
|
func (tree *Tree4[T]) GetCIDR(cidr *net.IPNet) (ok bool, value T) {
|
||||||
|
bit := startbit
|
||||||
|
node := tree.root
|
||||||
|
|
||||||
|
ip := iputil.Ip2VpnIp(cidr.IP)
|
||||||
|
mask := iputil.Ip2VpnIp(cidr.Mask)
|
||||||
|
|
||||||
|
// Find our last ancestor in the tree
|
||||||
|
for node != nil && bit&mask != 0 {
|
||||||
|
if ip&bit != 0 {
|
||||||
|
node = node.right
|
||||||
|
} else {
|
||||||
|
node = node.left
|
||||||
|
}
|
||||||
|
|
||||||
|
bit = bit >> 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if bit&mask == 0 && node != nil {
|
||||||
|
value = node.value
|
||||||
|
ok = node.hasValue
|
||||||
|
}
|
||||||
|
|
||||||
return ok, value
|
return ok, value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -115,35 +115,36 @@ func TestCIDRTree_MostSpecificContains(t *testing.T) {
|
|||||||
assert.Equal(t, "cool", r)
|
assert.Equal(t, "cool", r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCIDRTree_Match(t *testing.T) {
|
func TestTree4_GetCIDR(t *testing.T) {
|
||||||
tree := NewTree4[string]()
|
tree := NewTree4[string]()
|
||||||
tree.AddCIDR(Parse("4.1.1.0/32"), "1a")
|
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
|
||||||
tree.AddCIDR(Parse("4.1.1.1/32"), "1b")
|
tree.AddCIDR(Parse("2.1.0.0/16"), "2")
|
||||||
|
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
|
||||||
|
tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
|
||||||
|
tree.AddCIDR(Parse("4.1.1.1/32"), "4b")
|
||||||
|
tree.AddCIDR(Parse("4.1.2.1/32"), "4c")
|
||||||
|
tree.AddCIDR(Parse("254.0.0.0/4"), "5")
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
Found bool
|
Found bool
|
||||||
Result interface{}
|
Result interface{}
|
||||||
IP string
|
IPNet *net.IPNet
|
||||||
}{
|
}{
|
||||||
{true, "1a", "4.1.1.0"},
|
{true, "1", Parse("1.0.0.0/8")},
|
||||||
{true, "1b", "4.1.1.1"},
|
{true, "2", Parse("2.1.0.0/16")},
|
||||||
|
{true, "3", Parse("3.1.1.0/24")},
|
||||||
|
{true, "4a", Parse("4.1.1.0/24")},
|
||||||
|
{true, "4b", Parse("4.1.1.1/32")},
|
||||||
|
{true, "4c", Parse("4.1.2.1/32")},
|
||||||
|
{true, "5", Parse("254.0.0.0/4")},
|
||||||
|
{false, "", Parse("2.0.0.0/8")},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
ok, r := tree.Match(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
|
ok, r := tree.GetCIDR(tt.IPNet)
|
||||||
assert.Equal(t, tt.Found, ok)
|
assert.Equal(t, tt.Found, ok)
|
||||||
assert.Equal(t, tt.Result, r)
|
assert.Equal(t, tt.Result, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
tree = NewTree4[string]()
|
|
||||||
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
|
|
||||||
ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equal(t, "cool", r)
|
|
||||||
|
|
||||||
ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equal(t, "cool", r)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkCIDRTree_Contains(b *testing.B) {
|
func BenchmarkCIDRTree_Contains(b *testing.B) {
|
||||||
@ -167,25 +168,3 @@ func BenchmarkCIDRTree_Contains(b *testing.B) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkCIDRTree_Match(b *testing.B) {
|
|
||||||
tree := NewTree4[string]()
|
|
||||||
tree.AddCIDR(Parse("1.1.0.0/16"), "1")
|
|
||||||
tree.AddCIDR(Parse("1.2.1.1/32"), "1")
|
|
||||||
tree.AddCIDR(Parse("192.2.1.1/32"), "1")
|
|
||||||
tree.AddCIDR(Parse("172.2.1.1/32"), "1")
|
|
||||||
|
|
||||||
ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1"))
|
|
||||||
b.Run("found", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
tree.Match(ip)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255"))
|
|
||||||
b.Run("not found", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
tree.Match(ip)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|||||||
105
firewall.go
105
firewall.go
@ -84,6 +84,8 @@ type FirewallConntrack struct {
|
|||||||
TimerWheel *TimerWheel[firewall.Packet]
|
TimerWheel *TimerWheel[firewall.Packet]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FirewallTable is the entry point for a rule, the evaluation order is:
|
||||||
|
// Proto AND port AND (CA SHA or CA name) AND local CIDR AND (group OR groups OR name OR remote CIDR)
|
||||||
type FirewallTable struct {
|
type FirewallTable struct {
|
||||||
TCP firewallPort
|
TCP firewallPort
|
||||||
UDP firewallPort
|
UDP firewallPort
|
||||||
@ -101,24 +103,28 @@ func newFirewallTable() *FirewallTable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type FirewallCA struct {
|
type FirewallCA struct {
|
||||||
Any *FirewallRule
|
Any *firewallLocalCIDR
|
||||||
CANames map[string]*FirewallRule
|
CANames map[string]*firewallLocalCIDR
|
||||||
CAShas map[string]*FirewallRule
|
CAShas map[string]*firewallLocalCIDR
|
||||||
}
|
}
|
||||||
|
|
||||||
type FirewallRule struct {
|
type FirewallRule struct {
|
||||||
// Any makes Hosts, Groups, CIDR and LocalCIDR irrelevant
|
// Any makes Hosts, Groups, and CIDR irrelevant
|
||||||
Any bool
|
Any bool
|
||||||
Hosts map[string]struct{}
|
Hosts map[string]struct{}
|
||||||
Groups [][]string
|
Groups [][]string
|
||||||
CIDR *cidr.Tree4[struct{}]
|
CIDR *cidr.Tree4[struct{}]
|
||||||
LocalCIDR *cidr.Tree4[struct{}]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Even though ports are uint16, int32 maps are faster for lookup
|
// Even though ports are uint16, int32 maps are faster for lookup
|
||||||
// Plus we can use `-1` for fragment rules
|
// Plus we can use `-1` for fragment rules
|
||||||
type firewallPort map[int32]*FirewallCA
|
type firewallPort map[int32]*FirewallCA
|
||||||
|
|
||||||
|
type firewallLocalCIDR struct {
|
||||||
|
Any *FirewallRule
|
||||||
|
LocalCIDR *cidr.Tree4[*FirewallRule]
|
||||||
|
}
|
||||||
|
|
||||||
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
||||||
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
|
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
|
||||||
//TODO: error on 0 duration
|
//TODO: error on 0 duration
|
||||||
@ -632,8 +638,8 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
|
|||||||
for i := startPort; i <= endPort; i++ {
|
for i := startPort; i <= endPort; i++ {
|
||||||
if _, ok := fp[i]; !ok {
|
if _, ok := fp[i]; !ok {
|
||||||
fp[i] = &FirewallCA{
|
fp[i] = &FirewallCA{
|
||||||
CANames: make(map[string]*FirewallRule),
|
CANames: make(map[string]*firewallLocalCIDR),
|
||||||
CAShas: make(map[string]*FirewallRule),
|
CAShas: make(map[string]*firewallLocalCIDR),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -669,18 +675,15 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
|
func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
|
||||||
fr := func() *FirewallRule {
|
fl := func() *firewallLocalCIDR {
|
||||||
return &FirewallRule{
|
return &firewallLocalCIDR{
|
||||||
Hosts: make(map[string]struct{}),
|
LocalCIDR: cidr.NewTree4[*FirewallRule](),
|
||||||
Groups: make([][]string, 0),
|
|
||||||
CIDR: cidr.NewTree4[struct{}](),
|
|
||||||
LocalCIDR: cidr.NewTree4[struct{}](),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if caSha == "" && caName == "" {
|
if caSha == "" && caName == "" {
|
||||||
if fc.Any == nil {
|
if fc.Any == nil {
|
||||||
fc.Any = fr()
|
fc.Any = fl()
|
||||||
}
|
}
|
||||||
|
|
||||||
return fc.Any.addRule(groups, host, ip, localIp)
|
return fc.Any.addRule(groups, host, ip, localIp)
|
||||||
@ -688,7 +691,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN
|
|||||||
|
|
||||||
if caSha != "" {
|
if caSha != "" {
|
||||||
if _, ok := fc.CAShas[caSha]; !ok {
|
if _, ok := fc.CAShas[caSha]; !ok {
|
||||||
fc.CAShas[caSha] = fr()
|
fc.CAShas[caSha] = fl()
|
||||||
}
|
}
|
||||||
err := fc.CAShas[caSha].addRule(groups, host, ip, localIp)
|
err := fc.CAShas[caSha].addRule(groups, host, ip, localIp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -698,7 +701,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN
|
|||||||
|
|
||||||
if caName != "" {
|
if caName != "" {
|
||||||
if _, ok := fc.CANames[caName]; !ok {
|
if _, ok := fc.CANames[caName]; !ok {
|
||||||
fc.CANames[caName] = fr()
|
fc.CANames[caName] = fl()
|
||||||
}
|
}
|
||||||
err := fc.CANames[caName].addRule(groups, host, ip, localIp)
|
err := fc.CANames[caName].addRule(groups, host, ip, localIp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -732,18 +735,63 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
|
|||||||
return fc.CANames[s.Details.Name].match(p, c)
|
return fc.CANames[s.Details.Name].match(p, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, localIp *net.IPNet) error {
|
func (fc *firewallLocalCIDR) addRule(groups []string, host string, ip, localIp *net.IPNet) error {
|
||||||
|
fr := func() *FirewallRule {
|
||||||
|
return &FirewallRule{
|
||||||
|
Hosts: make(map[string]struct{}),
|
||||||
|
Groups: make([][]string, 0),
|
||||||
|
CIDR: cidr.NewTree4[struct{}](),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if localIp == nil || (localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0))) {
|
||||||
|
if fc.Any == nil {
|
||||||
|
fc.Any = fr()
|
||||||
|
}
|
||||||
|
|
||||||
|
return fc.Any.addRule(groups, host, ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, efr := fc.LocalCIDR.GetCIDR(localIp)
|
||||||
|
if efr != nil {
|
||||||
|
return efr.addRule(groups, host, ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
nfr := fr()
|
||||||
|
err := nfr.addRule(groups, host, ip)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fc.LocalCIDR.AddCIDR(localIp, nfr)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
|
||||||
|
if fc == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if fc.Any.match(p, c) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return fc.LocalCIDR.EachContains(p.LocalIP, func(fr *FirewallRule) bool {
|
||||||
|
return fr.match(p, c)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error {
|
||||||
if fr.Any {
|
if fr.Any {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if fr.isAny(groups, host, ip, localIp) {
|
if fr.isAny(groups, host, ip) {
|
||||||
fr.Any = true
|
fr.Any = true
|
||||||
// If it's any we need to wipe out any pre-existing rules to save on memory
|
// If it's any we need to wipe out any pre-existing rules to save on memory
|
||||||
fr.Groups = make([][]string, 0)
|
fr.Groups = make([][]string, 0)
|
||||||
fr.Hosts = make(map[string]struct{})
|
fr.Hosts = make(map[string]struct{})
|
||||||
fr.CIDR = cidr.NewTree4[struct{}]()
|
fr.CIDR = cidr.NewTree4[struct{}]()
|
||||||
fr.LocalCIDR = cidr.NewTree4[struct{}]()
|
|
||||||
} else {
|
} else {
|
||||||
if len(groups) > 0 {
|
if len(groups) > 0 {
|
||||||
fr.Groups = append(fr.Groups, groups)
|
fr.Groups = append(fr.Groups, groups)
|
||||||
@ -756,17 +804,13 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, loc
|
|||||||
if ip != nil {
|
if ip != nil {
|
||||||
fr.CIDR.AddCIDR(ip, struct{}{})
|
fr.CIDR.AddCIDR(ip, struct{}{})
|
||||||
}
|
}
|
||||||
|
|
||||||
if localIp != nil {
|
|
||||||
fr.LocalCIDR.AddCIDR(localIp, struct{}{})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fr *FirewallRule) isAny(groups []string, host string, ip, localIp *net.IPNet) bool {
|
func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
|
||||||
if len(groups) == 0 && host == "" && ip == nil && localIp == nil {
|
if len(groups) == 0 && host == "" && ip == nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -784,10 +828,6 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip, localIp *net.IPN
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0)) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -832,13 +872,6 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if fr.LocalCIDR != nil {
|
|
||||||
ok, _ := fr.LocalCIDR.Contains(p.LocalIP)
|
|
||||||
if ok {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// No host, group, or cidr matched, bye bye
|
// No host, group, or cidr matched, bye bye
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
217
firewall_test.go
217
firewall_test.go
@ -71,37 +71,34 @@ func TestFirewall_AddRule(t *testing.T) {
|
|||||||
|
|
||||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", ""))
|
||||||
// An empty rule is any
|
// An empty rule is any
|
||||||
assert.True(t, fw.InRules.TCP[1].Any.Any)
|
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
|
||||||
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
|
assert.Empty(t, fw.InRules.TCP[1].Any.Any.Groups)
|
||||||
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
|
assert.Empty(t, fw.InRules.TCP[1].Any.Any.Hosts)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", ""))
|
||||||
assert.False(t, fw.InRules.UDP[1].Any.Any)
|
assert.False(t, fw.InRules.UDP[1].Any.Any.Any)
|
||||||
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
|
assert.Contains(t, fw.InRules.UDP[1].Any.Any.Groups[0], "g1")
|
||||||
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
|
assert.Empty(t, fw.InRules.UDP[1].Any.Any.Hosts)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", ""))
|
||||||
assert.False(t, fw.InRules.ICMP[1].Any.Any)
|
assert.False(t, fw.InRules.ICMP[1].Any.Any.Any)
|
||||||
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
|
assert.Empty(t, fw.InRules.ICMP[1].Any.Any.Groups)
|
||||||
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
|
assert.Contains(t, fw.InRules.ICMP[1].Any.Any.Hosts, "h1")
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", ""))
|
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", ""))
|
||||||
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
|
assert.False(t, fw.OutRules.AnyProto[1].Any.Any.Any)
|
||||||
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
|
ok, _ := fw.OutRules.AnyProto[1].Any.Any.CIDR.GetCIDR(ti)
|
||||||
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
|
|
||||||
ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))
|
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
|
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
|
||||||
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
|
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||||
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
|
ok, fr := fw.OutRules.AnyProto[1].Any.LocalCIDR.GetCIDR(ti)
|
||||||
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
|
|
||||||
ok, _ = fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))
|
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
assert.True(t, fr.Any)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", ""))
|
||||||
@ -114,29 +111,28 @@ func TestFirewall_AddRule(t *testing.T) {
|
|||||||
// Set any and clear fields
|
// Set any and clear fields
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, ti, "", ""))
|
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, ti, "", ""))
|
||||||
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
|
ok, fr = fw.OutRules.AnyProto[0].Any.LocalCIDR.GetCIDR(ti)
|
||||||
assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
|
|
||||||
ok, _ = fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))
|
|
||||||
assert.True(t, ok)
|
|
||||||
ok, _ = fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))
|
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
assert.False(t, fr.Any)
|
||||||
|
assert.Equal(t, []string{"g1", "g2"}, fr.Groups[0])
|
||||||
|
assert.Contains(t, fr.Hosts, "h1")
|
||||||
|
|
||||||
// run twice just to make sure
|
// run twice just to make sure
|
||||||
//TODO: these ANY rules should clear the CA firewall portion
|
//TODO: these ANY rules should clear the CA firewall portion
|
||||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
|
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
|
||||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
|
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||||
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
|
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Any.Groups)
|
||||||
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
|
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Any.Hosts)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
|
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
|
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
|
||||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", ""))
|
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", ""))
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||||
|
|
||||||
// Test error conditions
|
// Test error conditions
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
@ -231,28 +227,37 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, n, _ := net.ParseCIDR("172.1.1.1/32")
|
_, n, _ := net.ParseCIDR("172.1.1.1/32")
|
||||||
_ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, n, "", "")
|
goodLocalCIDRIP := iputil.Ip2VpnIp(n.IP)
|
||||||
_ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, n, "", "")
|
_ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, nil, "", "")
|
||||||
_ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, n, "", "")
|
_ = ft.TCP.addRule(100, 100, []string{"good-group"}, "good-host", nil, n, "", "")
|
||||||
_ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, n, "", "")
|
|
||||||
_ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, n, "", "")
|
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
b.Run("fail on proto", func(b *testing.B) {
|
b.Run("fail on proto", func(b *testing.B) {
|
||||||
|
// This benchmark is showing us the cost of failing to match the protocol
|
||||||
c := &cert.NebulaCertificate{}
|
c := &cert.NebulaCertificate{}
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp)
|
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
b.Run("fail on port", func(b *testing.B) {
|
b.Run("pass proto, fail on port", func(b *testing.B) {
|
||||||
|
// This benchmark is showing us the cost of matching a specific protocol but failing to match the port
|
||||||
c := &cert.NebulaCertificate{}
|
c := &cert.NebulaCertificate{}
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp)
|
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
b.Run("fail all group, name, and cidr", func(b *testing.B) {
|
b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
|
||||||
|
c := &cert.NebulaCertificate{}
|
||||||
|
ip, _, _ := net.ParseCIDR("9.254.254.254/32")
|
||||||
|
lip := iputil.Ip2VpnIp(ip)
|
||||||
|
for n := 0; n < b.N; n++ {
|
||||||
|
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: lip}, true, c, cp))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
|
||||||
_, ip, _ := net.ParseCIDR("9.254.254.254/32")
|
_, ip, _ := net.ParseCIDR("9.254.254.254/32")
|
||||||
c := &cert.NebulaCertificate{
|
c := &cert.NebulaCertificate{
|
||||||
Details: cert.NebulaCertificateDetails{
|
Details: cert.NebulaCertificateDetails{
|
||||||
@ -262,11 +267,25 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
|
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
b.Run("pass on group", func(b *testing.B) {
|
b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
|
||||||
|
_, ip, _ := net.ParseCIDR("9.254.254.254/32")
|
||||||
|
c := &cert.NebulaCertificate{
|
||||||
|
Details: cert.NebulaCertificateDetails{
|
||||||
|
InvertedGroups: map[string]struct{}{"nope": {}},
|
||||||
|
Name: "nope",
|
||||||
|
Ips: []*net.IPNet{ip},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for n := 0; n < b.N; n++ {
|
||||||
|
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("pass on group on any local cidr", func(b *testing.B) {
|
||||||
c := &cert.NebulaCertificate{
|
c := &cert.NebulaCertificate{
|
||||||
Details: cert.NebulaCertificateDetails{
|
Details: cert.NebulaCertificateDetails{
|
||||||
InvertedGroups: map[string]struct{}{"good-group": {}},
|
InvertedGroups: map[string]struct{}{"good-group": {}},
|
||||||
@ -274,7 +293,19 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
|
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("pass on group on specific local cidr", func(b *testing.B) {
|
||||||
|
c := &cert.NebulaCertificate{
|
||||||
|
Details: cert.NebulaCertificateDetails{
|
||||||
|
InvertedGroups: map[string]struct{}{"good-group": {}},
|
||||||
|
Name: "nope",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for n := 0; n < b.N; n++ {
|
||||||
|
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -289,60 +320,60 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||||||
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
|
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
//
|
||||||
b.Run("pass on ip", func(b *testing.B) {
|
//b.Run("pass on ip", func(b *testing.B) {
|
||||||
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
|
// ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
|
||||||
c := &cert.NebulaCertificate{
|
// c := &cert.NebulaCertificate{
|
||||||
Details: cert.NebulaCertificateDetails{
|
// Details: cert.NebulaCertificateDetails{
|
||||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
// InvertedGroups: map[string]struct{}{"nope": {}},
|
||||||
Name: "good-host",
|
// Name: "good-host",
|
||||||
},
|
// },
|
||||||
}
|
// }
|
||||||
for n := 0; n < b.N; n++ {
|
// for n := 0; n < b.N; n++ {
|
||||||
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
|
// ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
|
||||||
}
|
// }
|
||||||
})
|
//})
|
||||||
|
//
|
||||||
b.Run("pass on local ip", func(b *testing.B) {
|
//b.Run("pass on local ip", func(b *testing.B) {
|
||||||
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
|
// ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
|
||||||
c := &cert.NebulaCertificate{
|
// c := &cert.NebulaCertificate{
|
||||||
Details: cert.NebulaCertificateDetails{
|
// Details: cert.NebulaCertificateDetails{
|
||||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
// InvertedGroups: map[string]struct{}{"nope": {}},
|
||||||
Name: "good-host",
|
// Name: "good-host",
|
||||||
},
|
// },
|
||||||
}
|
// }
|
||||||
for n := 0; n < b.N; n++ {
|
// for n := 0; n < b.N; n++ {
|
||||||
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp)
|
// ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp)
|
||||||
}
|
// }
|
||||||
})
|
//})
|
||||||
|
//
|
||||||
_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "")
|
//_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "")
|
||||||
|
//
|
||||||
b.Run("pass on ip with any port", func(b *testing.B) {
|
//b.Run("pass on ip with any port", func(b *testing.B) {
|
||||||
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
|
// ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
|
||||||
c := &cert.NebulaCertificate{
|
// c := &cert.NebulaCertificate{
|
||||||
Details: cert.NebulaCertificateDetails{
|
// Details: cert.NebulaCertificateDetails{
|
||||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
// InvertedGroups: map[string]struct{}{"nope": {}},
|
||||||
Name: "good-host",
|
// Name: "good-host",
|
||||||
},
|
// },
|
||||||
}
|
// }
|
||||||
for n := 0; n < b.N; n++ {
|
// for n := 0; n < b.N; n++ {
|
||||||
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
|
// ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
|
||||||
}
|
// }
|
||||||
})
|
//})
|
||||||
|
//
|
||||||
b.Run("pass on local ip with any port", func(b *testing.B) {
|
//b.Run("pass on local ip with any port", func(b *testing.B) {
|
||||||
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
|
// ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
|
||||||
c := &cert.NebulaCertificate{
|
// c := &cert.NebulaCertificate{
|
||||||
Details: cert.NebulaCertificateDetails{
|
// Details: cert.NebulaCertificateDetails{
|
||||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
// InvertedGroups: map[string]struct{}{"nope": {}},
|
||||||
Name: "good-host",
|
// Name: "good-host",
|
||||||
},
|
// },
|
||||||
}
|
// }
|
||||||
for n := 0; n < b.N; n++ {
|
// for n := 0; n < b.N; n++ {
|
||||||
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp)
|
// ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp)
|
||||||
}
|
// }
|
||||||
})
|
//})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop2(t *testing.T) {
|
func TestFirewall_Drop2(t *testing.T) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user