mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-23 00:44:25 +01:00
Use generics for CIDRTrees to avoid casting issues (#1004)
This commit is contained in:
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
type AllowList struct {
|
||||
// The values of this cidrTree are `bool`, signifying allow/deny
|
||||
cidrTree *cidr.Tree6
|
||||
cidrTree *cidr.Tree6[bool]
|
||||
}
|
||||
|
||||
type RemoteAllowList struct {
|
||||
@@ -20,7 +20,7 @@ type RemoteAllowList struct {
|
||||
|
||||
// Inside Range Specific, keys of this tree are inside CIDRs and values
|
||||
// are *AllowList
|
||||
insideAllowLists *cidr.Tree6
|
||||
insideAllowLists *cidr.Tree6[*AllowList]
|
||||
}
|
||||
|
||||
type LocalAllowList struct {
|
||||
@@ -88,7 +88,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
|
||||
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
|
||||
}
|
||||
|
||||
tree := cidr.NewTree6()
|
||||
tree := cidr.NewTree6[bool]()
|
||||
|
||||
// Keep track of the rules we have added for both ipv4 and ipv6
|
||||
type allowListRules struct {
|
||||
@@ -218,13 +218,13 @@ func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error
|
||||
return nameRules, nil
|
||||
}
|
||||
|
||||
func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6, error) {
|
||||
func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error) {
|
||||
value := c.Get(k)
|
||||
if value == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
remoteAllowRanges := cidr.NewTree6()
|
||||
remoteAllowRanges := cidr.NewTree6[*AllowList]()
|
||||
|
||||
rawMap, ok := value.(map[interface{}]interface{})
|
||||
if !ok {
|
||||
@@ -257,13 +257,8 @@ func (al *AllowList) Allow(ip net.IP) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
result := al.cidrTree.MostSpecificContains(ip)
|
||||
switch v := result.(type) {
|
||||
case bool:
|
||||
return v
|
||||
default:
|
||||
panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
|
||||
}
|
||||
_, result := al.cidrTree.MostSpecificContains(ip)
|
||||
return result
|
||||
}
|
||||
|
||||
func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool {
|
||||
@@ -271,13 +266,8 @@ func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
result := al.cidrTree.MostSpecificContainsIpV4(ip)
|
||||
switch v := result.(type) {
|
||||
case bool:
|
||||
return v
|
||||
default:
|
||||
panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
|
||||
}
|
||||
_, result := al.cidrTree.MostSpecificContainsIpV4(ip)
|
||||
return result
|
||||
}
|
||||
|
||||
func (al *AllowList) AllowIpV6(hi, lo uint64) bool {
|
||||
@@ -285,13 +275,8 @@ func (al *AllowList) AllowIpV6(hi, lo uint64) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
result := al.cidrTree.MostSpecificContainsIpV6(hi, lo)
|
||||
switch v := result.(type) {
|
||||
case bool:
|
||||
return v
|
||||
default:
|
||||
panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
|
||||
}
|
||||
_, result := al.cidrTree.MostSpecificContainsIpV6(hi, lo)
|
||||
return result
|
||||
}
|
||||
|
||||
func (al *LocalAllowList) Allow(ip net.IP) bool {
|
||||
@@ -352,9 +337,9 @@ func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool {
|
||||
|
||||
func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList {
|
||||
if al.insideAllowLists != nil {
|
||||
inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
|
||||
if inside != nil {
|
||||
return inside.(*AllowList)
|
||||
ok, inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
|
||||
if ok {
|
||||
return inside
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user