mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 16:34:25 +01:00
Add forward tables to handle unsafe network packets distinctly from vpn network packets
This commit is contained in:
134
firewall.go
134
firewall.go
@@ -23,16 +23,15 @@ import (
|
||||
)
|
||||
|
||||
type FirewallInterface interface {
|
||||
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error
|
||||
AddRule(forward, incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
Expires time.Time // Time when this conntrack entry will expire
|
||||
|
||||
// record why the original connection passed the firewall, so we can re-validate
|
||||
// after ruleset changes. Note, rulesVersion is a uint16 so that these two
|
||||
// fields pack for free after the uint32 above
|
||||
// record why the original connection passed the firewall, so we can re-validate after ruleset changes.
|
||||
incoming bool
|
||||
forward bool
|
||||
rulesVersion uint16
|
||||
}
|
||||
|
||||
@@ -40,8 +39,10 @@ type conn struct {
|
||||
type Firewall struct {
|
||||
Conntrack *FirewallConntrack
|
||||
|
||||
InRules *FirewallTable
|
||||
OutRules *FirewallTable
|
||||
InRules *FirewallTable
|
||||
OutRules *FirewallTable
|
||||
ForwardInRules *FirewallTable
|
||||
ForwardOutRules *FirewallTable
|
||||
|
||||
InSendReject bool
|
||||
OutSendReject bool
|
||||
@@ -54,7 +55,7 @@ type Firewall struct {
|
||||
|
||||
// routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
|
||||
// The vpn addresses are a full bit match while the unsafe networks only match the prefix
|
||||
routableNetworks *bart.Lite
|
||||
routableNetworks *bart.Table[NetworkType]
|
||||
|
||||
// assignedNetworks is a list of vpn networks assigned to us in the certificate.
|
||||
assignedNetworks []netip.Prefix
|
||||
@@ -149,17 +150,16 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
||||
tmax = defaultTimeout
|
||||
}
|
||||
|
||||
routableNetworks := new(bart.Lite)
|
||||
routableNetworks := new(bart.Table[NetworkType])
|
||||
var assignedNetworks []netip.Prefix
|
||||
for _, network := range c.Networks() {
|
||||
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
||||
routableNetworks.Insert(nprefix)
|
||||
routableNetworks.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), NetworkTypeVPN)
|
||||
assignedNetworks = append(assignedNetworks, network)
|
||||
}
|
||||
|
||||
hasUnsafeNetworks := false
|
||||
for _, n := range c.UnsafeNetworks() {
|
||||
routableNetworks.Insert(n)
|
||||
routableNetworks.Insert(n, NetworkTypeUnsafe)
|
||||
hasUnsafeNetworks = true
|
||||
}
|
||||
|
||||
@@ -170,6 +170,8 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
||||
},
|
||||
InRules: newFirewallTable(),
|
||||
OutRules: newFirewallTable(),
|
||||
ForwardInRules: newFirewallTable(),
|
||||
ForwardOutRules: newFirewallTable(),
|
||||
TCPTimeout: tcpTimeout,
|
||||
UDPTimeout: UDPTimeout,
|
||||
DefaultTimeout: defaultTimeout,
|
||||
@@ -212,6 +214,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
|
||||
|
||||
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false)
|
||||
|
||||
//TODO: do we also need firewall.forward_inbound_action and firewall.forward_outbound_action?
|
||||
inboundAction := c.GetString("firewall.inbound_action", "drop")
|
||||
switch inboundAction {
|
||||
case "reject":
|
||||
@@ -234,12 +237,22 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
|
||||
fw.OutSendReject = false
|
||||
}
|
||||
|
||||
err := AddFirewallRulesFromConfig(l, false, c, fw)
|
||||
err := AddFirewallRulesFromConfig(l, false, false, c, fw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = AddFirewallRulesFromConfig(l, true, c, fw)
|
||||
err = AddFirewallRulesFromConfig(l, true, false, c, fw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = AddFirewallRulesFromConfig(l, false, true, c, fw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = AddFirewallRulesFromConfig(l, true, true, c, fw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -248,11 +261,11 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
|
||||
}
|
||||
|
||||
// AddRule properly creates the in memory rule structure for a firewall table.
|
||||
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error {
|
||||
func (f *Firewall) AddRule(forward, incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error {
|
||||
// We need this rule string because we generate a hash. Removing this will break firewall reload.
|
||||
ruleString := fmt.Sprintf(
|
||||
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
|
||||
incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha,
|
||||
"forward: %v, incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
|
||||
forward, incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha,
|
||||
)
|
||||
f.rules += ruleString + "\n"
|
||||
|
||||
@@ -260,8 +273,12 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
||||
if !incoming {
|
||||
direction = "outgoing"
|
||||
}
|
||||
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}).
|
||||
Info("Firewall rule added")
|
||||
|
||||
fields := m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}
|
||||
if forward {
|
||||
fields["forward"] = true
|
||||
}
|
||||
f.l.WithField("firewallRule", fields).Info("Firewall rule added")
|
||||
|
||||
var (
|
||||
ft *FirewallTable
|
||||
@@ -269,9 +286,18 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
||||
)
|
||||
|
||||
if incoming {
|
||||
ft = f.InRules
|
||||
if forward {
|
||||
ft = f.ForwardInRules
|
||||
} else {
|
||||
ft = f.InRules
|
||||
}
|
||||
|
||||
} else {
|
||||
ft = f.OutRules
|
||||
if forward {
|
||||
ft = f.ForwardOutRules
|
||||
} else {
|
||||
ft = f.OutRules
|
||||
}
|
||||
}
|
||||
|
||||
switch proto {
|
||||
@@ -308,12 +334,21 @@ func (f *Firewall) GetRuleHashes() string {
|
||||
return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10)
|
||||
}
|
||||
|
||||
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
|
||||
func AddFirewallRulesFromConfig(l *logrus.Logger, forward, inbound bool, c *config.C, fw FirewallInterface) error {
|
||||
var table string
|
||||
if inbound {
|
||||
table = "firewall.inbound"
|
||||
if forward {
|
||||
table = "firewall.forward_inbound"
|
||||
} else {
|
||||
table = "firewall.inbound"
|
||||
}
|
||||
|
||||
} else {
|
||||
table = "firewall.outbound"
|
||||
if forward {
|
||||
table = "firewall.forward_outbound"
|
||||
} else {
|
||||
table = "firewall.outbound"
|
||||
}
|
||||
}
|
||||
|
||||
r := c.Get(table)
|
||||
@@ -386,7 +421,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
||||
l.Warnf("%s rule #%v; %s", table, i, warning)
|
||||
}
|
||||
|
||||
err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
|
||||
err = fw.AddRule(forward, inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
|
||||
}
|
||||
@@ -409,6 +444,9 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
||||
return nil
|
||||
}
|
||||
|
||||
var remoteNetworkType NetworkType
|
||||
var ok bool
|
||||
|
||||
// Make sure remote address matches nebula certificate, and determine how to treat it
|
||||
if h.networks == nil {
|
||||
// Simple case: Certificate has one address and no unsafe networks
|
||||
@@ -416,13 +454,14 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
remoteNetworkType = NetworkTypeVPN
|
||||
} else {
|
||||
nwType, ok := h.networks.Lookup(fp.RemoteAddr)
|
||||
remoteNetworkType, ok = h.networks.Lookup(fp.RemoteAddr)
|
||||
if !ok {
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
switch nwType {
|
||||
switch remoteNetworkType {
|
||||
case NetworkTypeVPN:
|
||||
break // nothing special
|
||||
case NetworkTypeVPNPeer:
|
||||
@@ -437,14 +476,27 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
||||
}
|
||||
|
||||
// Make sure we are supposed to be handling this local ip address
|
||||
if !f.routableNetworks.Contains(fp.LocalAddr) {
|
||||
localNetworkType, ok := f.routableNetworks.Lookup(fp.LocalAddr)
|
||||
if !ok {
|
||||
f.metrics(incoming).droppedLocalAddr.Inc(1)
|
||||
return ErrInvalidLocalIP
|
||||
}
|
||||
|
||||
table := f.OutRules
|
||||
useForward := remoteNetworkType == NetworkTypeUnsafe || localNetworkType == NetworkTypeUnsafe
|
||||
|
||||
var table *FirewallTable
|
||||
if incoming {
|
||||
table = f.InRules
|
||||
if useForward {
|
||||
table = f.ForwardInRules
|
||||
} else {
|
||||
table = f.InRules
|
||||
}
|
||||
} else {
|
||||
if useForward {
|
||||
table = f.ForwardOutRules
|
||||
} else {
|
||||
table = f.OutRules
|
||||
}
|
||||
}
|
||||
|
||||
// We now know which firewall table to check against
|
||||
@@ -454,12 +506,13 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
||||
}
|
||||
|
||||
// We always want to conntrack since it is a faster operation
|
||||
f.addConn(fp, incoming)
|
||||
f.addConn(fp, useForward, incoming)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Firewall) metrics(incoming bool) firewallMetrics {
|
||||
//TODO: need forward metrics too
|
||||
if incoming {
|
||||
return f.incomingMetrics
|
||||
} else {
|
||||
@@ -499,7 +552,6 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
||||
}
|
||||
|
||||
c, ok := conntrack.Conns[fp]
|
||||
|
||||
if !ok {
|
||||
conntrack.Unlock()
|
||||
return false
|
||||
@@ -508,9 +560,19 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
||||
if c.rulesVersion != f.rulesVersion {
|
||||
// This conntrack entry was for an older rule set, validate
|
||||
// it still passes with the current rule set
|
||||
table := f.OutRules
|
||||
var table *FirewallTable
|
||||
if c.incoming {
|
||||
table = f.InRules
|
||||
if c.forward {
|
||||
table = f.ForwardInRules
|
||||
} else {
|
||||
table = f.InRules
|
||||
}
|
||||
} else {
|
||||
if c.forward {
|
||||
table = f.ForwardOutRules
|
||||
} else {
|
||||
table = f.OutRules
|
||||
}
|
||||
}
|
||||
|
||||
// We now know which firewall table to check against
|
||||
@@ -519,6 +581,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
||||
h.logger(f.l).
|
||||
WithField("fwPacket", fp).
|
||||
WithField("incoming", c.incoming).
|
||||
WithField("forward", c.forward).
|
||||
WithField("rulesVersion", f.rulesVersion).
|
||||
WithField("oldRulesVersion", c.rulesVersion).
|
||||
Debugln("dropping old conntrack entry, does not match new ruleset")
|
||||
@@ -532,6 +595,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
||||
h.logger(f.l).
|
||||
WithField("fwPacket", fp).
|
||||
WithField("incoming", c.incoming).
|
||||
WithField("forward", c.forward).
|
||||
WithField("rulesVersion", f.rulesVersion).
|
||||
WithField("oldRulesVersion", c.rulesVersion).
|
||||
Debugln("keeping old conntrack entry, does match new ruleset")
|
||||
@@ -558,7 +622,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
|
||||
func (f *Firewall) addConn(fp firewall.Packet, forward, incoming bool) {
|
||||
var timeout time.Duration
|
||||
c := &conn{}
|
||||
|
||||
@@ -581,6 +645,7 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
|
||||
// Record which rulesVersion allowed this connection, so we can retest after
|
||||
// firewall reload
|
||||
c.incoming = incoming
|
||||
c.forward = forward
|
||||
c.rulesVersion = f.rulesVersion
|
||||
c.Expires = time.Now().Add(timeout)
|
||||
conntrack.Conns[fp] = c
|
||||
@@ -937,6 +1002,7 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
|
||||
r.Code = toString("code", m)
|
||||
r.Proto = toString("proto", m)
|
||||
r.Host = toString("host", m)
|
||||
//TODO: create an alias to remote_cidr and deprecate cidr?
|
||||
r.Cidr = toString("cidr", m)
|
||||
r.LocalCidr = toString("local_cidr", m)
|
||||
r.CAName = toString("ca_name", m)
|
||||
|
||||
Reference in New Issue
Block a user