Merge remote-tracking branch 'origin/master' into multiport

This commit is contained in:
Wade Simmons
2023-10-27 08:48:13 -04:00
74 changed files with 2540 additions and 1402 deletions

View File

@@ -13,7 +13,6 @@ import (
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
@@ -26,9 +25,9 @@ const mtu = 9001
type InterfaceConfig struct {
HostMap *HostMap
Outside *udp.Conn
Outside udp.Conn
Inside overlay.Device
certState *CertState
pki *PKI
Cipher string
Firewall *Firewall
ServeDns bool
@@ -41,20 +40,23 @@ type InterfaceConfig struct {
routines int
MessageMetrics *MessageMetrics
version string
caPool *cert.NebulaCAPool
disconnectInvalid bool
relayManager *relayManager
punchy *Punchy
tryPromoteEvery uint32
reQueryEvery uint32
reQueryWait time.Duration
ConntrackCacheTimeout time.Duration
l *logrus.Logger
}
type Interface struct {
hostMap *HostMap
outside *udp.Conn
outside udp.Conn
inside overlay.Device
certState atomic.Pointer[CertState]
pki *PKI
cipher string
firewall *Firewall
connectionManager *connectionManager
@@ -67,11 +69,14 @@ type Interface struct {
dropLocalBroadcast bool
dropMulticast bool
routines int
caPool *cert.NebulaCAPool
disconnectInvalid bool
closed atomic.Bool
relayManager *relayManager
tryPromoteEvery atomic.Uint32
reQueryEvery atomic.Uint32
reQueryWait atomic.Int64
sendRecvErrorConfig sendRecvErrorConfig
// rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse
@@ -80,7 +85,7 @@ type Interface struct {
conntrackCacheTimeout time.Duration
writers []*udp.Conn
writers []udp.Conn
readers []io.ReadWriteCloser
udpRaw *udp.RawConn
@@ -156,15 +161,17 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
if c.Inside == nil {
return nil, errors.New("no inside interface (tun)")
}
if c.certState == nil {
if c.pki == nil {
return nil, errors.New("no certificate state")
}
if c.Firewall == nil {
return nil, errors.New("no firewall rules")
}
myVpnIp := iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].IP)
certificate := c.pki.GetCertState().Certificate
myVpnIp := iputil.Ip2VpnIp(certificate.Details.Ips[0].IP)
ifce := &Interface{
pki: c.pki,
hostMap: c.HostMap,
outside: c.Outside,
inside: c.Inside,
@@ -174,14 +181,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
handshakeManager: c.HandshakeManager,
createTime: time.Now(),
lightHouse: c.lightHouse,
localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].Mask),
localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(certificate.Details.Ips[0].Mask),
dropLocalBroadcast: c.DropLocalBroadcast,
dropMulticast: c.DropMulticast,
routines: c.routines,
version: c.version,
writers: make([]*udp.Conn, c.routines),
writers: make([]udp.Conn, c.routines),
readers: make([]io.ReadWriteCloser, c.routines),
caPool: c.caPool,
disconnectInvalid: c.disconnectInvalid,
myVpnIp: myVpnIp,
relayManager: c.relayManager,
@@ -198,7 +204,10 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
l: c.l,
}
ifce.certState.Store(c.certState)
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
ifce.reQueryEvery.Store(c.reQueryEvery)
ifce.reQueryWait.Store(int64(c.reQueryWait))
ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy)
return ifce, nil
@@ -257,7 +266,7 @@ func (f *Interface) run() {
func (f *Interface) listenOut(i int) {
runtime.LockOSThread()
var li *udp.Conn
var li udp.Conn
// TODO clean this up with a coherent interface for each outside connection
if i > 0 {
li = f.writers[i]
@@ -297,49 +306,14 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
}
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
c.RegisterReloadCallback(f.reloadCA)
c.RegisterReloadCallback(f.reloadCertKey)
c.RegisterReloadCallback(f.reloadFirewall)
c.RegisterReloadCallback(f.reloadSendRecvError)
c.RegisterReloadCallback(f.reloadMisc)
for _, udpConn := range f.writers {
c.RegisterReloadCallback(udpConn.ReloadConfig)
}
}
func (f *Interface) reloadCA(c *config.C) {
// reload and check regardless
// todo: need mutex?
newCAs, err := loadCAFromConfig(f.l, c)
if err != nil {
f.l.WithError(err).Error("Could not refresh trusted CA certificates")
return
}
f.caPool = newCAs
f.l.WithField("fingerprints", f.caPool.GetFingerprints()).Info("Trusted CA certificates refreshed")
}
func (f *Interface) reloadCertKey(c *config.C) {
// reload and check in all cases
cs, err := NewCertStateFromConfig(c)
if err != nil {
f.l.WithError(err).Error("Could not refresh client cert")
return
}
// did IP in cert change? if so, don't set
currentCert := f.certState.Load().certificate
oldIPs := currentCert.Details.Ips
newIPs := cs.certificate.Details.Ips
if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
f.l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old")
return
}
f.certState.Store(cs)
f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
}
func (f *Interface) reloadFirewall(c *config.C) {
//TODO: need to trigger/detect if the certificate changed too
if c.HasChanged("firewall") == false {
@@ -347,7 +321,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
return
}
fw, err := NewFirewallFromConfig(f.l, f.certState.Load().certificate, c)
fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c)
if err != nil {
f.l.WithError(err).Error("Error while creating firewall during reload")
return
@@ -403,6 +377,26 @@ func (f *Interface) reloadSendRecvError(c *config.C) {
}
}
func (f *Interface) reloadMisc(c *config.C) {
if c.HasChanged("counters.try_promote") {
n := c.GetUint32("counters.try_promote", defaultPromoteEvery)
f.tryPromoteEvery.Store(n)
f.l.Info("counters.try_promote has changed")
}
if c.HasChanged("counters.requery_every_packets") {
n := c.GetUint32("counters.requery_every_packets", defaultReQueryEvery)
f.reQueryEvery.Store(n)
f.l.Info("counters.requery_every_packets has changed")
}
if c.HasChanged("timers.requery_wait_duration") {
n := c.GetDuration("timers.requery_wait_duration", defaultReQueryWait)
f.reQueryWait.Store(int64(n))
f.l.Info("timers.requery_wait_duration has changed")
}
}
func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
ticker := time.NewTicker(i)
defer ticker.Stop()
@@ -427,7 +421,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
}
rawStats()
}
certExpirationGauge.Update(int64(f.certState.Load().certificate.Details.NotAfter.Sub(time.Now()) / time.Second))
certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.Details.NotAfter.Sub(time.Now()) / time.Second))
}
}
}
@@ -435,6 +429,13 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
func (f *Interface) Close() error {
f.closed.Store(true)
for _, u := range f.writers {
err := u.Close()
if err != nil {
f.l.WithError(err).Error("Error while closing udp socket")
}
}
// Release the tun device
return f.inside.Close()
}