mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 00:15:37 +01:00
V2 certificate format (#1216)
Co-authored-by: Nate Brown <nbrown.us@gmail.com> Co-authored-by: Jack Doan <jackdoan@rivian.com> Co-authored-by: brad-defined <77982333+brad-defined@users.noreply.github.com> Co-authored-by: Jack Doan <me@jackdoan.com>
This commit is contained in:
121
firewall.go
121
firewall.go
@@ -22,7 +22,7 @@ import (
|
||||
)
|
||||
|
||||
type FirewallInterface interface {
|
||||
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error
|
||||
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
@@ -51,9 +51,12 @@ type Firewall struct {
|
||||
UDPTimeout time.Duration //linux: 180s max
|
||||
DefaultTimeout time.Duration //linux: 600s
|
||||
|
||||
// Used to ensure we don't emit local packets for ips we don't own
|
||||
localIps *bart.Table[struct{}]
|
||||
assignedCIDR netip.Prefix
|
||||
// routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
|
||||
// The vpn addresses are a full bit match while the unsafe networks only match the prefix
|
||||
routableNetworks *bart.Table[struct{}]
|
||||
|
||||
// assignedNetworks is a list of vpn networks assigned to us in the certificate.
|
||||
assignedNetworks []netip.Prefix
|
||||
hasUnsafeNetworks bool
|
||||
|
||||
rules string
|
||||
@@ -67,9 +70,9 @@ type Firewall struct {
|
||||
}
|
||||
|
||||
type firewallMetrics struct {
|
||||
droppedLocalIP metrics.Counter
|
||||
droppedRemoteIP metrics.Counter
|
||||
droppedNoRule metrics.Counter
|
||||
droppedLocalAddr metrics.Counter
|
||||
droppedRemoteAddr metrics.Counter
|
||||
droppedNoRule metrics.Counter
|
||||
}
|
||||
|
||||
type FirewallConntrack struct {
|
||||
@@ -126,84 +129,87 @@ type firewallLocalCIDR struct {
|
||||
}
|
||||
|
||||
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
||||
// The certificate provided should be the highest version loaded in memory.
|
||||
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
|
||||
//TODO: error on 0 duration
|
||||
var min, max time.Duration
|
||||
var tmin, tmax time.Duration
|
||||
|
||||
if tcpTimeout < UDPTimeout {
|
||||
min = tcpTimeout
|
||||
max = UDPTimeout
|
||||
tmin = tcpTimeout
|
||||
tmax = UDPTimeout
|
||||
} else {
|
||||
min = UDPTimeout
|
||||
max = tcpTimeout
|
||||
tmin = UDPTimeout
|
||||
tmax = tcpTimeout
|
||||
}
|
||||
|
||||
if defaultTimeout < min {
|
||||
min = defaultTimeout
|
||||
} else if defaultTimeout > max {
|
||||
max = defaultTimeout
|
||||
if defaultTimeout < tmin {
|
||||
tmin = defaultTimeout
|
||||
} else if defaultTimeout > tmax {
|
||||
tmax = defaultTimeout
|
||||
}
|
||||
|
||||
localIps := new(bart.Table[struct{}])
|
||||
var assignedCIDR netip.Prefix
|
||||
var assignedSet bool
|
||||
routableNetworks := new(bart.Table[struct{}])
|
||||
var assignedNetworks []netip.Prefix
|
||||
for _, network := range c.Networks() {
|
||||
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
||||
localIps.Insert(nprefix, struct{}{})
|
||||
|
||||
if !assignedSet {
|
||||
// Only grabbing the first one in the cert since any more than that currently has undefined behavior
|
||||
assignedCIDR = nprefix
|
||||
assignedSet = true
|
||||
}
|
||||
routableNetworks.Insert(nprefix, struct{}{})
|
||||
assignedNetworks = append(assignedNetworks, network)
|
||||
}
|
||||
|
||||
hasUnsafeNetworks := false
|
||||
for _, n := range c.UnsafeNetworks() {
|
||||
localIps.Insert(n, struct{}{})
|
||||
routableNetworks.Insert(n, struct{}{})
|
||||
hasUnsafeNetworks = true
|
||||
}
|
||||
|
||||
return &Firewall{
|
||||
Conntrack: &FirewallConntrack{
|
||||
Conns: make(map[firewall.Packet]*conn),
|
||||
TimerWheel: NewTimerWheel[firewall.Packet](min, max),
|
||||
TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax),
|
||||
},
|
||||
InRules: newFirewallTable(),
|
||||
OutRules: newFirewallTable(),
|
||||
TCPTimeout: tcpTimeout,
|
||||
UDPTimeout: UDPTimeout,
|
||||
DefaultTimeout: defaultTimeout,
|
||||
localIps: localIps,
|
||||
assignedCIDR: assignedCIDR,
|
||||
routableNetworks: routableNetworks,
|
||||
assignedNetworks: assignedNetworks,
|
||||
hasUnsafeNetworks: hasUnsafeNetworks,
|
||||
l: l,
|
||||
|
||||
incomingMetrics: firewallMetrics{
|
||||
droppedLocalIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil),
|
||||
droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_ip", nil),
|
||||
droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil),
|
||||
droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil),
|
||||
droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_addr", nil),
|
||||
droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil),
|
||||
},
|
||||
outgoingMetrics: firewallMetrics{
|
||||
droppedLocalIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_ip", nil),
|
||||
droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_ip", nil),
|
||||
droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil),
|
||||
droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_addr", nil),
|
||||
droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_addr", nil),
|
||||
droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewFirewallFromConfig(l *logrus.Logger, nc cert.Certificate, c *config.C) (*Firewall, error) {
|
||||
func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
|
||||
certificate := cs.getCertificate(cert.Version2)
|
||||
if certificate == nil {
|
||||
certificate = cs.getCertificate(cert.Version1)
|
||||
}
|
||||
|
||||
if certificate == nil {
|
||||
panic("No certificate available to reconfigure the firewall")
|
||||
}
|
||||
|
||||
fw := NewFirewall(
|
||||
l,
|
||||
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
|
||||
c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
|
||||
c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
|
||||
nc,
|
||||
certificate,
|
||||
//TODO: max_connections
|
||||
)
|
||||
|
||||
//TODO: Flip to false after v1.9 release
|
||||
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", true)
|
||||
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false)
|
||||
|
||||
inboundAction := c.GetString("firewall.inbound_action", "drop")
|
||||
switch inboundAction {
|
||||
@@ -283,7 +289,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
||||
fp = ft.TCP
|
||||
case firewall.ProtoUDP:
|
||||
fp = ft.UDP
|
||||
case firewall.ProtoICMP:
|
||||
case firewall.ProtoICMP, firewall.ProtoICMPv6:
|
||||
fp = ft.ICMP
|
||||
case firewall.ProtoAny:
|
||||
fp = ft.AnyProto
|
||||
@@ -424,26 +430,24 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
||||
}
|
||||
|
||||
// Make sure remote address matches nebula certificate
|
||||
if remoteCidr := h.remoteCidr; remoteCidr != nil {
|
||||
//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
|
||||
_, ok := remoteCidr.Lookup(fp.RemoteIP)
|
||||
if h.networks != nil {
|
||||
_, ok := h.networks.Lookup(fp.RemoteAddr)
|
||||
if !ok {
|
||||
f.metrics(incoming).droppedRemoteIP.Inc(1)
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
} else {
|
||||
// Simple case: Certificate has one IP and no subnets
|
||||
if fp.RemoteIP != h.vpnIp {
|
||||
f.metrics(incoming).droppedRemoteIP.Inc(1)
|
||||
// Simple case: Certificate has one address and no unsafe networks
|
||||
if h.vpnAddrs[0] != fp.RemoteAddr {
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure we are supposed to be handling this local ip address
|
||||
//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
|
||||
_, ok := f.localIps.Lookup(fp.LocalIP)
|
||||
_, ok := f.routableNetworks.Lookup(fp.LocalAddr)
|
||||
if !ok {
|
||||
f.metrics(incoming).droppedLocalIP.Inc(1)
|
||||
f.metrics(incoming).droppedLocalAddr.Inc(1)
|
||||
return ErrInvalidLocalIP
|
||||
}
|
||||
|
||||
@@ -629,7 +633,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedC
|
||||
if ft.UDP.match(p, incoming, c, caPool) {
|
||||
return true
|
||||
}
|
||||
case firewall.ProtoICMP:
|
||||
case firewall.ProtoICMP, firewall.ProtoICMPv6:
|
||||
if ft.ICMP.match(p, incoming, c, caPool) {
|
||||
return true
|
||||
}
|
||||
@@ -859,9 +863,9 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
|
||||
}
|
||||
|
||||
matched := false
|
||||
prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen())
|
||||
prefix := netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())
|
||||
fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool {
|
||||
if prefix.Contains(p.RemoteIP) && val.match(p, c) {
|
||||
if prefix.Contains(p.RemoteAddr) && val.match(p, c) {
|
||||
matched = true
|
||||
return false
|
||||
}
|
||||
@@ -877,9 +881,14 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
localIp = f.assignedCIDR
|
||||
for _, network := range f.assignedNetworks {
|
||||
flc.LocalCIDR.Insert(network, struct{}{})
|
||||
}
|
||||
return nil
|
||||
|
||||
} else if localIp.Bits() == 0 {
|
||||
flc.Any = true
|
||||
return nil
|
||||
}
|
||||
|
||||
flc.LocalCIDR.Insert(localIp, struct{}{})
|
||||
@@ -895,7 +904,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate
|
||||
return true
|
||||
}
|
||||
|
||||
_, ok := flc.LocalCIDR.Lookup(p.LocalIP)
|
||||
_, ok := flc.LocalCIDR.Lookup(p.LocalAddr)
|
||||
return ok
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user