mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-08 23:33:58 +01:00
We only need the certificate in ConnectionState (#953)
This commit is contained in:
parent
5a131b2975
commit
7edcf620c0
@ -406,7 +406,7 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
certState := n.intf.pki.GetCertState()
|
certState := n.intf.pki.GetCertState()
|
||||||
return bytes.Equal(current.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature)
|
return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
|
func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
|
||||||
@ -465,7 +465,7 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
|
|||||||
|
|
||||||
func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||||
certState := n.intf.pki.GetCertState()
|
certState := n.intf.pki.GetCertState()
|
||||||
if bytes.Equal(hostinfo.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature) {
|
if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -474,7 +474,7 @@ func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
|||||||
Info("Re-handshaking with remote")
|
Info("Re-handshaking with remote")
|
||||||
|
|
||||||
//TODO: this is copied from getOrHandshake to keep the extra checks out of the hot path, figure it out
|
//TODO: this is copied from getOrHandshake to keep the extra checks out of the hot path, figure it out
|
||||||
newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp, n.intf.initHostInfo)
|
newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp)
|
||||||
if !newHostinfo.HandshakeReady {
|
if !newHostinfo.HandshakeReady {
|
||||||
ixHandshakeStage0(n.intf, newHostinfo.vpnIp, newHostinfo)
|
ixHandshakeStage0(n.intf, newHostinfo.vpnIp, newHostinfo)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -79,8 +79,8 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||||||
remoteIndexId: 9901,
|
remoteIndexId: 9901,
|
||||||
}
|
}
|
||||||
hostinfo.ConnectionState = &ConnectionState{
|
hostinfo.ConnectionState = &ConnectionState{
|
||||||
certState: cs,
|
myCert: &cert.NebulaCertificate{},
|
||||||
H: &noise.HandshakeState{},
|
H: &noise.HandshakeState{},
|
||||||
}
|
}
|
||||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||||
|
|
||||||
@ -159,8 +159,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||||||
remoteIndexId: 9901,
|
remoteIndexId: 9901,
|
||||||
}
|
}
|
||||||
hostinfo.ConnectionState = &ConnectionState{
|
hostinfo.ConnectionState = &ConnectionState{
|
||||||
certState: cs,
|
myCert: &cert.NebulaCertificate{},
|
||||||
H: &noise.HandshakeState{},
|
H: &noise.HandshakeState{},
|
||||||
}
|
}
|
||||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||||
|
|
||||||
@ -222,7 +222,8 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||||||
PublicKey: pubCA,
|
PublicKey: pubCA,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
caCert.Sign(cert.Curve_CURVE25519, privCA)
|
|
||||||
|
assert.NoError(t, caCert.Sign(cert.Curve_CURVE25519, privCA))
|
||||||
ncp := &cert.NebulaCAPool{
|
ncp := &cert.NebulaCAPool{
|
||||||
CAs: cert.NewCAPool().CAs,
|
CAs: cert.NewCAPool().CAs,
|
||||||
}
|
}
|
||||||
@ -241,7 +242,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||||||
Issuer: "ca",
|
Issuer: "ca",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
peerCert.Sign(cert.Curve_CURVE25519, privCA)
|
assert.NoError(t, peerCert.Sign(cert.Curve_CURVE25519, privCA))
|
||||||
|
|
||||||
cs := &CertState{
|
cs := &CertState{
|
||||||
RawCertificate: []byte{},
|
RawCertificate: []byte{},
|
||||||
@ -275,9 +276,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||||||
hostinfo := &HostInfo{
|
hostinfo := &HostInfo{
|
||||||
vpnIp: vpnIp,
|
vpnIp: vpnIp,
|
||||||
ConnectionState: &ConnectionState{
|
ConnectionState: &ConnectionState{
|
||||||
certState: cs,
|
myCert: &cert.NebulaCertificate{},
|
||||||
peerCert: &peerCert,
|
peerCert: &peerCert,
|
||||||
H: &noise.HandshakeState{},
|
H: &noise.HandshakeState{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||||
|
|||||||
@ -18,7 +18,7 @@ type ConnectionState struct {
|
|||||||
eKey *NebulaCipherState
|
eKey *NebulaCipherState
|
||||||
dKey *NebulaCipherState
|
dKey *NebulaCipherState
|
||||||
H *noise.HandshakeState
|
H *noise.HandshakeState
|
||||||
certState *CertState
|
myCert *cert.NebulaCertificate
|
||||||
peerCert *cert.NebulaCertificate
|
peerCert *cert.NebulaCertificate
|
||||||
initiator bool
|
initiator bool
|
||||||
messageCounter atomic.Uint64
|
messageCounter atomic.Uint64
|
||||||
@ -28,25 +28,27 @@ type ConnectionState struct {
|
|||||||
ready bool
|
ready bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
|
func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
|
||||||
var dhFunc noise.DHFunc
|
var dhFunc noise.DHFunc
|
||||||
curCertState := f.pki.GetCertState()
|
|
||||||
|
|
||||||
switch curCertState.Certificate.Details.Curve {
|
switch certState.Certificate.Details.Curve {
|
||||||
case cert.Curve_CURVE25519:
|
case cert.Curve_CURVE25519:
|
||||||
dhFunc = noise.DH25519
|
dhFunc = noise.DH25519
|
||||||
case cert.Curve_P256:
|
case cert.Curve_P256:
|
||||||
dhFunc = noiseutil.DHP256
|
dhFunc = noiseutil.DHP256
|
||||||
default:
|
default:
|
||||||
l.Errorf("invalid curve: %s", curCertState.Certificate.Details.Curve)
|
l.Errorf("invalid curve: %s", certState.Certificate.Details.Curve)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
cs := noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
|
|
||||||
if f.cipher == "chachapoly" {
|
var cs noise.CipherSuite
|
||||||
|
if cipher == "chachapoly" {
|
||||||
cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
|
cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||||
|
} else {
|
||||||
|
cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
|
||||||
}
|
}
|
||||||
|
|
||||||
static := noise.DHKey{Private: curCertState.PrivateKey, Public: curCertState.PublicKey}
|
static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey}
|
||||||
|
|
||||||
b := NewBits(ReplayWindow)
|
b := NewBits(ReplayWindow)
|
||||||
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
|
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
|
||||||
@ -72,7 +74,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern
|
|||||||
initiator: initiator,
|
initiator: initiator,
|
||||||
window: b,
|
window: b,
|
||||||
ready: false,
|
ready: false,
|
||||||
certState: curCertState,
|
myCert: certState.Certificate,
|
||||||
}
|
}
|
||||||
|
|
||||||
return ci
|
return ci
|
||||||
|
|||||||
@ -165,7 +165,7 @@ func (c *Control) GetCert() *cert.NebulaCertificate {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {
|
func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {
|
||||||
hostinfo := c.f.handshakeManager.AddVpnIp(vpnIp, c.f.initHostInfo)
|
hostinfo := c.f.handshakeManager.AddVpnIp(vpnIp)
|
||||||
ixHandshakeStage0(c.f, vpnIp, hostinfo)
|
ixHandshakeStage0(c.f, vpnIp, hostinfo)
|
||||||
|
|
||||||
// If this is a static host, we don't need to wait for the HostQueryReply
|
// If this is a static host, we don't need to wait for the HostQueryReply
|
||||||
|
|||||||
@ -28,12 +28,14 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ci := hostinfo.ConnectionState
|
certState := f.pki.GetCertState()
|
||||||
|
ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0)
|
||||||
|
hostinfo.ConnectionState = ci
|
||||||
|
|
||||||
hsProto := &NebulaHandshakeDetails{
|
hsProto := &NebulaHandshakeDetails{
|
||||||
InitiatorIndex: hostinfo.localIndexId,
|
InitiatorIndex: hostinfo.localIndexId,
|
||||||
Time: uint64(time.Now().UnixNano()),
|
Time: uint64(time.Now().UnixNano()),
|
||||||
Cert: ci.certState.RawCertificateNoKey,
|
Cert: certState.RawCertificateNoKey,
|
||||||
}
|
}
|
||||||
|
|
||||||
hsBytes := []byte{}
|
hsBytes := []byte{}
|
||||||
@ -69,7 +71,8 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
|
func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
|
||||||
ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
|
certState := f.pki.GetCertState()
|
||||||
|
ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0)
|
||||||
// Mark packet 1 as seen so it doesn't show up as missed
|
// Mark packet 1 as seen so it doesn't show up as missed
|
||||||
ci.window.Update(f.l, 1)
|
ci.window.Update(f.l, 1)
|
||||||
|
|
||||||
@ -155,7 +158,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
|
|||||||
Info("Handshake message received")
|
Info("Handshake message received")
|
||||||
|
|
||||||
hs.Details.ResponderIndex = myIndex
|
hs.Details.ResponderIndex = myIndex
|
||||||
hs.Details.Cert = ci.certState.RawCertificateNoKey
|
hs.Details.Cert = certState.RawCertificateNoKey
|
||||||
// Update the time in case their clock is way off from ours
|
// Update the time in case their clock is way off from ours
|
||||||
hs.Details.Time = uint64(time.Now().UnixNano())
|
hs.Details.Time = uint64(time.Now().UnixNano())
|
||||||
|
|
||||||
|
|||||||
@ -297,7 +297,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddVpnIp will try to handshake with the provided vpn ip and return the hostinfo for it.
|
// AddVpnIp will try to handshake with the provided vpn ip and return the hostinfo for it.
|
||||||
func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *HostInfo {
|
func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo {
|
||||||
// A write lock is used to avoid having to recheck the map and trading a read lock for a write lock
|
// A write lock is used to avoid having to recheck the map and trading a read lock for a write lock
|
||||||
c.Lock()
|
c.Lock()
|
||||||
defer c.Unlock()
|
defer c.Unlock()
|
||||||
@ -317,10 +317,6 @@ func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *H
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if init != nil {
|
|
||||||
init(hostinfo)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.vpnIps[vpnIp] = hostinfo
|
c.vpnIps[vpnIp] = hostinfo
|
||||||
c.metricInitiated.Inc(1)
|
c.metricInitiated.Inc(1)
|
||||||
c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
|
c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
|
||||||
|
|||||||
@ -28,17 +28,8 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
|||||||
now := time.Now()
|
now := time.Now()
|
||||||
blah.NextOutboundHandshakeTimerTick(now, mw)
|
blah.NextOutboundHandshakeTimerTick(now, mw)
|
||||||
|
|
||||||
var initCalled bool
|
i := blah.AddVpnIp(ip)
|
||||||
initFunc := func(*HostInfo) {
|
i2 := blah.AddVpnIp(ip)
|
||||||
initCalled = true
|
|
||||||
}
|
|
||||||
|
|
||||||
i := blah.AddVpnIp(ip, initFunc)
|
|
||||||
assert.True(t, initCalled)
|
|
||||||
|
|
||||||
initCalled = false
|
|
||||||
i2 := blah.AddVpnIp(ip, initFunc)
|
|
||||||
assert.False(t, initCalled)
|
|
||||||
assert.Same(t, i, i2)
|
assert.Same(t, i, i2)
|
||||||
|
|
||||||
i.remotes = NewRemoteList(nil)
|
i.remotes = NewRemoteList(nil)
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/flynn/noise"
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
@ -124,7 +123,7 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
|
|||||||
|
|
||||||
hostinfo := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f)
|
hostinfo := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f)
|
||||||
if hostinfo == nil {
|
if hostinfo == nil {
|
||||||
hostinfo = f.handshakeManager.AddVpnIp(vpnIp, f.initHostInfo)
|
hostinfo = f.handshakeManager.AddVpnIp(vpnIp)
|
||||||
}
|
}
|
||||||
ci := hostinfo.ConnectionState
|
ci := hostinfo.ConnectionState
|
||||||
|
|
||||||
@ -168,12 +167,6 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
|
|||||||
return hostinfo
|
return hostinfo
|
||||||
}
|
}
|
||||||
|
|
||||||
// initHostInfo is the init function to pass to (*HandshakeManager).AddVpnIP that
|
|
||||||
// will create the initial Noise ConnectionState
|
|
||||||
func (f *Interface) initHostInfo(hostinfo *HostInfo) {
|
|
||||||
hostinfo.ConnectionState = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
|
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
|
||||||
fp := &firewall.Packet{}
|
fp := &firewall.Packet{}
|
||||||
err := newPacket(p, false, fp)
|
err := newPacket(p, false, fp)
|
||||||
|
|||||||
2
ssh.go
2
ssh.go
@ -607,7 +607,7 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp, ifce.initHostInfo)
|
hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp)
|
||||||
if addr != nil {
|
if addr != nil {
|
||||||
hostInfo.SetRemote(addr)
|
hostInfo.SetRemote(addr)
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user