At the end

This commit is contained in:
Nate Brown
2024-01-29 15:30:52 -06:00
parent 8f44f22c37
commit f346cf4109
4 changed files with 143 additions and 131 deletions

View File

@@ -58,7 +58,9 @@ type Firewall struct {
DefaultTimeout time.Duration //linux: 600s
// Used to ensure we don't emit local packets for ips we don't own
localIps *cidr.Tree4[struct{}]
localIps *cidr.Tree4[struct{}]
assignedCIDR *net.IPNet
hasSubnets bool
rules string
rulesVersion uint16
@@ -103,17 +105,22 @@ func newFirewallTable() *FirewallTable {
}
type FirewallCA struct {
Any *firewallLocalCIDR
CANames map[string]*firewallLocalCIDR
CAShas map[string]*firewallLocalCIDR
Any *FirewallRule
CANames map[string]*FirewallRule
CAShas map[string]*FirewallRule
}
type FirewallRule struct {
// Any makes Hosts, Groups, and CIDR irrelevant
Any bool
Hosts map[string]struct{}
Groups [][]string
CIDR *cidr.Tree4[struct{}]
Any *firewallLocalCIDR
Hosts map[string]*firewallLocalCIDR
Groups []*firewallGroups
CIDR *cidr.Tree4[*firewallLocalCIDR]
}
type firewallGroups struct {
Groups []string
LocalCIDR *firewallLocalCIDR
}
// Even though ports are uint16, int32 maps are faster for lookup
@@ -121,8 +128,8 @@ type FirewallRule struct {
type firewallPort map[int32]*FirewallCA
type firewallLocalCIDR struct {
Any *FirewallRule
LocalCIDR *cidr.Tree4[*FirewallRule]
Any bool
LocalCIDR *cidr.Tree4[struct{}]
}
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
@@ -145,8 +152,15 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
}
localIps := cidr.NewTree4[struct{}]()
var assignedCIDR *net.IPNet
for _, ip := range c.Details.Ips {
localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
ipNet := &net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}
localIps.AddCIDR(ipNet, struct{}{})
if assignedCIDR == nil {
// Only grabbing the first one in the cert since any more than that currently has undefined behavior
assignedCIDR = ipNet
}
}
for _, n := range c.Details.Subnets {
@@ -164,6 +178,8 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
UDPTimeout: UDPTimeout,
DefaultTimeout: defaultTimeout,
localIps: localIps,
assignedCIDR: assignedCIDR,
hasSubnets: len(c.Details.Subnets) > 0,
l: l,
metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)),
@@ -276,7 +292,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
return fmt.Errorf("unknown protocol %v", proto)
}
return fp.addRule(startPort, endPort, groups, host, ip, localIp, caName, caSha)
return fp.addRule(f, startPort, endPort, groups, host, ip, localIp, caName, caSha)
}
// GetRuleHash returns a hash representation of all inbound and outbound rules
@@ -630,7 +646,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC
return false
}
func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
if startPort > endPort {
return fmt.Errorf("start port was lower than end port")
}
@@ -638,12 +654,12 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
for i := startPort; i <= endPort; i++ {
if _, ok := fp[i]; !ok {
fp[i] = &FirewallCA{
CANames: make(map[string]*firewallLocalCIDR),
CAShas: make(map[string]*firewallLocalCIDR),
CANames: make(map[string]*FirewallRule),
CAShas: make(map[string]*FirewallRule),
}
}
if err := fp[i].addRule(groups, host, ip, localIp, caName, caSha); err != nil {
if err := fp[i].addRule(f, groups, host, ip, localIp, caName, caSha); err != nil {
return err
}
}
@@ -674,26 +690,28 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer
return fp[firewall.PortAny].match(p, c, caPool)
}
func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
fl := func() *firewallLocalCIDR {
return &firewallLocalCIDR{
LocalCIDR: cidr.NewTree4[*FirewallRule](),
func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
fr := func() *FirewallRule {
return &FirewallRule{
Hosts: make(map[string]*firewallLocalCIDR),
Groups: make([]*firewallGroups, 0),
CIDR: cidr.NewTree4[*firewallLocalCIDR](),
}
}
if caSha == "" && caName == "" {
if fc.Any == nil {
fc.Any = fl()
fc.Any = fr()
}
return fc.Any.addRule(groups, host, ip, localIp)
return fc.Any.addRule(f, groups, host, ip, localIp)
}
if caSha != "" {
if _, ok := fc.CAShas[caSha]; !ok {
fc.CAShas[caSha] = fl()
fc.CAShas[caSha] = fr()
}
err := fc.CAShas[caSha].addRule(groups, host, ip, localIp)
err := fc.CAShas[caSha].addRule(f, groups, host, ip, localIp)
if err != nil {
return err
}
@@ -701,9 +719,9 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN
if caName != "" {
if _, ok := fc.CANames[caName]; !ok {
fc.CANames[caName] = fl()
fc.CANames[caName] = fr()
}
err := fc.CANames[caName].addRule(groups, host, ip, localIp)
err := fc.CANames[caName].addRule(f, groups, host, ip, localIp)
if err != nil {
return err
}
@@ -735,75 +753,56 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
return fc.CANames[s.Details.Name].match(p, c)
}
func (fc *firewallLocalCIDR) addRule(groups []string, host string, ip, localIp *net.IPNet) error {
fr := func() *FirewallRule {
return &FirewallRule{
Hosts: make(map[string]struct{}),
Groups: make([][]string, 0),
CIDR: cidr.NewTree4[struct{}](),
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *net.IPNet, localCIDR *net.IPNet) error {
flc := func() *firewallLocalCIDR {
return &firewallLocalCIDR{
LocalCIDR: cidr.NewTree4[struct{}](),
}
}
if localIp == nil || (localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0))) {
if fc.Any == nil {
fc.Any = fr()
}
return fc.Any.addRule(groups, host, ip)
}
_, efr := fc.LocalCIDR.GetCIDR(localIp)
if efr != nil {
return efr.addRule(groups, host, ip)
}
nfr := fr()
err := nfr.addRule(groups, host, ip)
if err != nil {
return err
}
fc.LocalCIDR.AddCIDR(localIp, nfr)
return nil
}
func (fc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
if fc == nil {
return false
}
if fc.Any.match(p, c) {
return true
}
return fc.LocalCIDR.EachContains(p.LocalIP, func(fr *FirewallRule) bool {
return fr.match(p, c)
})
}
func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error {
if fr.Any {
return nil
}
if fr.isAny(groups, host, ip) {
fr.Any = true
// If it's any we need to wipe out any pre-existing rules to save on memory
fr.Groups = make([][]string, 0)
fr.Hosts = make(map[string]struct{})
fr.CIDR = cidr.NewTree4[struct{}]()
} else {
if len(groups) > 0 {
fr.Groups = append(fr.Groups, groups)
if fr.Any == nil {
fr.Any = flc()
}
if host != "" {
fr.Hosts[host] = struct{}{}
return fr.Any.addRule(f, localCIDR)
}
if len(groups) > 0 {
nlc := flc()
err := nlc.addRule(f, localCIDR)
if err != nil {
return err
}
if ip != nil {
fr.CIDR.AddCIDR(ip, struct{}{})
fr.Groups = append(fr.Groups, &firewallGroups{
Groups: groups,
LocalCIDR: nlc,
})
}
if host != "" {
nlc := fr.Hosts[host]
if nlc == nil {
nlc = flc()
}
err := nlc.addRule(f, localCIDR)
if err != nil {
return err
}
fr.Hosts[host] = nlc
}
if ip != nil {
_, nlc := fr.CIDR.GetCIDR(ip)
if nlc == nil {
nlc = flc()
}
err := nlc.addRule(f, localCIDR)
if err != nil {
return err
}
fr.CIDR.AddCIDR(ip, nlc)
}
return nil
@@ -837,7 +836,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
}
// Shortcut path for if groups, hosts, or cidr contained an `any`
if fr.Any {
if fr.Any.match(p, c) {
return true
}
@@ -845,7 +844,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
for _, sg := range fr.Groups {
found := false
for _, g := range sg {
for _, g := range sg.Groups {
if _, ok := c.Details.InvertedGroups[g]; !ok {
found = false
break
@@ -854,26 +853,48 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
found = true
}
if found {
if found && sg.LocalCIDR.match(p, c) {
return true
}
}
if fr.Hosts != nil {
if _, ok := fr.Hosts[c.Details.Name]; ok {
return true
if flc, ok := fr.Hosts[c.Details.Name]; ok {
if flc.match(p, c) {
return true
}
}
}
if fr.CIDR != nil {
ok, _ := fr.CIDR.Contains(p.RemoteIP)
if ok {
return true
return fr.CIDR.EachContains(p.RemoteIP, func(flc *firewallLocalCIDR) bool {
return flc.match(p, c)
})
}
func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp *net.IPNet) error {
if localIp == nil || (localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0))) {
if !f.hasSubnets {
flc.Any = true
return nil
}
localIp = f.assignedCIDR
}
// No host, group, or cidr matched, bye bye
return false
flc.LocalCIDR.AddCIDR(localIp, struct{}{})
return nil
}
func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
if flc == nil {
return false
}
if flc.Any {
return true
}
ok, _ := flc.LocalCIDR.Contains(p.LocalIP)
return ok
}
type rule struct {