Cert interface (#1212)

This commit is contained in:
Nate Brown
2024-10-10 18:00:22 -05:00
committed by GitHub
parent 16eaae306a
commit 08ac65362e
49 changed files with 2862 additions and 2833 deletions

View File

@@ -52,9 +52,9 @@ type Firewall struct {
DefaultTimeout time.Duration //linux: 600s
// Used to ensure we don't emit local packets for ips we don't own
localIps *bart.Table[struct{}]
assignedCIDR netip.Prefix
hasSubnets bool
localIps *bart.Table[struct{}]
assignedCIDR netip.Prefix
hasUnsafeNetworks bool
rules string
rulesVersion uint16
@@ -126,7 +126,7 @@ type firewallLocalCIDR struct {
}
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
//TODO: error on 0 duration
var min, max time.Duration
@@ -147,11 +147,8 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
localIps := new(bart.Table[struct{}])
var assignedCIDR netip.Prefix
var assignedSet bool
for _, ip := range c.Details.Ips {
//TODO: IPV6-WORK the unmap is a bit unfortunate
nip, _ := netip.AddrFromSlice(ip.IP)
nip = nip.Unmap()
nprefix := netip.PrefixFrom(nip, nip.BitLen())
for _, network := range c.Networks() {
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
localIps.Insert(nprefix, struct{}{})
if !assignedSet {
@@ -161,11 +158,10 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
}
}
for _, n := range c.Details.Subnets {
nip, _ := netip.AddrFromSlice(n.IP)
ones, _ := n.Mask.Size()
nip = nip.Unmap()
localIps.Insert(netip.PrefixFrom(nip, ones), struct{}{})
hasUnsafeNetworks := false
for _, n := range c.UnsafeNetworks() {
localIps.Insert(n, struct{}{})
hasUnsafeNetworks = true
}
return &Firewall{
@@ -173,15 +169,15 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
Conns: make(map[firewall.Packet]*conn),
TimerWheel: NewTimerWheel[firewall.Packet](min, max),
},
InRules: newFirewallTable(),
OutRules: newFirewallTable(),
TCPTimeout: tcpTimeout,
UDPTimeout: UDPTimeout,
DefaultTimeout: defaultTimeout,
localIps: localIps,
assignedCIDR: assignedCIDR,
hasSubnets: len(c.Details.Subnets) > 0,
l: l,
InRules: newFirewallTable(),
OutRules: newFirewallTable(),
TCPTimeout: tcpTimeout,
UDPTimeout: UDPTimeout,
DefaultTimeout: defaultTimeout,
localIps: localIps,
assignedCIDR: assignedCIDR,
hasUnsafeNetworks: hasUnsafeNetworks,
l: l,
incomingMetrics: firewallMetrics{
droppedLocalIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil),
@@ -196,7 +192,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
}
}
func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *config.C) (*Firewall, error) {
func NewFirewallFromConfig(l *logrus.Logger, nc cert.Certificate, c *config.C) (*Firewall, error) {
fw := NewFirewall(
l,
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
@@ -421,7 +417,7 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
// Drop returns an error if the packet should be dropped, explaining why. It
// returns nil if the packet should not be dropped.
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error {
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
// Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(fp, h, caPool, localCache) {
return nil
@@ -492,7 +488,7 @@ func (f *Firewall) EmitStats() {
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
}
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool {
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
if localCache != nil {
if _, ok := localCache[fp]; ok {
return true
@@ -619,7 +615,7 @@ func (f *Firewall) evict(p firewall.Packet) {
delete(conntrack.Conns, p)
}
func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedCertificate, caPool *cert.CAPool) bool {
if ft.AnyProto.match(p, incoming, c, caPool) {
return true
}
@@ -663,7 +659,7 @@ func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, grou
return nil
}
func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCertificate, caPool *cert.CAPool) bool {
// We don't have any allowed ports, bail
if fp == nil {
return false
@@ -726,7 +722,7 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, loc
return nil
}
func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool *cert.CAPool) bool {
if fc == nil {
return false
}
@@ -735,18 +731,18 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
return true
}
if t, ok := fc.CAShas[c.Details.Issuer]; ok {
if t, ok := fc.CAShas[c.Certificate.Issuer()]; ok {
if t.match(p, c) {
return true
}
}
s, err := caPool.GetCAForCert(c)
s, err := caPool.GetCAForCert(c.Certificate)
if err != nil {
return false
}
return fc.CANames[s.Details.Name].match(p, c)
return fc.CANames[s.Certificate.Name()].match(p, c)
}
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
@@ -826,7 +822,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) boo
return false
}
func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool {
if fr == nil {
return false
}
@@ -841,7 +837,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
found := false
for _, g := range sg.Groups {
if _, ok := c.Details.InvertedGroups[g]; !ok {
if _, ok := c.InvertedGroups[g]; !ok {
found = false
break
}
@@ -855,7 +851,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
}
if fr.Hosts != nil {
if flc, ok := fr.Hosts[c.Details.Name]; ok {
if flc, ok := fr.Hosts[c.Certificate.Name()]; ok {
if flc.match(p, c) {
return true
}
@@ -876,7 +872,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
if !localIp.IsValid() {
if !f.hasSubnets || f.defaultLocalCIDRAny {
if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
flc.Any = true
return nil
}
@@ -890,7 +886,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
return nil
}
func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate) bool {
if flc == nil {
return false
}