Remove tcp rtt tracking from the firewall (#1114)

This commit is contained in:
Nate Brown
2024-04-11 21:44:22 -05:00
committed by GitHub
parent 7efa750aef
commit c1711bc9c5
4 changed files with 26 additions and 175 deletions

View File

@@ -2,7 +2,6 @@ package nebula
import (
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
@@ -22,17 +21,12 @@ import (
"github.com/slackhq/nebula/firewall"
)
const tcpACK = 0x10
const tcpFIN = 0x01
type FirewallInterface interface {
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error
}
type conn struct {
Expires time.Time // Time when this conntrack entry will expire
Sent time.Time // If tcp rtt tracking is enabled this will be when Seq was last set
Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack
// record why the original connection passed the firewall, so we can re-validate
// after ruleset changes. Note, rulesVersion is a uint16 so that these two
@@ -66,8 +60,6 @@ type Firewall struct {
rulesVersion uint16
defaultLocalCIDRAny bool
trackTCPRTT bool
metricTCPRTT metrics.Histogram
incomingMetrics firewallMetrics
outgoingMetrics firewallMetrics
@@ -183,7 +175,6 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
hasSubnets: len(c.Details.Subnets) > 0,
l: l,
metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)),
incomingMetrics: firewallMetrics{
droppedLocalIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil),
droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_ip", nil),
@@ -422,9 +413,9 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
// Drop returns an error if the packet should be dropped, explaining why. It
// returns nil if the packet should not be dropped.
func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error {
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error {
// Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(packet, fp, incoming, h, caPool, localCache) {
if f.inConns(fp, h, caPool, localCache) {
return nil
}
@@ -462,7 +453,7 @@ func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *Hos
}
// We always want to conntrack since it is a faster operation
f.addConn(packet, fp, incoming)
f.addConn(fp, incoming)
return nil
}
@@ -491,7 +482,7 @@ func (f *Firewall) EmitStats() {
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
}
func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool {
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool {
if localCache != nil {
if _, ok := localCache[fp]; ok {
return true
@@ -551,11 +542,6 @@ func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h *
switch fp.Protocol {
case firewall.ProtoTCP:
c.Expires = time.Now().Add(f.TCPTimeout)
if incoming {
f.checkTCPRTT(c, packet)
} else {
setTCPRTTTracking(c, packet)
}
case firewall.ProtoUDP:
c.Expires = time.Now().Add(f.UDPTimeout)
default:
@@ -571,16 +557,13 @@ func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h *
return true
}
func (f *Firewall) addConn(packet []byte, fp firewall.Packet, incoming bool) {
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
var timeout time.Duration
c := &conn{}
switch fp.Protocol {
case firewall.ProtoTCP:
timeout = f.TCPTimeout
if !incoming {
setTCPRTTTracking(c, packet)
}
case firewall.ProtoUDP:
timeout = f.UDPTimeout
default:
@@ -1017,42 +1000,3 @@ func parsePort(s string) (startPort, endPort int32, err error) {
return
}
// TODO: write tests for these
func setTCPRTTTracking(c *conn, p []byte) {
if c.Seq != 0 {
return
}
ihl := int(p[0]&0x0f) << 2
// Don't track FIN packets
if p[ihl+13]&tcpFIN != 0 {
return
}
c.Seq = binary.BigEndian.Uint32(p[ihl+4 : ihl+8])
c.Sent = time.Now()
}
func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool {
if c.Seq == 0 {
return false
}
ihl := int(p[0]&0x0f) << 2
if p[ihl+13]&tcpACK == 0 {
return false
}
// Deal with wrap around, signed int cuts the ack window in half
// 0 is a bad ack, no data acknowledged
// positive number is a bad ack, ack is over half the window away
if int32(c.Seq-binary.BigEndian.Uint32(p[ihl+8:ihl+12])) >= 0 {
return false
}
f.metricTCPRTT.Update(time.Since(c.Sent).Nanoseconds())
c.Seq = 0
return true
}