diff --git a/connection_manager.go b/connection_manager.go index 62a8dd2..81563a4 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -406,7 +406,7 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { } certState := n.intf.pki.GetCertState() - return bytes.Equal(current.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature) + return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature) } func (n *connectionManager) swapPrimary(current, primary *HostInfo) { @@ -465,7 +465,7 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) { func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { certState := n.intf.pki.GetCertState() - if bytes.Equal(hostinfo.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature) { + if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) { return } @@ -474,7 +474,7 @@ func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { Info("Re-handshaking with remote") //TODO: this is copied from getOrHandshake to keep the extra checks out of the hot path, figure it out - newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp, n.intf.initHostInfo) + newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp) if !newHostinfo.HandshakeReady { ixHandshakeStage0(n.intf, newHostinfo.vpnIp, newHostinfo) } diff --git a/connection_manager_test.go b/connection_manager_test.go index a489bf2..e220819 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -79,8 +79,8 @@ func Test_NewConnectionManagerTest(t *testing.T) { remoteIndexId: 9901, } hostinfo.ConnectionState = &ConnectionState{ - certState: cs, - H: &noise.HandshakeState{}, + myCert: &cert.NebulaCertificate{}, + H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -159,8 +159,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) { remoteIndexId: 9901, } hostinfo.ConnectionState = &ConnectionState{ - certState: cs, - H: &noise.HandshakeState{}, + myCert: &cert.NebulaCertificate{}, + H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -222,7 +222,8 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { PublicKey: pubCA, }, } - caCert.Sign(cert.Curve_CURVE25519, privCA) + + assert.NoError(t, caCert.Sign(cert.Curve_CURVE25519, privCA)) ncp := &cert.NebulaCAPool{ CAs: cert.NewCAPool().CAs, } @@ -241,7 +242,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { Issuer: "ca", }, } - peerCert.Sign(cert.Curve_CURVE25519, privCA) + assert.NoError(t, peerCert.Sign(cert.Curve_CURVE25519, privCA)) cs := &CertState{ RawCertificate: []byte{}, @@ -275,9 +276,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { hostinfo := &HostInfo{ vpnIp: vpnIp, ConnectionState: &ConnectionState{ - certState: cs, - peerCert: &peerCert, - H: &noise.HandshakeState{}, + myCert: &cert.NebulaCertificate{}, + peerCert: &peerCert, + H: &noise.HandshakeState{}, }, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) diff --git a/connection_state.go b/connection_state.go index 163e4bc..5260749 100644 --- a/connection_state.go +++ b/connection_state.go @@ -18,7 +18,7 @@ type ConnectionState struct { eKey *NebulaCipherState dKey *NebulaCipherState H *noise.HandshakeState - certState *CertState + myCert *cert.NebulaCertificate peerCert *cert.NebulaCertificate initiator bool messageCounter atomic.Uint64 @@ -28,25 +28,27 @@ type ConnectionState struct { ready bool } -func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { +func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { var dhFunc noise.DHFunc - curCertState := f.pki.GetCertState() - switch curCertState.Certificate.Details.Curve { + switch certState.Certificate.Details.Curve { case cert.Curve_CURVE25519: dhFunc = noise.DH25519 case cert.Curve_P256: dhFunc = noiseutil.DHP256 default: - l.Errorf("invalid curve: %s", curCertState.Certificate.Details.Curve) + l.Errorf("invalid curve: %s", certState.Certificate.Details.Curve) return nil } - cs := noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256) - if f.cipher == "chachapoly" { + + var cs noise.CipherSuite + if cipher == "chachapoly" { cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256) + } else { + cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256) } - static := noise.DHKey{Private: curCertState.PrivateKey, Public: curCertState.PublicKey} + static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey} b := NewBits(ReplayWindow) // Clear out bit 0, we never transmit it and we don't want it showing as packet loss @@ -72,7 +74,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern initiator: initiator, window: b, ready: false, - certState: curCertState, + myCert: certState.Certificate, } return ci diff --git a/control_tester.go b/control_tester.go index a26c8bb..680cd5a 100644 --- a/control_tester.go +++ b/control_tester.go @@ -165,7 +165,7 @@ func (c *Control) GetCert() *cert.NebulaCertificate { } func (c *Control) ReHandshake(vpnIp iputil.VpnIp) { - hostinfo := c.f.handshakeManager.AddVpnIp(vpnIp, c.f.initHostInfo) + hostinfo := c.f.handshakeManager.AddVpnIp(vpnIp) ixHandshakeStage0(c.f, vpnIp, hostinfo) // If this is a static host, we don't need to wait for the HostQueryReply diff --git a/handshake_ix.go b/handshake_ix.go index 52efdf5..94f408f 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -28,12 +28,14 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) { return } - ci := hostinfo.ConnectionState + certState := f.pki.GetCertState() + ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0) + hostinfo.ConnectionState = ci hsProto := &NebulaHandshakeDetails{ InitiatorIndex: hostinfo.localIndexId, Time: uint64(time.Now().UnixNano()), - Cert: ci.certState.RawCertificateNoKey, + Cert: certState.RawCertificateNoKey, } hsBytes := []byte{} @@ -69,7 +71,8 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) { } func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) { - ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0) + certState := f.pki.GetCertState() + ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0) // Mark packet 1 as seen so it doesn't show up as missed ci.window.Update(f.l, 1) @@ -155,7 +158,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by Info("Handshake message received") hs.Details.ResponderIndex = myIndex - hs.Details.Cert = ci.certState.RawCertificateNoKey + hs.Details.Cert = certState.RawCertificateNoKey // Update the time in case their clock is way off from ours hs.Details.Time = uint64(time.Now().UnixNano()) diff --git a/handshake_manager.go b/handshake_manager.go index a70f4db..e15b794 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -297,7 +297,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light } // AddVpnIp will try to handshake with the provided vpn ip and return the hostinfo for it. -func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *HostInfo { +func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo { // A write lock is used to avoid having to recheck the map and trading a read lock for a write lock c.Lock() defer c.Unlock() @@ -317,10 +317,6 @@ func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *H }, } - if init != nil { - init(hostinfo) - } - c.vpnIps[vpnIp] = hostinfo c.metricInitiated.Inc(1) c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval) diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 383e900..c6df37d 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -28,17 +28,8 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { now := time.Now() blah.NextOutboundHandshakeTimerTick(now, mw) - var initCalled bool - initFunc := func(*HostInfo) { - initCalled = true - } - - i := blah.AddVpnIp(ip, initFunc) - assert.True(t, initCalled) - - initCalled = false - i2 := blah.AddVpnIp(ip, initFunc) - assert.False(t, initCalled) + i := blah.AddVpnIp(ip) + i2 := blah.AddVpnIp(ip) assert.Same(t, i, i2) i.remotes = NewRemoteList(nil) diff --git a/inside.go b/inside.go index 0fac833..6a0e078 100644 --- a/inside.go +++ b/inside.go @@ -1,7 +1,6 @@ package nebula import ( - "github.com/flynn/noise" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -124,7 +123,7 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo { hostinfo := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f) if hostinfo == nil { - hostinfo = f.handshakeManager.AddVpnIp(vpnIp, f.initHostInfo) + hostinfo = f.handshakeManager.AddVpnIp(vpnIp) } ci := hostinfo.ConnectionState @@ -168,12 +167,6 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo { return hostinfo } -// initHostInfo is the init function to pass to (*HandshakeManager).AddVpnIP that -// will create the initial Noise ConnectionState -func (f *Interface) initHostInfo(hostinfo *HostInfo) { - hostinfo.ConnectionState = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0) -} - func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { fp := &firewall.Packet{} err := newPacket(p, false, fp) diff --git a/ssh.go b/ssh.go index 44286c8..c68e082 100644 --- a/ssh.go +++ b/ssh.go @@ -607,7 +607,7 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW } } - hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp, ifce.initHostInfo) + hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp) if addr != nil { hostInfo.SetRemote(addr) }