mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 08:24:25 +01:00
Use generics for CIDRTrees to avoid casting issues (#1004)
This commit is contained in:
36
firewall.go
36
firewall.go
@@ -57,7 +57,7 @@ type Firewall struct {
|
||||
DefaultTimeout time.Duration //linux: 600s
|
||||
|
||||
// Used to ensure we don't emit local packets for ips we don't own
|
||||
localIps *cidr.Tree4
|
||||
localIps *cidr.Tree4[struct{}]
|
||||
|
||||
rules string
|
||||
rulesVersion uint16
|
||||
@@ -110,8 +110,8 @@ type FirewallRule struct {
|
||||
Any bool
|
||||
Hosts map[string]struct{}
|
||||
Groups [][]string
|
||||
CIDR *cidr.Tree4
|
||||
LocalCIDR *cidr.Tree4
|
||||
CIDR *cidr.Tree4[struct{}]
|
||||
LocalCIDR *cidr.Tree4[struct{}]
|
||||
}
|
||||
|
||||
// Even though ports are uint16, int32 maps are faster for lookup
|
||||
@@ -137,7 +137,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
||||
max = defaultTimeout
|
||||
}
|
||||
|
||||
localIps := cidr.NewTree4()
|
||||
localIps := cidr.NewTree4[struct{}]()
|
||||
for _, ip := range c.Details.Ips {
|
||||
localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
|
||||
}
|
||||
@@ -391,7 +391,8 @@ func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *Hos
|
||||
|
||||
// Make sure remote address matches nebula certificate
|
||||
if remoteCidr := h.remoteCidr; remoteCidr != nil {
|
||||
if remoteCidr.Contains(fp.RemoteIP) == nil {
|
||||
ok, _ := remoteCidr.Contains(fp.RemoteIP)
|
||||
if !ok {
|
||||
f.metrics(incoming).droppedRemoteIP.Inc(1)
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
@@ -404,7 +405,8 @@ func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *Hos
|
||||
}
|
||||
|
||||
// Make sure we are supposed to be handling this local ip address
|
||||
if f.localIps.Contains(fp.LocalIP) == nil {
|
||||
ok, _ := f.localIps.Contains(fp.LocalIP)
|
||||
if !ok {
|
||||
f.metrics(incoming).droppedLocalIP.Inc(1)
|
||||
return ErrInvalidLocalIP
|
||||
}
|
||||
@@ -657,8 +659,8 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN
|
||||
return &FirewallRule{
|
||||
Hosts: make(map[string]struct{}),
|
||||
Groups: make([][]string, 0),
|
||||
CIDR: cidr.NewTree4(),
|
||||
LocalCIDR: cidr.NewTree4(),
|
||||
CIDR: cidr.NewTree4[struct{}](),
|
||||
LocalCIDR: cidr.NewTree4[struct{}](),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -726,8 +728,8 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, loc
|
||||
// If it's any we need to wipe out any pre-existing rules to save on memory
|
||||
fr.Groups = make([][]string, 0)
|
||||
fr.Hosts = make(map[string]struct{})
|
||||
fr.CIDR = cidr.NewTree4()
|
||||
fr.LocalCIDR = cidr.NewTree4()
|
||||
fr.CIDR = cidr.NewTree4[struct{}]()
|
||||
fr.LocalCIDR = cidr.NewTree4[struct{}]()
|
||||
} else {
|
||||
if len(groups) > 0 {
|
||||
fr.Groups = append(fr.Groups, groups)
|
||||
@@ -809,12 +811,18 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
|
||||
}
|
||||
}
|
||||
|
||||
if fr.CIDR != nil && fr.CIDR.Contains(p.RemoteIP) != nil {
|
||||
return true
|
||||
if fr.CIDR != nil {
|
||||
ok, _ := fr.CIDR.Contains(p.RemoteIP)
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if fr.LocalCIDR != nil && fr.LocalCIDR.Contains(p.LocalIP) != nil {
|
||||
return true
|
||||
if fr.LocalCIDR != nil {
|
||||
ok, _ := fr.LocalCIDR.Contains(p.LocalIP)
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// No host, group, or cidr matched, bye bye
|
||||
|
||||
Reference in New Issue
Block a user