We only need the certificate in ConnectionState (#953)

This commit is contained in:
Nate Brown 2023-08-21 14:11:06 -05:00 committed by GitHub
parent 5a131b2975
commit 7edcf620c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 37 additions and 51 deletions

View File

@ -406,7 +406,7 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
} }
certState := n.intf.pki.GetCertState() certState := n.intf.pki.GetCertState()
return bytes.Equal(current.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature) return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature)
} }
func (n *connectionManager) swapPrimary(current, primary *HostInfo) { func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
@ -465,7 +465,7 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
certState := n.intf.pki.GetCertState() certState := n.intf.pki.GetCertState()
if bytes.Equal(hostinfo.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature) { if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) {
return return
} }
@ -474,7 +474,7 @@ func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
Info("Re-handshaking with remote") Info("Re-handshaking with remote")
//TODO: this is copied from getOrHandshake to keep the extra checks out of the hot path, figure it out //TODO: this is copied from getOrHandshake to keep the extra checks out of the hot path, figure it out
newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp, n.intf.initHostInfo) newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp)
if !newHostinfo.HandshakeReady { if !newHostinfo.HandshakeReady {
ixHandshakeStage0(n.intf, newHostinfo.vpnIp, newHostinfo) ixHandshakeStage0(n.intf, newHostinfo.vpnIp, newHostinfo)
} }

View File

@ -79,8 +79,8 @@ func Test_NewConnectionManagerTest(t *testing.T) {
remoteIndexId: 9901, remoteIndexId: 9901,
} }
hostinfo.ConnectionState = &ConnectionState{ hostinfo.ConnectionState = &ConnectionState{
certState: cs, myCert: &cert.NebulaCertificate{},
H: &noise.HandshakeState{}, H: &noise.HandshakeState{},
} }
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@ -159,8 +159,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
remoteIndexId: 9901, remoteIndexId: 9901,
} }
hostinfo.ConnectionState = &ConnectionState{ hostinfo.ConnectionState = &ConnectionState{
certState: cs, myCert: &cert.NebulaCertificate{},
H: &noise.HandshakeState{}, H: &noise.HandshakeState{},
} }
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@ -222,7 +222,8 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
PublicKey: pubCA, PublicKey: pubCA,
}, },
} }
caCert.Sign(cert.Curve_CURVE25519, privCA)
assert.NoError(t, caCert.Sign(cert.Curve_CURVE25519, privCA))
ncp := &cert.NebulaCAPool{ ncp := &cert.NebulaCAPool{
CAs: cert.NewCAPool().CAs, CAs: cert.NewCAPool().CAs,
} }
@ -241,7 +242,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
Issuer: "ca", Issuer: "ca",
}, },
} }
peerCert.Sign(cert.Curve_CURVE25519, privCA) assert.NoError(t, peerCert.Sign(cert.Curve_CURVE25519, privCA))
cs := &CertState{ cs := &CertState{
RawCertificate: []byte{}, RawCertificate: []byte{},
@ -275,9 +276,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
hostinfo := &HostInfo{ hostinfo := &HostInfo{
vpnIp: vpnIp, vpnIp: vpnIp,
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
certState: cs, myCert: &cert.NebulaCertificate{},
peerCert: &peerCert, peerCert: &peerCert,
H: &noise.HandshakeState{}, H: &noise.HandshakeState{},
}, },
} }
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)

View File

@ -18,7 +18,7 @@ type ConnectionState struct {
eKey *NebulaCipherState eKey *NebulaCipherState
dKey *NebulaCipherState dKey *NebulaCipherState
H *noise.HandshakeState H *noise.HandshakeState
certState *CertState myCert *cert.NebulaCertificate
peerCert *cert.NebulaCertificate peerCert *cert.NebulaCertificate
initiator bool initiator bool
messageCounter atomic.Uint64 messageCounter atomic.Uint64
@ -28,25 +28,27 @@ type ConnectionState struct {
ready bool ready bool
} }
func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
var dhFunc noise.DHFunc var dhFunc noise.DHFunc
curCertState := f.pki.GetCertState()
switch curCertState.Certificate.Details.Curve { switch certState.Certificate.Details.Curve {
case cert.Curve_CURVE25519: case cert.Curve_CURVE25519:
dhFunc = noise.DH25519 dhFunc = noise.DH25519
case cert.Curve_P256: case cert.Curve_P256:
dhFunc = noiseutil.DHP256 dhFunc = noiseutil.DHP256
default: default:
l.Errorf("invalid curve: %s", curCertState.Certificate.Details.Curve) l.Errorf("invalid curve: %s", certState.Certificate.Details.Curve)
return nil return nil
} }
cs := noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
if f.cipher == "chachapoly" { var cs noise.CipherSuite
if cipher == "chachapoly" {
cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256) cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
} else {
cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
} }
static := noise.DHKey{Private: curCertState.PrivateKey, Public: curCertState.PublicKey} static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey}
b := NewBits(ReplayWindow) b := NewBits(ReplayWindow)
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss // Clear out bit 0, we never transmit it and we don't want it showing as packet loss
@ -72,7 +74,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern
initiator: initiator, initiator: initiator,
window: b, window: b,
ready: false, ready: false,
certState: curCertState, myCert: certState.Certificate,
} }
return ci return ci

