Use generics for CIDRTrees to avoid casting issues (#1004)

This commit is contained in:
Nate Brown
2023-11-02 17:05:08 -05:00
committed by GitHub
parent a44e1b8b05
commit 5181cb0474
21 changed files with 264 additions and 247 deletions

View File

@@ -57,7 +57,7 @@ type Firewall struct {
DefaultTimeout time.Duration //linux: 600s
// Used to ensure we don't emit local packets for ips we don't own
localIps *cidr.Tree4
localIps *cidr.Tree4[struct{}]
rules string
rulesVersion uint16
@@ -110,8 +110,8 @@ type FirewallRule struct {
Any bool
Hosts map[string]struct{}
Groups [][]string
CIDR *cidr.Tree4
LocalCIDR *cidr.Tree4
CIDR *cidr.Tree4[struct{}]
LocalCIDR *cidr.Tree4[struct{}]
}
// Even though ports are uint16, int32 maps are faster for lookup
@@ -137,7 +137,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
max = defaultTimeout
}
localIps := cidr.NewTree4()
localIps := cidr.NewTree4[struct{}]()
for _, ip := range c.Details.Ips {
localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
}
@@ -391,7 +391,8 @@ func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *Hos
// Make sure remote address matches nebula certificate
if remoteCidr := h.remoteCidr; remoteCidr != nil {
if remoteCidr.Contains(fp.RemoteIP) == nil {
ok, _ := remoteCidr.Contains(fp.RemoteIP)
if !ok {
f.metrics(incoming).droppedRemoteIP.Inc(1)
return ErrInvalidRemoteIP
}
@@ -404,7 +405,8 @@ func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *Hos
}
// Make sure we are supposed to be handling this local ip address
if f.localIps.Contains(fp.LocalIP) == nil {
ok, _ := f.localIps.Contains(fp.LocalIP)
if !ok {
f.metrics(incoming).droppedLocalIP.Inc(1)
return ErrInvalidLocalIP
}
@@ -657,8 +659,8 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN
return &FirewallRule{
Hosts: make(map[string]struct{}),
Groups: make([][]string, 0),
CIDR: cidr.NewTree4(),
LocalCIDR: cidr.NewTree4(),
CIDR: cidr.NewTree4[struct{}](),
LocalCIDR: cidr.NewTree4[struct{}](),
}
}
@@ -726,8 +728,8 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, loc
// If it's any we need to wipe out any pre-existing rules to save on memory
fr.Groups = make([][]string, 0)
fr.Hosts = make(map[string]struct{})
fr.CIDR = cidr.NewTree4()
fr.LocalCIDR = cidr.NewTree4()
fr.CIDR = cidr.NewTree4[struct{}]()
fr.LocalCIDR = cidr.NewTree4[struct{}]()
} else {
if len(groups) > 0 {
fr.Groups = append(fr.Groups, groups)
@@ -809,12 +811,18 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
}
}
if fr.CIDR != nil && fr.CIDR.Contains(p.RemoteIP) != nil {
return true
if fr.CIDR != nil {
ok, _ := fr.CIDR.Contains(p.RemoteIP)
if ok {
return true
}
}
if fr.LocalCIDR != nil && fr.LocalCIDR.Contains(p.LocalIP) != nil {
return true
if fr.LocalCIDR != nil {
ok, _ := fr.LocalCIDR.Contains(p.LocalIP)
if ok {
return true
}
}
// No host, group, or cidr matched, bye bye