mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 08:24:25 +01:00
In the middle
This commit is contained in:
113
firewall.go
113
firewall.go
@@ -84,6 +84,8 @@ type FirewallConntrack struct {
|
||||
TimerWheel *TimerWheel[firewall.Packet]
|
||||
}
|
||||
|
||||
// FirewallTable is the entry point for a rule, the evaluation order is:
|
||||
// Proto AND port AND (CA SHA or CA name) AND local CIDR AND (group OR groups OR name OR remote CIDR)
|
||||
type FirewallTable struct {
|
||||
TCP firewallPort
|
||||
UDP firewallPort
|
||||
@@ -101,24 +103,28 @@ func newFirewallTable() *FirewallTable {
|
||||
}
|
||||
|
||||
type FirewallCA struct {
|
||||
Any *FirewallRule
|
||||
CANames map[string]*FirewallRule
|
||||
CAShas map[string]*FirewallRule
|
||||
Any *firewallLocalCIDR
|
||||
CANames map[string]*firewallLocalCIDR
|
||||
CAShas map[string]*firewallLocalCIDR
|
||||
}
|
||||
|
||||
type FirewallRule struct {
|
||||
// Any makes Hosts, Groups, CIDR and LocalCIDR irrelevant
|
||||
Any bool
|
||||
Hosts map[string]struct{}
|
||||
Groups [][]string
|
||||
CIDR *cidr.Tree4[struct{}]
|
||||
LocalCIDR *cidr.Tree4[struct{}]
|
||||
// Any makes Hosts, Groups, and CIDR irrelevant
|
||||
Any bool
|
||||
Hosts map[string]struct{}
|
||||
Groups [][]string
|
||||
CIDR *cidr.Tree4[struct{}]
|
||||
}
|
||||
|
||||
// Even though ports are uint16, int32 maps are faster for lookup
|
||||
// Plus we can use `-1` for fragment rules
|
||||
type firewallPort map[int32]*FirewallCA
|
||||
|
||||
type firewallLocalCIDR struct {
|
||||
Any *FirewallRule
|
||||
LocalCIDR *cidr.Tree4[*FirewallRule]
|
||||
}
|
||||
|
||||
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
||||
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
|
||||
//TODO: error on 0 duration
|
||||
@@ -632,8 +638,8 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
|
||||
for i := startPort; i <= endPort; i++ {
|
||||
if _, ok := fp[i]; !ok {
|
||||
fp[i] = &FirewallCA{
|
||||
CANames: make(map[string]*FirewallRule),
|
||||
CAShas: make(map[string]*FirewallRule),
|
||||
CANames: make(map[string]*firewallLocalCIDR),
|
||||
CAShas: make(map[string]*firewallLocalCIDR),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -669,18 +675,15 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer
|
||||
}
|
||||
|
||||
func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
|
||||
fr := func() *FirewallRule {
|
||||
return &FirewallRule{
|
||||
Hosts: make(map[string]struct{}),
|
||||
Groups: make([][]string, 0),
|
||||
CIDR: cidr.NewTree4[struct{}](),
|
||||
LocalCIDR: cidr.NewTree4[struct{}](),
|
||||
fl := func() *firewallLocalCIDR {
|
||||
return &firewallLocalCIDR{
|
||||
LocalCIDR: cidr.NewTree4[*FirewallRule](),
|
||||
}
|
||||
}
|
||||
|
||||
if caSha == "" && caName == "" {
|
||||
if fc.Any == nil {
|
||||
fc.Any = fr()
|
||||
fc.Any = fl()
|
||||
}
|
||||
|
||||
return fc.Any.addRule(groups, host, ip, localIp)
|
||||
@@ -688,7 +691,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN
|
||||
|
||||
if caSha != "" {
|
||||
if _, ok := fc.CAShas[caSha]; !ok {
|
||||
fc.CAShas[caSha] = fr()
|
||||
fc.CAShas[caSha] = fl()
|
||||
}
|
||||
err := fc.CAShas[caSha].addRule(groups, host, ip, localIp)
|
||||
if err != nil {
|
||||
@@ -698,7 +701,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN
|
||||
|
||||
if caName != "" {
|
||||
if _, ok := fc.CANames[caName]; !ok {
|
||||
fc.CANames[caName] = fr()
|
||||
fc.CANames[caName] = fl()
|
||||
}
|
||||
err := fc.CANames[caName].addRule(groups, host, ip, localIp)
|
||||
if err != nil {
|
||||
@@ -732,18 +735,63 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
|
||||
return fc.CANames[s.Details.Name].match(p, c)
|
||||
}
|
||||
|
||||
func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, localIp *net.IPNet) error {
|
||||
func (fc *firewallLocalCIDR) addRule(groups []string, host string, ip, localIp *net.IPNet) error {
|
||||
fr := func() *FirewallRule {
|
||||
return &FirewallRule{
|
||||
Hosts: make(map[string]struct{}),
|
||||
Groups: make([][]string, 0),
|
||||
CIDR: cidr.NewTree4[struct{}](),
|
||||
}
|
||||
}
|
||||
|
||||
if localIp == nil || (localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0))) {
|
||||
if fc.Any == nil {
|
||||
fc.Any = fr()
|
||||
}
|
||||
|
||||
return fc.Any.addRule(groups, host, ip)
|
||||
}
|
||||
|
||||
_, efr := fc.LocalCIDR.GetCIDR(localIp)
|
||||
if efr != nil {
|
||||
return efr.addRule(groups, host, ip)
|
||||
}
|
||||
|
||||
nfr := fr()
|
||||
err := nfr.addRule(groups, host, ip)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fc.LocalCIDR.AddCIDR(localIp, nfr)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
|
||||
if fc == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if fc.Any.match(p, c) {
|
||||
return true
|
||||
}
|
||||
|
||||
return fc.LocalCIDR.EachContains(p.LocalIP, func(fr *FirewallRule) bool {
|
||||
return fr.match(p, c)
|
||||
})
|
||||
}
|
||||
|
||||
func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error {
|
||||
if fr.Any {
|
||||
return nil
|
||||
}
|
||||
|
||||
if fr.isAny(groups, host, ip, localIp) {
|
||||
if fr.isAny(groups, host, ip) {
|
||||
fr.Any = true
|
||||
// If it's any we need to wipe out any pre-existing rules to save on memory
|
||||
fr.Groups = make([][]string, 0)
|
||||
fr.Hosts = make(map[string]struct{})
|
||||
fr.CIDR = cidr.NewTree4[struct{}]()
|
||||
fr.LocalCIDR = cidr.NewTree4[struct{}]()
|
||||
} else {
|
||||
if len(groups) > 0 {
|
||||
fr.Groups = append(fr.Groups, groups)
|
||||
@@ -756,17 +804,13 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, loc
|
||||
if ip != nil {
|
||||
fr.CIDR.AddCIDR(ip, struct{}{})
|
||||
}
|
||||
|
||||
if localIp != nil {
|
||||
fr.LocalCIDR.AddCIDR(localIp, struct{}{})
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fr *FirewallRule) isAny(groups []string, host string, ip, localIp *net.IPNet) bool {
|
||||
if len(groups) == 0 && host == "" && ip == nil && localIp == nil {
|
||||
func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
|
||||
if len(groups) == 0 && host == "" && ip == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -784,10 +828,6 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip, localIp *net.IPN
|
||||
return true
|
||||
}
|
||||
|
||||
if localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0)) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -832,13 +872,6 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
|
||||
}
|
||||
}
|
||||
|
||||
if fr.LocalCIDR != nil {
|
||||
ok, _ := fr.LocalCIDR.Contains(p.LocalIP)
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// No host, group, or cidr matched, bye bye
|
||||
return false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user