View File

@ -165,7 +165,7 @@ func (c *Control) GetCert() *cert.NebulaCertificate {
} }
func (c *Control) ReHandshake(vpnIp iputil.VpnIp) { func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {
hostinfo := c.f.handshakeManager.AddVpnIp(vpnIp, c.f.initHostInfo) hostinfo := c.f.handshakeManager.AddVpnIp(vpnIp)
ixHandshakeStage0(c.f, vpnIp, hostinfo) ixHandshakeStage0(c.f, vpnIp, hostinfo)
// If this is a static host, we don't need to wait for the HostQueryReply // If this is a static host, we don't need to wait for the HostQueryReply

View File

@ -28,12 +28,14 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
return return
} }
ci := hostinfo.ConnectionState certState := f.pki.GetCertState()
ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0)
hostinfo.ConnectionState = ci
hsProto := &NebulaHandshakeDetails{ hsProto := &NebulaHandshakeDetails{
InitiatorIndex: hostinfo.localIndexId, InitiatorIndex: hostinfo.localIndexId,
Time: uint64(time.Now().UnixNano()), Time: uint64(time.Now().UnixNano()),
Cert: ci.certState.RawCertificateNoKey, Cert: certState.RawCertificateNoKey,
} }
hsBytes := []byte{} hsBytes := []byte{}
@ -69,7 +71,8 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
} }
func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) { func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0) certState := f.pki.GetCertState()
ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0)
// Mark packet 1 as seen so it doesn't show up as missed // Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(f.l, 1) ci.window.Update(f.l, 1)
@ -155,7 +158,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
Info("Handshake message received") Info("Handshake message received")
hs.Details.ResponderIndex = myIndex hs.Details.ResponderIndex = myIndex
hs.Details.Cert = ci.certState.RawCertificateNoKey hs.Details.Cert = certState.RawCertificateNoKey
// Update the time in case their clock is way off from ours // Update the time in case their clock is way off from ours
hs.Details.Time = uint64(time.Now().UnixNano()) hs.Details.Time = uint64(time.Now().UnixNano())

View File

@ -297,7 +297,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
} }
// AddVpnIp will try to handshake with the provided vpn ip and return the hostinfo for it. // AddVpnIp will try to handshake with the provided vpn ip and return the hostinfo for it.
func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *HostInfo { func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo {
// A write lock is used to avoid having to recheck the map and trading a read lock for a write lock // A write lock is used to avoid having to recheck the map and trading a read lock for a write lock
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
@ -317,10 +317,6 @@ func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *H
}, },
} }
if init != nil {
init(hostinfo)
}
c.vpnIps[vpnIp] = hostinfo c.vpnIps[vpnIp] = hostinfo
c.metricInitiated.Inc(1) c.metricInitiated.Inc(1)
c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval) c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)

View File

@ -28,17 +28,8 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
now := time.Now() now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw) blah.NextOutboundHandshakeTimerTick(now, mw)
var initCalled bool i := blah.AddVpnIp(ip)
initFunc := func(*HostInfo) { i2 := blah.AddVpnIp(ip)
initCalled = true
}
i := blah.AddVpnIp(ip, initFunc)
assert.True(t, initCalled)
initCalled = false
i2 := blah.AddVpnIp(ip, initFunc)
assert.False(t, initCalled)
assert.Same(t, i, i2) assert.Same(t, i, i2)
i.remotes = NewRemoteList(nil) i.remotes = NewRemoteList(nil)

View File

@ -1,7 +1,6 @@
package nebula package nebula
import ( import (
"github.com/flynn/noise"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
@ -124,7 +123,7 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
hostinfo := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f) hostinfo := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f)
if hostinfo == nil { if hostinfo == nil {
hostinfo = f.handshakeManager.AddVpnIp(vpnIp, f.initHostInfo) hostinfo = f.handshakeManager.AddVpnIp(vpnIp)
} }
ci := hostinfo.ConnectionState ci := hostinfo.ConnectionState
@ -168,12 +167,6 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
return hostinfo return hostinfo
} }
// initHostInfo is the init function to pass to (*HandshakeManager).AddVpnIP that
// will create the initial Noise ConnectionState
func (f *Interface) initHostInfo(hostinfo *HostInfo) {
hostinfo.ConnectionState = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
}
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
fp := &firewall.Packet{} fp := &firewall.Packet{}
err := newPacket(p, false, fp) err := newPacket(p, false, fp)

2
ssh.go
View File

@ -607,7 +607,7 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
} }
} }
hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp, ifce.initHostInfo) hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp)
if addr != nil { if addr != nil {
hostInfo.SetRemote(addr) hostInfo.SetRemote(addr)
} }