mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-08 19:53:59 +01:00
Use atomic.Pointer for certState (#833)
This commit is contained in:
parent
2801fb2286
commit
6b3d42efa5
@ -54,12 +54,12 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
inside: &test.NoopTun{},
|
inside: &test.NoopTun{},
|
||||||
outside: &udp.Conn{},
|
outside: &udp.Conn{},
|
||||||
certState: cs,
|
|
||||||
firewall: &Firewall{},
|
firewall: &Firewall{},
|
||||||
lightHouse: lh,
|
lightHouse: lh,
|
||||||
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
|
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
|
ifce.certState.Store(cs)
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
@ -130,12 +130,12 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
inside: &test.NoopTun{},
|
inside: &test.NoopTun{},
|
||||||
outside: &udp.Conn{},
|
outside: &udp.Conn{},
|
||||||
certState: cs,
|
|
||||||
firewall: &Firewall{},
|
firewall: &Firewall{},
|
||||||
lightHouse: lh,
|
lightHouse: lh,
|
||||||
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
|
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
|
ifce.certState.Store(cs)
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
@ -245,7 +245,6 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
inside: &test.NoopTun{},
|
inside: &test.NoopTun{},
|
||||||
outside: &udp.Conn{},
|
outside: &udp.Conn{},
|
||||||
certState: cs,
|
|
||||||
firewall: &Firewall{},
|
firewall: &Firewall{},
|
||||||
lightHouse: lh,
|
lightHouse: lh,
|
||||||
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
|
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
|
||||||
@ -253,6 +252,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||||||
disconnectInvalid: true,
|
disconnectInvalid: true,
|
||||||
caPool: ncp,
|
caPool: ncp,
|
||||||
}
|
}
|
||||||
|
ifce.certState.Store(cs)
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|||||||
@ -33,7 +33,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern
|
|||||||
cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||||
}
|
}
|
||||||
|
|
||||||
curCertState := f.certState
|
curCertState := f.certState.Load()
|
||||||
static := noise.DHKey{Private: curCertState.privateKey, Public: curCertState.publicKey}
|
static := noise.DHKey{Private: curCertState.privateKey, Public: curCertState.publicKey}
|
||||||
|
|
||||||
b := NewBits(ReplayWindow)
|
b := NewBits(ReplayWindow)
|
||||||
|
|||||||
@ -161,5 +161,5 @@ func (c *Control) GetHostmap() *HostMap {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) GetCert() *cert.NebulaCertificate {
|
func (c *Control) GetCert() *cert.NebulaCertificate {
|
||||||
return c.f.certState.certificate
|
return c.f.certState.Load().certificate
|
||||||
}
|
}
|
||||||
|
|||||||
11
interface.go
11
interface.go
@ -52,7 +52,7 @@ type Interface struct {
|
|||||||
hostMap *HostMap
|
hostMap *HostMap
|
||||||
outside *udp.Conn
|
outside *udp.Conn
|
||||||
inside overlay.Device
|
inside overlay.Device
|
||||||
certState *CertState
|
certState atomic.Pointer[CertState]
|
||||||
cipher string
|
cipher string
|
||||||
firewall *Firewall
|
firewall *Firewall
|
||||||
connectionManager *connectionManager
|
connectionManager *connectionManager
|
||||||
@ -141,7 +141,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
hostMap: c.HostMap,
|
hostMap: c.HostMap,
|
||||||
outside: c.Outside,
|
outside: c.Outside,
|
||||||
inside: c.Inside,
|
inside: c.Inside,
|
||||||
certState: c.certState,
|
|
||||||
cipher: c.Cipher,
|
cipher: c.Cipher,
|
||||||
firewall: c.Firewall,
|
firewall: c.Firewall,
|
||||||
serveDns: c.ServeDns,
|
serveDns: c.ServeDns,
|
||||||
@ -172,6 +171,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
l: c.l,
|
l: c.l,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ifce.certState.Store(c.certState)
|
||||||
ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval)
|
ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval)
|
||||||
|
|
||||||
return ifce, nil
|
return ifce, nil
|
||||||
@ -298,14 +298,15 @@ func (f *Interface) reloadCertKey(c *config.C) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// did IP in cert change? if so, don't set
|
// did IP in cert change? if so, don't set
|
||||||
oldIPs := f.certState.certificate.Details.Ips
|
currentCert := f.certState.Load().certificate
|
||||||
|
oldIPs := currentCert.Details.Ips
|
||||||
newIPs := cs.certificate.Details.Ips
|
newIPs := cs.certificate.Details.Ips
|
||||||
if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
|
if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
|
||||||
f.l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old")
|
f.l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.certState = cs
|
f.certState.Store(cs)
|
||||||
f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
|
f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -316,7 +317,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fw, err := NewFirewallFromConfig(f.l, f.certState.certificate, c)
|
fw, err := NewFirewallFromConfig(f.l, f.certState.Load().certificate, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Error while creating firewall during reload")
|
f.l.WithError(err).Error("Error while creating firewall during reload")
|
||||||
return
|
return
|
||||||
|
|||||||
2
ssh.go
2
ssh.go
@ -753,7 +753,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
cert := ifce.certState.certificate
|
cert := ifce.certState.Load().certificate
|
||||||
if len(a) > 0 {
|
if len(a) > 0 {
|
||||||
parsedIp := net.ParseIP(a[0])
|
parsedIp := net.ParseIP(a[0])
|
||||||
if parsedIp == nil {
|
if parsedIp == nil {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user