diff --git a/cert_test/cert.go b/cert_test/cert.go index 75134316..c3759f12 100644 --- a/cert_test/cert.go +++ b/cert_test/cert.go @@ -163,3 +163,55 @@ func P256Keypair() ([]byte, []byte) { pubkey := privkey.PublicKey() return pubkey.Bytes(), privkey.Bytes() } + +// DummyCert is a minimal cert.Certificate implementation for testing error paths. +type DummyCert struct { + Version_ cert.Version + Curve_ cert.Curve + Groups_ []string + IsCA_ bool + Issuer_ string + Name_ string + Networks_ []netip.Prefix + NotAfter_ time.Time + NotBefore_ time.Time + PublicKey_ []byte + Signature_ []byte + UnsafeNetworks_ []netip.Prefix +} + +func (d *DummyCert) Version() cert.Version { return d.Version_ } +func (d *DummyCert) Curve() cert.Curve { return d.Curve_ } +func (d *DummyCert) Groups() []string { return d.Groups_ } +func (d *DummyCert) IsCA() bool { return d.IsCA_ } +func (d *DummyCert) Issuer() string { return d.Issuer_ } +func (d *DummyCert) Name() string { return d.Name_ } +func (d *DummyCert) Networks() []netip.Prefix { return d.Networks_ } +func (d *DummyCert) NotAfter() time.Time { return d.NotAfter_ } +func (d *DummyCert) NotBefore() time.Time { return d.NotBefore_ } +func (d *DummyCert) PublicKey() []byte { return d.PublicKey_ } +func (d *DummyCert) Signature() []byte { return d.Signature_ } +func (d *DummyCert) UnsafeNetworks() []netip.Prefix { return d.UnsafeNetworks_ } +func (d *DummyCert) Fingerprint() (string, error) { return "", nil } +func (d *DummyCert) CheckSignature(key []byte) bool { return false } +func (d *DummyCert) MarshalForHandshakes() ([]byte, error) { return nil, nil } +func (d *DummyCert) MarshalPEM() ([]byte, error) { return nil, nil } +func (d *DummyCert) MarshalJSON() ([]byte, error) { return nil, nil } +func (d *DummyCert) Marshal() ([]byte, error) { return nil, nil } +func (d *DummyCert) String() string { return "dummy" } +func (d *DummyCert) Copy() cert.Certificate { return d } +func (d *DummyCert) VerifyPrivateKey(c cert.Curve, k []byte) error { return nil } +func (d *DummyCert) Expired(time.Time) bool { return false } +func (d *DummyCert) MarshalPublicKeyPEM() []byte { return nil } +func (d *DummyCert) PublicKeyPEM() []byte { return nil } + +// NewTestCAPool creates a CAPool from the given CA certificates, panicking on error. +func NewTestCAPool(cas ...cert.Certificate) *cert.CAPool { + pool := cert.NewCAPool() + for _, ca := range cas { + if err := pool.AddCA(ca); err != nil { + panic(err) + } + } + return pool +} diff --git a/connection_manager_test.go b/connection_manager_test.go index a015fba9..7dc08a45 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/flynn/noise" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/overlay/overlaytest" @@ -47,7 +46,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { initiatingVersion: cert.Version1, privateKey: []byte{}, v1Cert: &dummyCert{version: cert.Version1}, - v1HandshakeBytes: []byte{}, + v1Credential: nil, } lh := newTestLighthouse() @@ -80,7 +79,6 @@ func Test_NewConnectionManagerTest(t *testing.T) { } hostinfo.ConnectionState = &ConnectionState{ myCert: &dummyCert{version: cert.Version1}, - H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -130,7 +128,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { initiatingVersion: cert.Version1, privateKey: []byte{}, v1Cert: &dummyCert{version: cert.Version1}, - v1HandshakeBytes: []byte{}, + v1Credential: nil, } lh := newTestLighthouse() @@ -163,7 +161,6 @@ func Test_NewConnectionManagerTest2(t *testing.T) { } hostinfo.ConnectionState = &ConnectionState{ myCert: &dummyCert{version: cert.Version1}, - H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -215,7 +212,7 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) { initiatingVersion: cert.Version1, privateKey: []byte{}, v1Cert: &dummyCert{version: cert.Version1}, - v1HandshakeBytes: []byte{}, + v1Credential: nil, } lh := newTestLighthouse() @@ -249,7 +246,6 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) { } hostinfo.ConnectionState = &ConnectionState{ myCert: &dummyCert{version: cert.Version1}, - H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -340,9 +336,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert) cs := &CertState{ - privateKey: []byte{}, - v1Cert: &dummyCert{}, - v1HandshakeBytes: []byte{}, + privateKey: []byte{}, + v1Cert: &dummyCert{}, + v1Credential: nil, } lh := newTestLighthouse() @@ -372,7 +368,6 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { ConnectionState: &ConnectionState{ myCert: &dummyCert{}, peerCert: cachedPeerCert, - H: &noise.HandshakeState{}, }, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) diff --git a/connection_state.go b/connection_state.go index b85aebd4..47e23b5a 100644 --- a/connection_state.go +++ b/connection_state.go @@ -1,15 +1,12 @@ package nebula import ( - "crypto/rand" "encoding/json" - "fmt" "sync" "sync/atomic" - "github.com/flynn/noise" "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/noiseutil" + "github.com/slackhq/nebula/handshake" ) const ReplayWindow = 1024 @@ -17,7 +14,6 @@ const ReplayWindow = 1024 type ConnectionState struct { eKey *NebulaCipherState dKey *NebulaCipherState - H *noise.HandshakeState myCert cert.Certificate peerCert *cert.CachedCertificate initiator bool @@ -26,55 +22,24 @@ type ConnectionState struct { writeLock sync.Mutex } -func NewConnectionState(cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) { - var dhFunc noise.DHFunc - switch crt.Curve() { - case cert.Curve_CURVE25519: - dhFunc = noise.DH25519 - case cert.Curve_P256: - if cs.pkcs11Backed { - dhFunc = noiseutil.DHP256PKCS11 - } else { - dhFunc = noiseutil.DHP256 - } - default: - return nil, fmt.Errorf("invalid curve: %s", crt.Curve()) - } - - var ncs noise.CipherSuite - if cs.cipher == "chachapoly" { - ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256) - } else { - ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256) - } - - static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()} - hs, err := noise.NewHandshakeState(noise.Config{ - CipherSuite: ncs, - Random: rand.Reader, - Pattern: pattern, - Initiator: initiator, - StaticKeypair: static, - //NOTE: These should come from CertState (pki.go) when we finally implement it - PresharedKey: []byte{}, - PresharedKeyPlacement: 0, - }) - if err != nil { - return nil, fmt.Errorf("NewConnectionState: %s", err) - } - - // The queue and ready params prevent a counter race that would happen when - // sending stored packets and simultaneously accepting new traffic. +// newConnectionStateFromResult builds a fully-populated ConnectionState from a +// completed handshake.Result. It seeds messageCounter and the replay window so +// that the post-handshake message indices already used on the wire don't count +// as missed traffic in the data plane. +func newConnectionStateFromResult(r *handshake.Result) *ConnectionState { ci := &ConnectionState{ - H: hs, - initiator: initiator, + myCert: r.MyCert, + initiator: r.Initiator, + peerCert: r.RemoteCert, + eKey: NewNebulaCipherState(r.EKey), + dKey: NewNebulaCipherState(r.DKey), window: NewBits(ReplayWindow), - myCert: crt, } - // always start the counter from 2, as packet 1 and packet 2 are handshake packets. - ci.messageCounter.Add(2) - - return ci, nil + ci.messageCounter.Add(r.MessageIndex) + for i := uint64(1); i <= r.MessageIndex; i++ { + ci.window.Update(nil, i) + } + return ci } func (cs *ConnectionState) MarshalJSON() ([]byte, error) { diff --git a/connection_state_test.go b/connection_state_test.go new file mode 100644 index 00000000..dea60d39 --- /dev/null +++ b/connection_state_test.go @@ -0,0 +1,114 @@ +package nebula + +import ( + "net/netip" + "testing" + "time" + + "github.com/flynn/noise" + "github.com/slackhq/nebula/cert" + ct "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/handshake" + "github.com/slackhq/nebula/header" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// runTestHandshake runs a complete IX handshake between two freshly-built +// peers and returns the initiator and responder Results. Used to produce +// real cipher states for tests that need to exercise post-handshake glue. +func runTestHandshake(t *testing.T) (initR, respR *handshake.Result) { + t.Helper() + + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + + makeCreds := func(name string, networks []netip.Prefix) handshake.GetCredentialFunc { + c, _, rawKey, _ := ct.NewTestCert( + cert.Version2, cert.Curve_CURVE25519, ca, caKey, + name, ca.NotBefore(), ca.NotAfter(), networks, nil, nil, + ) + priv, _, _, err := cert.UnmarshalPrivateKeyFromPEM(rawKey) + require.NoError(t, err) + hsBytes, err := c.MarshalForHandshakes() + require.NoError(t, err) + ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256) + cred := handshake.NewCredential(c, hsBytes, priv, ncs) + return func(v cert.Version) *handshake.Credential { + if v == cert.Version2 { + return cred + } + return nil + } + } + + verifier := func(c cert.Certificate) (*cert.CachedCertificate, error) { + return caPool.VerifyCertificate(time.Now(), c) + } + + initCreds := makeCreds("initiator", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + respCreds := makeCreds("responder", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}) + + initM, err := handshake.NewMachine( + cert.Version2, initCreds, verifier, + func() (uint32, error) { return 1000, nil }, + true, header.HandshakeIXPSK0, + ) + require.NoError(t, err) + + respM, err := handshake.NewMachine( + cert.Version2, respCreds, verifier, + func() (uint32, error) { return 2000, nil }, + false, header.HandshakeIXPSK0, + ) + require.NoError(t, err) + + msg1, err := initM.Initiate(nil) + require.NoError(t, err) + + resp, respR, err := respM.ProcessPacket(nil, msg1) + require.NoError(t, err) + require.NotNil(t, respR) + + _, initR, err = initM.ProcessPacket(nil, resp) + require.NoError(t, err) + require.NotNil(t, initR) + + return initR, respR +} + +func TestNewConnectionStateFromResult(t *testing.T) { + initR, respR := runTestHandshake(t) + + t.Run("initiator", func(t *testing.T) { + ci := newConnectionStateFromResult(initR) + assert.True(t, ci.initiator) + assert.Equal(t, initR.MyCert, ci.myCert) + assert.Equal(t, initR.RemoteCert, ci.peerCert) + assert.NotNil(t, ci.eKey) + assert.NotNil(t, ci.dKey) + + // IX has 2 handshake messages; the next data-plane send is counter=3. + assert.Equal(t, uint64(2), ci.messageCounter.Load(), + "messageCounter must equal Result.MessageIndex so the next send is N+1") + + // Both handshake counters must be marked seen so they don't appear lost. + // Check returns false if an index has already been recorded. + assert.False(t, ci.window.Check(nil, 1), "counter 1 must already be seen") + assert.False(t, ci.window.Check(nil, 2), "counter 2 must already be seen") + // Counter 3 is the next data-plane message and must NOT be pre-marked. + assert.True(t, ci.window.Check(nil, 3), "counter 3 must not be pre-seeded") + }) + + t.Run("responder", func(t *testing.T) { + ci := newConnectionStateFromResult(respR) + assert.False(t, ci.initiator) + assert.Equal(t, respR.MyCert, ci.myCert) + assert.Equal(t, respR.RemoteCert, ci.peerCert) + assert.NotNil(t, ci.eKey) + assert.NotNil(t, ci.dKey) + assert.Equal(t, uint64(2), ci.messageCounter.Load()) + }) +} diff --git a/e2e/handshake_manager_test.go b/e2e/handshake_manager_test.go index 3fe784c1..1c6ebacc 100644 --- a/e2e/handshake_manager_test.go +++ b/e2e/handshake_manager_test.go @@ -28,6 +28,7 @@ func makeHandshakePacket(from, to netip.AddrPort, subtype header.MessageSubType, } func TestHandshakeRetransmitDuplicate(t *testing.T) { + t.Parallel() // Verify the responder correctly handles receiving the same msg1 multiple times // (retransmission). The duplicate goes through CheckAndComplete -> ErrAlreadySeen // and the cached response is resent. @@ -78,6 +79,7 @@ func TestHandshakeRetransmitDuplicate(t *testing.T) { } func TestHandshakeTruncatedPacketRecovery(t *testing.T) { + t.Parallel() // Verify that a truncated handshake packet is ignored and the real // packet can still complete the handshake. @@ -126,6 +128,7 @@ func TestHandshakeTruncatedPacketRecovery(t *testing.T) { } func TestHandshakeOrphanedMsg2Dropped(t *testing.T) { + t.Parallel() // A msg2 arriving with no matching pending index should be silently dropped // with no response sent and no state changes. @@ -168,6 +171,7 @@ func TestHandshakeOrphanedMsg2Dropped(t *testing.T) { } func TestHandshakeUnknownMessageCounter(t *testing.T) { + t.Parallel() // A handshake packet with an unexpected message counter should be silently // dropped with no side effects and no UDP response. @@ -199,6 +203,7 @@ func TestHandshakeUnknownMessageCounter(t *testing.T) { } func TestHandshakeUnknownSubtype(t *testing.T) { + t.Parallel() // A handshake packet with an unknown subtype should be silently dropped. ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) @@ -224,6 +229,7 @@ func TestHandshakeUnknownSubtype(t *testing.T) { } func TestHandshakeLateResponse(t *testing.T) { + t.Parallel() // After a handshake times out, a late response should be silently ignored // with no new tunnels created. @@ -273,6 +279,7 @@ func TestHandshakeLateResponse(t *testing.T) { } func TestHandshakeSelfConnectionRejected(t *testing.T) { + t.Parallel() // Verify that a node rejects a handshake containing its own VPN IP in the // peer cert. We do this by sending the initiator's own msg1 back to itself. @@ -321,6 +328,7 @@ func TestHandshakeSelfConnectionRejected(t *testing.T) { } func TestHandshakeMessageCounter0Dropped(t *testing.T) { + t.Parallel() // MessageCounter=0 is not a valid handshake message and should be dropped. ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) @@ -341,6 +349,7 @@ func TestHandshakeMessageCounter0Dropped(t *testing.T) { } func TestHandshakeRemoteAllowList(t *testing.T) { + t.Parallel() // Verify that a handshake from a blocked underlay IP is dropped with no // response and no state changes. Then verify the same packet from an // allowed IP succeeds. @@ -399,6 +408,7 @@ func TestHandshakeRemoteAllowList(t *testing.T) { } func TestHandshakeAlreadySeenPreferredRemote(t *testing.T) { + t.Parallel() // When a duplicate msg1 arrives via ErrAlreadySeen, verify the tunnel // remains functional and hostmap index count is stable. @@ -445,6 +455,7 @@ func TestHandshakeAlreadySeenPreferredRemote(t *testing.T) { } func TestHandshakeWrongResponderPacketStore(t *testing.T) { + t.Parallel() // Verify that when the wrong host responds, the cached packets are // transferred to the new handshake, the evil tunnel is closed, evil's // address is blocked, and the correct tunnel is eventually established. @@ -508,6 +519,7 @@ func TestHandshakeWrongResponderPacketStore(t *testing.T) { } func TestHandshakeRelayComplete(t *testing.T) { + t.Parallel() // Verify that a relay handshake completes correctly and relay state is // properly maintained on all three nodes. diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 93f200ac..43fa72f2 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -84,6 +84,7 @@ func BenchmarkHotPathRelay(b *testing.B) { } func TestGoodHandshake(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) @@ -134,6 +135,7 @@ func TestGoodHandshake(t *testing.T) { } func TestGoodHandshakeNoOverlap(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "2001::69/24", nil) //look ma, cross-stack! @@ -169,6 +171,7 @@ func TestGoodHandshakeNoOverlap(t *testing.T) { } func TestWrongResponderHandshake(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", nil) @@ -245,6 +248,7 @@ func TestWrongResponderHandshake(t *testing.T) { } func TestWrongResponderHandshakeStaticHostMap(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil) @@ -327,6 +331,7 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) { } func TestStage1Race(t *testing.T) { + t.Parallel() // This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow // But will eventually collapse down to a single tunnel @@ -407,6 +412,7 @@ func TestStage1Race(t *testing.T) { } func TestUncleanShutdownRaceLoser(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) @@ -456,6 +462,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { } func TestUncleanShutdownRaceWinner(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) @@ -507,6 +514,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { } func TestRelays(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) @@ -536,6 +544,7 @@ func TestRelays(t *testing.T) { } func TestRelaysDontCareAboutIps(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "2001::9999/24", m{"relay": m{"am_relay": true}}) @@ -565,6 +574,7 @@ func TestRelaysDontCareAboutIps(t *testing.T) { } func TestReestablishRelays(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) @@ -696,6 +706,7 @@ func TestReestablishRelays(t *testing.T) { } func TestStage1RaceRelays(t *testing.T) { + t.Parallel() //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) @@ -743,6 +754,7 @@ func TestStage1RaceRelays(t *testing.T) { } func TestStage1RaceRelays2(t *testing.T) { + t.Parallel() //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) @@ -819,6 +831,7 @@ func TestStage1RaceRelays2(t *testing.T) { } func TestRehandshakingRelays(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) @@ -922,6 +935,7 @@ func TestRehandshakingRelays(t *testing.T) { } func TestRehandshakingRelaysPrimary(t *testing.T) { + t.Parallel() // This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}}) @@ -1026,6 +1040,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { } func TestRehandshaking(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil) @@ -1121,6 +1136,7 @@ func TestRehandshaking(t *testing.T) { } func TestRehandshakingLoser(t *testing.T) { + t.Parallel() // The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel // Should be the one with the new certificate ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) @@ -1219,6 +1235,7 @@ func TestRehandshakingLoser(t *testing.T) { } func TestRaceRegression(t *testing.T) { + t.Parallel() // This test forces stage 1, stage 2, stage 1 to be received by me from them // We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which // caused a cross-linked hostinfo @@ -1279,6 +1296,7 @@ func TestRaceRegression(t *testing.T) { } func TestV2NonPrimaryWithLighthouse(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "10.128.0.1/24, ff::1/64", m{"lighthouse": m{"am_lighthouse": true}}) @@ -1319,6 +1337,7 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) { } func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "2001::1/64", m{"lighthouse": m{"am_lighthouse": true}}) @@ -1359,6 +1378,7 @@ func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) { } func TestLighthouseUpdateOnReload(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) // Create the lighthouse @@ -1434,6 +1454,7 @@ func TestLighthouseUpdateOnReload(t *testing.T) { } func TestGoodHandshakeUnsafeDest(t *testing.T) { + t.Parallel() unsafePrefix := "192.168.6.0/24" ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(cert.Version2, ca, caKey, "spooky", "10.128.0.2/24", netip.MustParseAddrPort("10.64.0.2:4242"), unsafePrefix, nil) diff --git a/e2e/leak_test.go b/e2e/leak_test.go new file mode 100644 index 00000000..ffb024fe --- /dev/null +++ b/e2e/leak_test.go @@ -0,0 +1,51 @@ +//go:build e2e_testing +// +build e2e_testing + +package e2e + +import ( + "testing" + "time" + + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/e2e/router" + "go.uber.org/goleak" +) + +// TestNoGoroutineLeaks brings up two nebula instances, completes a tunnel, +// stops both, and asserts no goroutines leak past the shutdown. goleak's +// retry mechanism gives the wg.Wait()-driven goroutines a moment to drain +// before failing the assertion. +// +// IgnoreCurrent is necessary in the parallelized suite: other tests can +// leave goroutines mid-shutdown when this one runs (Stop is async, the +// wg.Wait() drain is not blocking on test return). We're checking that +// *this* test's setup tears down cleanly, not that the whole suite is +// idle at this moment. Intentionally NOT t.Parallel()'d for the same +// reason — concurrent test goroutines would always show up. +func TestNoGoroutineLeaks(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + myControl.Stop() + theirControl.Stop() + r.RenderFlow() + + // Settle period: Stop() is non-blocking; the wg-driven goroutines need + // a moment to drain. goleak retries internally too, but a short explicit + // settle reduces flakes when the suite is busy. + time.Sleep(50 * time.Millisecond) +} diff --git a/e2e/tunnels_test.go b/e2e/tunnels_test.go index e8e41945..63c655f3 100644 --- a/e2e/tunnels_test.go +++ b/e2e/tunnels_test.go @@ -19,6 +19,7 @@ import ( ) func TestDropInactiveTunnels(t *testing.T) { + t.Parallel() // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides // under ideal conditions ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) @@ -63,6 +64,7 @@ func TestDropInactiveTunnels(t *testing.T) { } func TestCertUpgrade(t *testing.T) { + t.Parallel() // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides // under ideal conditions ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) @@ -157,6 +159,7 @@ func TestCertUpgrade(t *testing.T) { } func TestCertDowngrade(t *testing.T) { + t.Parallel() // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides // under ideal conditions ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) @@ -255,6 +258,7 @@ func TestCertDowngrade(t *testing.T) { } func TestCertMismatchCorrection(t *testing.T) { + t.Parallel() // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides // under ideal conditions ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) @@ -322,6 +326,7 @@ func TestCertMismatchCorrection(t *testing.T) { } func TestCrossStackRelaysWork(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}}) @@ -369,6 +374,7 @@ func TestCrossStackRelaysWork(t *testing.T) { } func TestCloseTunnelAuthenticated(t *testing.T) { + t.Parallel() ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "5s"}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "10m"}}) diff --git a/firewall_test.go b/firewall_test.go index cbf090fd..40b57477 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -1033,7 +1033,7 @@ func TestNewFirewallFromConfig(t *testing.T) { l := test.NewLogger() // Test a bad rule definition c := &dummyCert{} - cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil) + cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil, "aes") require.NoError(t, err) conf := config.NewC(test.NewLogger()) diff --git a/go.mod b/go.mod index 0de2df7d..24d901c5 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stretchr/testify v1.11.1 github.com/vishvananda/netlink v1.3.1 + go.uber.org/goleak v1.3.0 go.yaml.in/yaml/v3 v3.0.4 golang.org/x/crypto v0.50.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 diff --git a/handshake/credential.go b/handshake/credential.go new file mode 100644 index 00000000..f6cd5f41 --- /dev/null +++ b/handshake/credential.go @@ -0,0 +1,57 @@ +package handshake + +import ( + "crypto/rand" + + "github.com/flynn/noise" + "github.com/slackhq/nebula/cert" +) + +// Credential holds everything needed to participate in a handshake +// at a given cert version. Version and Curve are read from Cert; the public +// half of the static keypair likewise comes from Cert.PublicKey(). +type Credential struct { + Cert cert.Certificate // the certificate + Bytes []byte // pre-marshaled certificate bytes + privateKey []byte // static private key (public half lives in Cert) + cipherSuite noise.CipherSuite // pre-built cipher suite (DH + cipher + hash) +} + +// NewCredential creates a Credential with all material needed for handshake +// participation. The cipherSuite should be pre-built by the caller with the +// appropriate DH function, cipher, and hash. +func NewCredential( + c cert.Certificate, + hsBytes []byte, + privateKey []byte, + cipherSuite noise.CipherSuite, +) *Credential { + return &Credential{ + Cert: c, + Bytes: hsBytes, + privateKey: privateKey, + cipherSuite: cipherSuite, + } +} + +// buildHandshakeState creates a noise.HandshakeState from this credential. +func (hc *Credential) buildHandshakeState(initiator bool, pattern noise.HandshakePattern) (*noise.HandshakeState, error) { + return noise.NewHandshakeState(noise.Config{ + CipherSuite: hc.cipherSuite, + Random: rand.Reader, + Pattern: pattern, + Initiator: initiator, + StaticKeypair: noise.DHKey{Private: hc.privateKey, Public: hc.Cert.PublicKey()}, + PresharedKey: []byte{}, + PresharedKeyPlacement: 0, + }) +} + +// GetCredentialFunc returns the handshake credential for the given version, +// or nil if that version is not available. +// +// Implementations must return credentials drawn from a snapshot stable for +// the lifetime of any single Machine. The Machine may call this multiple +// times during a handshake (e.g. when negotiating to the peer's version) +// and assumes the underlying static keypair is consistent across calls. +type GetCredentialFunc func(v cert.Version) *Credential diff --git a/handshake/errors.go b/handshake/errors.go new file mode 100644 index 00000000..bb8a5893 --- /dev/null +++ b/handshake/errors.go @@ -0,0 +1,21 @@ +package handshake + +import "errors" + +var ( + ErrInitiateOnResponder = errors.New("initiate called on responder") + ErrInitiateAlreadyCalled = errors.New("initiate already called") + ErrInitiateNotCalled = errors.New("initiate must be called before ProcessPacket for initiators") + ErrPacketTooShort = errors.New("packet too short") + ErrPublicKeyMismatch = errors.New("public key mismatch between certificate and handshake") + ErrIncompleteHandshake = errors.New("handshake completed without receiving required content") + ErrMachineFailed = errors.New("handshake machine has failed") + ErrUnknownSubtype = errors.New("unknown handshake subtype") + ErrMissingContent = errors.New("expected handshake content but message was empty") + ErrUnexpectedContent = errors.New("received unexpected handshake content") + ErrIndexAllocation = errors.New("failed to allocate local index") + ErrNoCredential = errors.New("no handshake credential available for cert version") + ErrAsymmetricCipherKeys = errors.New("noise produced only one cipher key") + ErrMultiMessageUnsupported = errors.New("multi-message handshake patterns are not yet supported by the manager") + ErrSubtypeMismatch = errors.New("packet subtype does not match handshake machine subtype") +) diff --git a/handshake/handshake.proto b/handshake/handshake.proto new file mode 100644 index 00000000..8eb32aa6 --- /dev/null +++ b/handshake/handshake.proto @@ -0,0 +1,29 @@ +// This file documents the wire format the nebula handshake speaks. It is +// not run through protoc; the encoder/decoder in payload.go is hand-written +// against this shape directly to keep the parser narrow and panic-free. +// +// Any change to the wire format must be reflected here, and adding a new +// field requires updating MarshalPayload / unmarshalPayloadDetails together +// with the field-uniqueness and wire-type checks in those functions. + +syntax = "proto3"; +package nebula.handshake; + +message NebulaHandshake { + NebulaHandshakeDetails Details = 1; + bytes Hmac = 2; +} + +message NebulaHandshakeDetails { + bytes Cert = 1; + uint32 InitiatorIndex = 2; + uint32 ResponderIndex = 3; + // Cookie was reserved for an anti-DoS mechanism that was never + // implemented. No released version of nebula has ever populated it; the + // hand-written parser silently skips it on read. + uint64 Cookie = 4 [deprecated = true]; + uint64 Time = 5; + uint32 CertVersion = 8; + // reserved for WIP multiport + reserved 6, 7; +} diff --git a/handshake/helpers_test.go b/handshake/helpers_test.go new file mode 100644 index 00000000..c72346cb --- /dev/null +++ b/handshake/helpers_test.go @@ -0,0 +1,116 @@ +package handshake + +import ( + "net/netip" + "testing" + "time" + + "github.com/flynn/noise" + "github.com/slackhq/nebula/cert" + ct "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/header" + "github.com/stretchr/testify/require" +) + +// testCertState holds cert material for a test peer. +type testCertState struct { + version cert.Version + creds map[cert.Version]*Credential +} + +func (s *testCertState) getCredential(v cert.Version) *Credential { + return s.creds[v] +} + +func newTestCertState( + t *testing.T, ca cert.Certificate, caKey []byte, name string, networks []netip.Prefix, +) *testCertState { + return newTestCertStateWithCipher(t, ca, caKey, name, networks, noise.CipherChaChaPoly) +} + +func newTestCertStateWithCipher( + t *testing.T, ca cert.Certificate, caKey []byte, name string, networks []netip.Prefix, + cipher noise.CipherFunc, +) *testCertState { + t.Helper() + c, _, rawPrivKey, _ := ct.NewTestCert( + cert.Version2, cert.Curve_CURVE25519, ca, caKey, + name, ca.NotBefore(), ca.NotAfter(), networks, nil, nil, + ) + + priv, _, _, err := cert.UnmarshalPrivateKeyFromPEM(rawPrivKey) + require.NoError(t, err) + + hsBytes, err := c.MarshalForHandshakes() + require.NoError(t, err) + + ncs := noise.NewCipherSuite(noise.DH25519, cipher, noise.HashSHA256) + return &testCertState{ + version: cert.Version2, + creds: map[cert.Version]*Credential{ + cert.Version2: NewCredential(c, hsBytes, priv, ncs), + }, + } +} + +func testVerifier(pool *cert.CAPool) CertVerifier { + return func(c cert.Certificate) (*cert.CachedCertificate, error) { + return pool.VerifyCertificate(time.Now(), c) + } +} + +func newTestMachine( + t *testing.T, + cs *testCertState, + verifier CertVerifier, + initiator bool, + localIndex uint32, +) *Machine { + t.Helper() + m, err := NewMachine( + cs.version, cs.getCredential, + verifier, func() (uint32, error) { return localIndex, nil }, + initiator, header.HandshakeIXPSK0, + ) + require.NoError(t, err) + return m +} + +func initiateHandshake( + t *testing.T, + initCS *testCertState, initVerifier CertVerifier, + respCS *testCertState, respVerifier CertVerifier, +) (initM, respM *Machine, respResult *Result, resp []byte, err error) { + t.Helper() + initM = newTestMachine(t, initCS, initVerifier, true, 100) + msg1, merr := initM.Initiate(nil) + require.NoError(t, merr) + + respM = newTestMachine(t, respCS, respVerifier, false, 200) + resp, respResult, err = respM.ProcessPacket(nil, msg1) + return +} + +func doFullHandshake( + t *testing.T, initCS, respCS *testCertState, caPool *cert.CAPool, +) (initResult, respResult *Result) { + t.Helper() + v := testVerifier(caPool) + + initM := newTestMachine(t, initCS, v, true, 1000) + respM := newTestMachine(t, respCS, v, false, 2000) + + msg1, err := initM.Initiate(nil) + require.NoError(t, err) + + resp, respResult, err := respM.ProcessPacket(nil, msg1) + require.NoError(t, err) + require.NotNil(t, respResult) + require.NotEmpty(t, resp) + + _, initResult, err = initM.ProcessPacket(nil, resp) + require.NoError(t, err) + require.NotNil(t, initResult) + + return initResult, respResult +} diff --git a/handshake/machine.go b/handshake/machine.go new file mode 100644 index 00000000..25ed3a5a --- /dev/null +++ b/handshake/machine.go @@ -0,0 +1,444 @@ +package handshake + +import ( + "bytes" + "fmt" + "slices" + "time" + + "github.com/flynn/noise" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/header" +) + +// IndexAllocator is called by the Machine to allocate a local index for the +// handshake. It is called at most once, when the first outgoing message that +// carries a payload is built. +// +// Implementations MUST NOT return 0. Zero is reserved as a sentinel meaning +// "no index assigned" on the wire and in the payload-presence checks. If an +// allocator ever returned 0, a legitimate handshake's payload could be +// indistinguishable from an empty one and would be rejected. +type IndexAllocator func() (uint32, error) + +// CertVerifier is called by the Machine after reconstructing the peer's +// certificate from the handshake. The verifier performs all validation +// (CA trust, expiry, policy checks, allow lists). +type CertVerifier func(cert.Certificate) (*cert.CachedCertificate, error) + +// Result contains the results of a successful handshake. +// Returned by ProcessPacket when the handshake is complete. +type Result struct { + EKey *noise.CipherState + DKey *noise.CipherState + MyCert cert.Certificate + RemoteCert *cert.CachedCertificate + RemoteIndex uint32 + LocalIndex uint32 + HandshakeTime uint64 + MessageIndex uint64 // number of messages exchanged during the handshake + Initiator bool +} + +// Machine drives a Noise handshake through N messages. It handles Noise +// protocol operations, certificate reconstruction, and payload encoding. +// Certificate validation is delegated to the caller via CertVerifier. +// +// A Machine is not safe for concurrent use. The caller must ensure that +// Initiate and ProcessPacket are not called concurrently. +// +// Error contract: when ProcessPacket or Initiate returns an error, callers +// must check Failed() to decide what to do next. If Failed() is false the +// underlying noise state was not advanced (the packet was rejected before +// ReadMessage took effect, or the rejection is non-fatal like a stale +// retransmit) and the Machine can accept another packet. If Failed() is +// true the Machine is unrecoverable and the caller must abandon it. +type Machine struct { + hs *noise.HandshakeState + getCred GetCredentialFunc + allocIndex IndexAllocator + verifier CertVerifier + result *Result + msgs []msgFlags + myVersion cert.Version + subtype header.MessageSubType + indexAllocated bool + remoteCertSet bool + payloadSet bool + failed bool +} + +// NewMachine creates a handshake state machine. The subtype determines both +// the noise pattern and the per-message content layout. The credential for +// `version` is fetched via getCred and used to seed the noise.HandshakeState. +// IndexAllocator is called lazily when the first outgoing payload is built. +func NewMachine( + version cert.Version, + getCred GetCredentialFunc, + verifier CertVerifier, + allocIndex IndexAllocator, + initiator bool, + subtype header.MessageSubType, +) (*Machine, error) { + info, err := subtypeInfoFor(subtype) + if err != nil { + return nil, err + } + + cred := getCred(version) + if cred == nil { + return nil, fmt.Errorf("%w: %v", ErrNoCredential, version) + } + + hs, err := cred.buildHandshakeState(initiator, info.pattern) + if err != nil { + return nil, fmt.Errorf("build noise state: %w", err) + } + + return &Machine{ + hs: hs, + subtype: subtype, + msgs: info.msgs, + getCred: getCred, + allocIndex: allocIndex, + verifier: verifier, + myVersion: version, + result: &Result{ + Initiator: initiator, + }, + }, nil +} + +// Failed returns true if the Machine is in an unrecoverable state. +func (m *Machine) Failed() bool { + return m.failed +} + +// Subtype returns the handshake subtype this Machine was built for. +func (m *Machine) Subtype() header.MessageSubType { + return m.subtype +} + +// MessageIndex returns the noise handshake message index, which equals the +// wire counter of the most recently sent or received message. +func (m *Machine) MessageIndex() int { + return m.hs.MessageIndex() +} + +// requireComplete checks that both a peer cert and payload have been received. +// Marks the machine as failed if not. +func (m *Machine) requireComplete() error { + if !m.payloadSet || !m.remoteCertSet { + m.failed = true + return ErrIncompleteHandshake + } + return nil +} + +// myMsgFlags returns the flags for the current outgoing message. +func (m *Machine) myMsgFlags() msgFlags { + idx := m.hs.MessageIndex() + if idx < len(m.msgs) { + return m.msgs[idx] + } + return msgFlags{} +} + +// peerMsgFlags returns the flags for the message we just read. +func (m *Machine) peerMsgFlags() msgFlags { + idx := m.hs.MessageIndex() - 1 + if idx >= 0 && idx < len(m.msgs) { + return m.msgs[idx] + } + return msgFlags{} +} + +// Initiate produces the first handshake message. Only valid for initiators, +// and must be called exactly once before ProcessPacket. +// +// out is a destination buffer the message is appended to and returned. Pass +// nil to allocate fresh, or pass a re-used buffer sliced to length 0 (e.g. +// buf[:0]) with sufficient capacity to avoid allocation. +// +// An error return may not indicate a fatal condition, check Failed() to +// determine if the Machine can still be used. +func (m *Machine) Initiate(out []byte) ([]byte, error) { + if m.failed { + return nil, ErrMachineFailed + } + if !m.result.Initiator { + m.failed = true + return nil, ErrInitiateOnResponder + } + if m.hs.MessageIndex() != 0 { + m.failed = true + return nil, ErrInitiateAlreadyCalled + } + + // At MessageIndex=0 with RemoteIndex still zero, buildResponse produces + // header counter 1 and remote index 0, which is what the initial message needs. + out, _, _, err := m.buildResponse(out) + if err != nil { + m.failed = true + return nil, err + } + return out, nil +} + +// ProcessPacket handles an incoming handshake message. It advances the Noise +// state, validates the peer certificate via the verifier, and optionally +// produces a response. +// +// out is a destination buffer the response is appended to and returned. Pass +// nil to allocate fresh, or pass a re-used buffer sliced to length 0 (e.g. +// buf[:0]) with sufficient capacity to avoid allocation. The returned slice +// is nil when no outgoing message is produced (handshake complete on this +// side, or final message of a multi-message pattern). +// +// Returns a non-nil Result when the handshake is complete. +// An error return may not indicate a fatal condition, check Failed() to +// determine if the Machine can still be used. +func (m *Machine) ProcessPacket(out, packet []byte) ([]byte, *Result, error) { + if m.failed { + return nil, nil, ErrMachineFailed + } + if len(packet) < header.Len { + return nil, nil, ErrPacketTooShort + } + // Reject packets whose subtype doesn't match the one this Machine was + // built for. A pending handshake that suddenly receives a different + // subtype on its index is either a stray packet that matched by chance + // or a peer protocol violation; drop it without failing the Machine so + // the legitimate retransmit can still complete. + if header.MessageSubType(packet[1]) != m.subtype { + return nil, nil, ErrSubtypeMismatch + } + if m.result.Initiator && m.hs.MessageIndex() == 0 { + m.failed = true + return nil, nil, ErrInitiateNotCalled + } + + // The (eKey, dKey) ordering here is correct for IX, where the initiator + // completes the handshake by reading the responder's stage-2 message. + // noise returns (cs1, cs2) where cs1 is the initiator->responder cipher. + // For 3-message patterns where a responder finishes by reading the final + // message, this ordering would be wrong; revisit when XX/pqIX lands. + msg, eKey, dKey, err := m.hs.ReadMessage(nil, packet[header.Len:]) + if err != nil { + // Noise ReadMessage failed. The noise library checkpoints and rolls back + // on failure, so the Machine is still alive. The caller can retry with + // a different packet. + return nil, nil, fmt.Errorf("noise ReadMessage: %w", err) + } + + // From here on, noise state has advanced. Any error is fatal. + flags := m.peerMsgFlags() + + if err := m.processPayload(msg, flags); err != nil { + return nil, nil, err + } + + // If ReadMessage derived keys, the handshake is complete. Noise should + // always produce both keys together; asymmetry is a protocol invariant + // violation. + if eKey != nil || dKey != nil { + if eKey == nil || dKey == nil { + m.failed = true + return nil, nil, ErrAsymmetricCipherKeys + } + if err := m.requireComplete(); err != nil { + return nil, nil, err + } + return nil, m.completed(eKey, dKey), nil + } + + // ReadMessage didn't complete, produce the next outgoing message + out, dk, ek, err := m.buildResponse(out) + if err != nil { + m.failed = true + return nil, nil, err + } + + if ek != nil || dk != nil { + if ek == nil || dk == nil { + m.failed = true + return nil, nil, ErrAsymmetricCipherKeys + } + if err := m.requireComplete(); err != nil { + return nil, nil, err + } + return out, m.completed(ek, dk), nil + } + + return out, nil, nil +} + +func (m *Machine) completed(eKey, dKey *noise.CipherState) *Result { + m.result.EKey = eKey + m.result.DKey = dKey + m.result.MessageIndex = uint64(m.hs.MessageIndex()) + return m.result +} + +func (m *Machine) processPayload(msg []byte, flags msgFlags) error { + if len(msg) == 0 { + if flags.expectsPayload || flags.expectsCert { + m.failed = true + return ErrMissingContent + } + return nil + } + + payload, err := UnmarshalPayload(msg) + if err != nil { + m.failed = true + return fmt.Errorf("unmarshal handshake: %w", err) + } + + // Assert the payload contains exactly what we expect + hasPayloadData := payload.InitiatorIndex != 0 || payload.ResponderIndex != 0 || payload.Time != 0 + if hasPayloadData != flags.expectsPayload { + m.failed = true + return ErrUnexpectedContent + } + + hasCertData := len(payload.Cert) > 0 + if hasCertData != flags.expectsCert { + m.failed = true + return ErrUnexpectedContent + } + + // Process payload + if flags.expectsPayload { + if m.result.Initiator { + m.result.RemoteIndex = payload.ResponderIndex + } else { + m.result.RemoteIndex = payload.InitiatorIndex + } + m.result.HandshakeTime = payload.Time + m.payloadSet = true + } + + // Process certificate + if flags.expectsCert { + if err := m.validateCert(payload); err != nil { + return err + } + } + + return nil +} + +func (m *Machine) validateCert(payload Payload) error { + cred := m.getCred(m.myVersion) + if cred == nil { + m.failed = true + return fmt.Errorf("%w: %v", ErrNoCredential, m.myVersion) + } + rc, err := cert.Recombine( + cert.Version(payload.CertVersion), + payload.Cert, + m.hs.PeerStatic(), + cred.Cert.Curve(), + ) + if err != nil { + m.failed = true + return fmt.Errorf("recombine cert: %w", err) + } + + if !bytes.Equal(rc.PublicKey(), m.hs.PeerStatic()) { + m.failed = true + return ErrPublicKeyMismatch + } + + // Version negotiation, if the peer sent a different version and we have it, switch + if rc.Version() != m.myVersion { + if m.getCred(rc.Version()) != nil { + m.myVersion = rc.Version() + } + } + + verified, err := m.verifier(rc) + if err != nil { + m.failed = true + return fmt.Errorf("verify cert: %w", err) + } + + m.result.RemoteCert = verified + m.remoteCertSet = true + return nil +} + +func (m *Machine) marshalOutgoing(flags msgFlags) ([]byte, error) { + if !flags.expectsPayload && !flags.expectsCert { + return nil, nil + } + + var p Payload + if flags.expectsPayload { + if !m.indexAllocated { + index, err := m.allocIndex() + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrIndexAllocation, err) + } + m.result.LocalIndex = index + m.indexAllocated = true + } + + if m.result.Initiator { + p.InitiatorIndex = m.result.LocalIndex + } else { + p.ResponderIndex = m.result.LocalIndex + p.InitiatorIndex = m.result.RemoteIndex + } + p.Time = uint64(time.Now().UnixNano()) + } + if flags.expectsCert { + cred := m.getCred(m.myVersion) + if cred == nil { + return nil, fmt.Errorf("%w: %v", ErrNoCredential, m.myVersion) + } + p.Cert = cred.Bytes + p.CertVersion = uint32(cred.Cert.Version()) + m.result.MyCert = cred.Cert + } + + return MarshalPayload(nil, p), nil +} + +func (m *Machine) buildResponse(out []byte) ([]byte, *noise.CipherState, *noise.CipherState, error) { + flags := m.myMsgFlags() + hsBytes, err := m.marshalOutgoing(flags) + if err != nil { + return nil, nil, nil, err + } + + // Extend out by header.Len to make room for the header. slices.Grow is a + // no-op when the cap is already sufficient (the zero-copy case where the + // caller passed a pre-sized buffer). header.Encode overwrites the new + // bytes, so they don't need to be zeroed. + start := len(out) + out = slices.Grow(out, header.Len)[:start+header.Len] + header.Encode( + out[start:], + header.Version, header.Handshake, m.subtype, + m.result.RemoteIndex, + uint64(m.hs.MessageIndex()+1), + ) + + // noise.WriteMessage appends the encrypted handshake message to out, + // reusing capacity when present. + // + // The (dKey, eKey) ordering here is correct for IX, where the responder + // completes the handshake by writing the stage-2 message. noise returns + // (cs1, cs2) where cs1 is the initiator->responder cipher (which is the + // responder's decrypt key). For 3-message patterns where an initiator + // finishes by writing the final message, this ordering would be wrong; + // revisit when XX/pqIX lands. + out, dKey, eKey, err := m.hs.WriteMessage(out, hsBytes) + if err != nil { + return nil, nil, nil, fmt.Errorf("noise WriteMessage: %w", err) + } + + return out, dKey, eKey, nil +} diff --git a/handshake/machine_test.go b/handshake/machine_test.go new file mode 100644 index 00000000..722a39e1 --- /dev/null +++ b/handshake/machine_test.go @@ -0,0 +1,662 @@ +package handshake + +import ( + "net/netip" + "testing" + "time" + + "github.com/flynn/noise" + "github.com/slackhq/nebula/cert" + ct "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/noiseutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMachineIXHappyPath(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + + initCS := newTestCertState(t, ca, caKey, "initiator", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + respCS := newTestCertState(t, ca, caKey, "responder", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}) + + initR, respR := doFullHandshake(t, initCS, respCS, caPool) + + assert.Equal(t, "responder", initR.RemoteCert.Certificate.Name()) + assert.Equal(t, "initiator", respR.RemoteCert.Certificate.Name()) + + assert.Equal(t, uint32(1000), initR.LocalIndex) + assert.Equal(t, uint32(2000), initR.RemoteIndex) + assert.Equal(t, uint32(2000), respR.LocalIndex) + assert.Equal(t, uint32(1000), respR.RemoteIndex) + + assert.Equal(t, uint64(2), initR.MessageIndex, "IX has 2 messages") + assert.Equal(t, uint64(2), respR.MessageIndex, "IX has 2 messages") + + ct1, err := initR.EKey.Encrypt(nil, nil, []byte("hello")) + require.NoError(t, err) + pt1, err := respR.DKey.Decrypt(nil, nil, ct1) + require.NoError(t, err) + assert.Equal(t, []byte("hello"), pt1) + + ct2, err := respR.EKey.Encrypt(nil, nil, []byte("world")) + require.NoError(t, err) + pt2, err := initR.DKey.Decrypt(nil, nil, ct2) + require.NoError(t, err) + assert.Equal(t, []byte("world"), pt2) +} + +func TestMachineInitiateErrors(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + v := testVerifier(caPool) + + t.Run("initiate on responder", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + _, err := m.Initiate(nil) + require.ErrorIs(t, err, ErrInitiateOnResponder) + assert.True(t, m.Failed()) + }) + + t.Run("initiate called twice", func(t *testing.T) { + m := newTestMachine(t, cs, v, true, 100) + _, err := m.Initiate(nil) + require.NoError(t, err) + _, err = m.Initiate(nil) + require.ErrorIs(t, err, ErrInitiateAlreadyCalled) + assert.True(t, m.Failed()) + }) + + t.Run("process packet before initiate on initiator", func(t *testing.T) { + m := newTestMachine(t, cs, v, true, 100) + _, _, err := m.ProcessPacket(nil, make([]byte, 100)) + require.ErrorIs(t, err, ErrInitiateNotCalled) + assert.True(t, m.Failed()) + }) + + t.Run("calling failed machine", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + _, err := m.Initiate(nil) // fails: responder + require.Error(t, err) + _, err = m.Initiate(nil) // fails: already failed + require.ErrorIs(t, err, ErrMachineFailed) + }) +} + +func TestMachineProcessPacketErrors(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + v := testVerifier(caPool) + + t.Run("packet too short", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + _, _, err := m.ProcessPacket(nil, []byte{1, 2, 3}) + require.ErrorIs(t, err, ErrPacketTooShort) + assert.False(t, m.Failed(), "short packet should not kill machine") + }) + + t.Run("noise decryption failure is recoverable", func(t *testing.T) { + initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + initM := newTestMachine(t, initCS, v, true, 100) + msg1, err := initM.Initiate(nil) + require.NoError(t, err) + + respM := newTestMachine(t, cs, v, false, 200) + resp, _, err := respM.ProcessPacket(nil, msg1) + require.NoError(t, err) + + corrupted := make([]byte, len(resp)) + copy(corrupted, resp) + for i := header.Len; i < len(corrupted); i++ { + corrupted[i] ^= 0xff + } + _, _, err = initM.ProcessPacket(nil, corrupted) + require.Error(t, err) + assert.False(t, initM.Failed(), "noise failure should be recoverable") + + // And the machine should still complete a real handshake afterward. + _, result, err := initM.ProcessPacket(nil, resp) + require.NoError(t, err) + require.NotNil(t, result, "initiator should complete on the legitimate response") + }) + + t.Run("invalid cert is fatal", func(t *testing.T) { + otherCA, _, otherCAKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + otherCS := newTestCertState(t, otherCA, otherCAKey, "other", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}) + + initM := newTestMachine(t, otherCS, testVerifier(ct.NewTestCAPool(otherCA)), true, 100) + msg1, err := initM.Initiate(nil) + require.NoError(t, err) + + respM := newTestMachine(t, cs, v, false, 200) + _, _, err = respM.ProcessPacket(nil, msg1) + require.Error(t, err) + assert.True(t, respM.Failed(), "cert validation failure should kill machine") + }) + + t.Run("subtype mismatch is recoverable", func(t *testing.T) { + initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + initM := newTestMachine(t, initCS, v, true, 100) + msg1, err := initM.Initiate(nil) + require.NoError(t, err) + + // Mutate the subtype byte (offset 1 in the header) to a value the + // responder Machine wasn't built for. + bad := make([]byte, len(msg1)) + copy(bad, msg1) + bad[1] = 0xff + + respM := newTestMachine(t, cs, v, false, 200) + _, _, err = respM.ProcessPacket(nil, bad) + require.ErrorIs(t, err, ErrSubtypeMismatch) + assert.False(t, respM.Failed(), "subtype mismatch should not kill the machine") + + // And the machine should still complete a real handshake afterward. + resp, result, err := respM.ProcessPacket(nil, msg1) + require.NoError(t, err) + require.NotNil(t, result, "responder should complete on the legitimate stage-1 packet") + assert.NotEmpty(t, resp, "responder should produce a stage-2 reply") + }) +} + +// TestMachineProcessPayload exercises processPayload's internal validation +// directly. Most of these failure modes can't be reached black-box once the +// subtype check at the top of ProcessPacket gates external callers, so we +// drive them by hand here for coverage. +func TestMachineProcessPayload(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + v := testVerifier(caPool) + + t.Run("empty message with expects fails", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + err := m.processPayload(nil, msgFlags{expectsPayload: true, expectsCert: true}) + require.ErrorIs(t, err, ErrMissingContent) + assert.True(t, m.Failed()) + }) + + t.Run("empty message with no expects passes", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + err := m.processPayload(nil, msgFlags{}) + require.NoError(t, err) + assert.False(t, m.Failed()) + }) + + t.Run("malformed protobuf is fatal", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + err := m.processPayload([]byte{0xff, 0xff, 0xff}, msgFlags{expectsPayload: true, expectsCert: true}) + require.Error(t, err) + assert.True(t, m.Failed()) + }) + + t.Run("unexpected payload data is fatal", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + // A payload with index data when none was expected. + bytes := MarshalPayload(nil, Payload{InitiatorIndex: 42, Time: 1}) + err := m.processPayload(bytes, msgFlags{expectsPayload: false, expectsCert: false}) + require.ErrorIs(t, err, ErrUnexpectedContent) + assert.True(t, m.Failed()) + }) + + t.Run("unexpected cert data is fatal", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + // A payload with cert when none was expected. + bytes := MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2}) + err := m.processPayload(bytes, msgFlags{expectsPayload: false, expectsCert: false}) + require.ErrorIs(t, err, ErrUnexpectedContent) + assert.True(t, m.Failed()) + }) + + t.Run("missing payload data when expected is fatal", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + // Cert present, but no index/time fields. + bytes := MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2}) + err := m.processPayload(bytes, msgFlags{expectsPayload: true, expectsCert: true}) + require.ErrorIs(t, err, ErrUnexpectedContent) + assert.True(t, m.Failed()) + }) +} + +// TestMachineRequireComplete checks the fail-on-incomplete-handshake path +// directly. Like processPayload above this isn't reachable from a normal IX +// flow, so we drive it by hand. +func TestMachineRequireComplete(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + v := testVerifier(caPool) + + t.Run("missing both fails", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + err := m.requireComplete() + require.ErrorIs(t, err, ErrIncompleteHandshake) + assert.True(t, m.Failed()) + }) + + t.Run("payload only fails", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + m.payloadSet = true + err := m.requireComplete() + require.ErrorIs(t, err, ErrIncompleteHandshake) + assert.True(t, m.Failed()) + }) + + t.Run("cert only fails", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + m.remoteCertSet = true + err := m.requireComplete() + require.ErrorIs(t, err, ErrIncompleteHandshake) + assert.True(t, m.Failed()) + }) + + t.Run("both set passes", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + m.payloadSet = true + m.remoteCertSet = true + err := m.requireComplete() + require.NoError(t, err) + assert.False(t, m.Failed()) + }) +} + +func TestMachineAESCipher(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + + initCS := newTestCertStateWithCipher( + t, ca, caKey, "init", + []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}, + noiseutil.CipherAESGCM, + ) + respCS := newTestCertStateWithCipher( + t, ca, caKey, "resp", + []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, + noiseutil.CipherAESGCM, + ) + + initR, respR := doFullHandshake(t, initCS, respCS, caPool) + + ct1, err := initR.EKey.Encrypt(nil, nil, []byte("works")) + require.NoError(t, err) + pt1, err := respR.DKey.Decrypt(nil, nil, ct1) + require.NoError(t, err) + assert.Equal(t, []byte("works"), pt1) + + ct2, err := respR.EKey.Encrypt(nil, nil, []byte("back")) + require.NoError(t, err) + pt2, err := initR.DKey.Decrypt(nil, nil, ct2) + require.NoError(t, err) + assert.Equal(t, []byte("back"), pt2) +} + +func TestResultFields(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}) + + initR, respR := doFullHandshake(t, initCS, respCS, caPool) + + assert.True(t, initR.Initiator) + assert.False(t, respR.Initiator) + assert.NotZero(t, initR.HandshakeTime) + assert.NotZero(t, respR.HandshakeTime) + assert.NotNil(t, initR.RemoteCert) + assert.NotNil(t, respR.RemoteCert) +} + +func TestMachineBufferReuse(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}) + v := testVerifier(caPool) + + initM := newTestMachine(t, initCS, v, true, 1000) + respM := newTestMachine(t, respCS, v, false, 2000) + + msg1, err := initM.Initiate(nil) + require.NoError(t, err) + + t.Run("response writes into provided buffer", func(t *testing.T) { + buf := make([]byte, 0, 4096) + resp, result, err := respM.ProcessPacket(buf, msg1) + require.NoError(t, err) + require.NotNil(t, result) + + assert.NotEmpty(t, resp, "response should have content") + assert.Equal(t, &buf[:1][0], &resp[:1][0], + "response should reuse the provided buffer's backing array") + }) + + t.Run("initiate writes into provided buffer", func(t *testing.T) { + initM2 := newTestMachine(t, initCS, v, true, 3000) + buf := make([]byte, 0, 4096) + msg, err := initM2.Initiate(buf) + require.NoError(t, err) + + assert.NotEmpty(t, msg, "initiate should have content") + assert.Equal(t, &buf[:1][0], &msg[:1][0], + "initiate should reuse the provided buffer's backing array") + }) + + t.Run("nil out still works", func(t *testing.T) { + initM2 := newTestMachine(t, initCS, v, true, 4000) + respM2 := newTestMachine(t, respCS, v, false, 5000) + + msg1, err := initM2.Initiate(nil) + require.NoError(t, err) + + resp, _, err := respM2.ProcessPacket(nil, msg1) + require.NoError(t, err) + + out, result, err := initM2.ProcessPacket(nil, resp) + require.NoError(t, err) + assert.NotNil(t, result) + assert.Nil(t, out, "initiator should have no response for IX msg2") + }) +} + +func TestMachineMsgIndexTracking(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}) + v := testVerifier(caPool) + + initM := newTestMachine(t, initCS, v, true, 100) + respM := newTestMachine(t, respCS, v, false, 200) + + msg1, err := initM.Initiate(nil) + require.NoError(t, err) + + resp1, result1, err := respM.ProcessPacket(nil, msg1) + require.NoError(t, err) + assert.NotNil(t, result1) + + _, result2, err := initM.ProcessPacket(nil, resp1) + require.NoError(t, err) + assert.NotNil(t, result2) +} + +func TestMachineThreeMessagePattern(t *testing.T) { + registerTestXXInfo(t) + + // Use HandshakeXX (3 messages) to verify the Machine handles multi-message + // patterns correctly. XX flow: + // msg1 (I->R): [E] - payload only, no cert + // msg2 (R->I): [E, ee, S, es] - payload + cert + // msg3 (I->R): [S, se] - cert only (no payload, not first two) + + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + v := testVerifier(caPool) + + initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}) + respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}) + + initM, err := NewMachine( + cert.Version2, + initCS.getCredential, v, + func() (uint32, error) { return 1000, nil }, + true, header.HandshakeXXPSK0, + ) + require.NoError(t, err) + + respM, err := NewMachine( + cert.Version2, + respCS.getCredential, v, + func() (uint32, error) { return 2000, nil }, + false, header.HandshakeXXPSK0, + ) + require.NoError(t, err) + + // msg1: initiator -> responder (E only, no cert) + msg1, err := initM.Initiate(nil) + require.NoError(t, err) + assert.NotEmpty(t, msg1) + + // Responder processes msg1, should not complete yet, should produce msg2 + msg2, result, err := respM.ProcessPacket(nil, msg1) + require.NoError(t, err) + assert.Nil(t, result, "XX should not complete on msg1") + assert.NotEmpty(t, msg2, "responder should produce msg2") + + // Initiator processes msg2: gets responder's cert, produces msg3, and + // completes (WriteMessage for msg3 derives keys) + msg3, initResult, err := initM.ProcessPacket(nil, msg2) + require.NoError(t, err) + require.NotNil(t, initResult, "XX initiator should complete after reading msg2 and writing msg3") + assert.NotEmpty(t, msg3, "initiator should produce msg3") + assert.Equal(t, "resp", initResult.RemoteCert.Certificate.Name()) + + // Responder processes msg3: gets initiator's cert and completes + _, respResult, err := respM.ProcessPacket(nil, msg3) + require.NoError(t, err) + require.NotNil(t, respResult, "XX responder should complete on msg3") + assert.Equal(t, "init", respResult.RemoteCert.Certificate.Name()) + + assert.Equal(t, uint64(3), initResult.MessageIndex, "XX has 3 messages") + assert.Equal(t, uint64(3), respResult.MessageIndex, "XX has 3 messages") + + // Verify keys work + ct1, err := initResult.EKey.Encrypt(nil, nil, []byte("three messages")) + require.NoError(t, err) + pt1, err := respResult.DKey.Decrypt(nil, nil, ct1) + require.NoError(t, err) + assert.Equal(t, []byte("three messages"), pt1) +} + +// NOTE: ErrIncompleteHandshake is tested implicitly. It can't be triggered with +// IX since the cert is always in the payload. A 3-message pattern test (HybridIX) +// should exercise the case where cert arrives in msg3 and verify that completing +// without it fails. + +func TestMachineExpiredCert(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, + time.Now().Add(-24*time.Hour), time.Now().Add(24*time.Hour), + nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + + expCert, _, expKeyPEM, _ := ct.NewTestCert( + cert.Version2, cert.Curve_CURVE25519, ca, caKey, + "expired", time.Now().Add(-2*time.Hour), time.Now().Add(-1*time.Hour), + []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}, nil, nil, + ) + expKey, _, _, err := cert.UnmarshalPrivateKeyFromPEM(expKeyPEM) + require.NoError(t, err) + expHsBytes, err := expCert.MarshalForHandshakes() + require.NoError(t, err) + ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256) + + expiredCS := &testCertState{ + version: cert.Version2, + creds: map[cert.Version]*Credential{ + cert.Version2: NewCredential(expCert, expHsBytes, expKey, ncs), + }, + } + + respCS := newTestCertState( + t, ca, caKey, "responder", + []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, + ) + + _, respM, _, _, err := initiateHandshake( + t, expiredCS, testVerifier(caPool), + respCS, testVerifier(caPool), + ) + require.ErrorContains(t, err, "verify cert") + assert.True(t, respM.Failed()) +} + +func TestMachineNoCertNetworks(t *testing.T) { + ca, _, caKey, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca) + + caHsBytes, err := ca.MarshalForHandshakes() + require.NoError(t, err) + ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256) + + noNetCS := &testCertState{ + version: cert.Version2, + creds: map[cert.Version]*Credential{ + cert.Version2: NewCredential(ca, caHsBytes, caKey, ncs), + }, + } + + respCS := newTestCertState( + t, ca, caKey, "responder", + []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, + ) + + _, respM, _, _, err := initiateHandshake( + t, noNetCS, testVerifier(caPool), + respCS, testVerifier(caPool), + ) + require.Error(t, err) + assert.True(t, respM.Failed()) +} + +func TestMachineDifferentCAs(t *testing.T) { + ca1, _, caKey1, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + ca2, _, caKey2, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + + initCS := newTestCertState( + t, ca1, caKey1, "init", + []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}, + ) + respCS := newTestCertState( + t, ca2, caKey2, "resp", + []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, + ) + + _, respM, _, _, err := initiateHandshake( + t, initCS, testVerifier(ct.NewTestCAPool(ca1)), + respCS, testVerifier(ct.NewTestCAPool(ca2)), + ) + require.ErrorContains(t, err, "verify cert") + assert.True(t, respM.Failed()) +} + +func TestMachineVersionNegotiation(t *testing.T) { + ca1, _, caKey1, _ := ct.NewTestCaCert( + cert.Version1, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + ca2, _, caKey2, _ := ct.NewTestCaCert( + cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, + ) + caPool := ct.NewTestCAPool(ca1, ca2) + + makeMultiVersionResp := func(t *testing.T) *testCertState { + t.Helper() + respCertV1, _, respKeyPEM, _ := ct.NewTestCert( + cert.Version1, cert.Curve_CURVE25519, ca1, caKey1, "resp", + ca1.NotBefore(), ca1.NotAfter(), + []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, nil, nil, + ) + respKey, _, _, _ := cert.UnmarshalPrivateKeyFromPEM(respKeyPEM) + respCertV2, _ := ct.NewTestCertDifferentVersion(respCertV1, cert.Version2, ca2, caKey2) + respHsV1, _ := respCertV1.MarshalForHandshakes() + respHsV2, _ := respCertV2.MarshalForHandshakes() + ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256) + return &testCertState{ + version: cert.Version1, + creds: map[cert.Version]*Credential{ + cert.Version1: NewCredential(respCertV1, respHsV1, respKey, ncs), + cert.Version2: NewCredential(respCertV2, respHsV2, respKey, ncs), + }, + } + } + + t.Run("responder matches initiator version", func(t *testing.T) { + initCS := newTestCertState( + t, ca2, caKey2, "init", + []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}, + ) + respCS := makeMultiVersionResp(t) + v := testVerifier(caPool) + + initM, _, respResult, resp, err := initiateHandshake( + t, initCS, v, + respCS, v, + ) + require.NoError(t, err) + require.NotNil(t, respResult) + + assert.Equal(t, cert.Version2, respResult.MyCert.Version(), + "responder should negotiate to initiator's version") + + _, initResult, err := initM.ProcessPacket(nil, resp) + require.NoError(t, err) + require.NotNil(t, initResult) + assert.Equal(t, cert.Version2, initResult.RemoteCert.Certificate.Version(), + "initiator should see V2 cert from responder") + }) + + t.Run("responder keeps version when no match available", func(t *testing.T) { + initCS := newTestCertState( + t, ca2, caKey2, "init", + []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}, + ) + + respCert, _, respKeyPEM, _ := ct.NewTestCert( + cert.Version1, cert.Curve_CURVE25519, ca1, caKey1, "resp", + ca1.NotBefore(), ca1.NotAfter(), + []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, nil, nil, + ) + respKey, _, _, _ := cert.UnmarshalPrivateKeyFromPEM(respKeyPEM) + respHs, _ := respCert.MarshalForHandshakes() + ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256) + respCS := &testCertState{ + version: cert.Version1, + creds: map[cert.Version]*Credential{ + cert.Version1: NewCredential(respCert, respHs, respKey, ncs), + }, + } + + v := testVerifier(caPool) + _, _, respResult, _, err := initiateHandshake( + t, initCS, v, + respCS, v, + ) + require.NoError(t, err) + require.NotNil(t, respResult) + + assert.Equal(t, cert.Version1, respResult.MyCert.Version(), + "responder should keep V1 when V2 not available") + }) +} diff --git a/handshake/patterns.go b/handshake/patterns.go new file mode 100644 index 00000000..a0cc1a70 --- /dev/null +++ b/handshake/patterns.go @@ -0,0 +1,54 @@ +package handshake + +import ( + "fmt" + + "github.com/flynn/noise" + "github.com/slackhq/nebula/header" +) + +// msgFlags tracks what application data a handshake message carries. +type msgFlags struct { + expectsPayload bool // message carries indexes and time + expectsCert bool // message carries the certificate +} + +// subtypeInfo bundles the noise pattern with the per-message flags for a +// given handshake subtype. +type subtypeInfo struct { + pattern noise.HandshakePattern + msgs []msgFlags +} + +// subtypeInfos defines the noise pattern and message content layout for each +// handshake subtype. +var subtypeInfos = map[header.MessageSubType]subtypeInfo{ + // IX: 2 messages, both carry payload and cert + header.HandshakeIXPSK0: { + pattern: noise.HandshakeIX, + msgs: []msgFlags{ + {expectsPayload: true, expectsCert: true}, + {expectsPayload: true, expectsCert: true}, + }, + }, + + // XX: 3 messages + // msg1 (I->R): payload only + // msg2 (R->I): payload + cert + // msg3 (I->R): cert only + //header.HandshakeXXPSK0: { + // pattern: noise.HandshakeXX, + // msgs: []msgFlags{ + // {expectsPayload: true, expectsCert: false}, + // {expectsPayload: true, expectsCert: true}, + // {expectsPayload: false, expectsCert: true}, + // }, + //}, +} + +func subtypeInfoFor(subtype header.MessageSubType) (subtypeInfo, error) { + if info, ok := subtypeInfos[subtype]; ok { + return info, nil + } + return subtypeInfo{}, fmt.Errorf("%w: %d", ErrUnknownSubtype, subtype) +} diff --git a/handshake/patterns_test.go b/handshake/patterns_test.go new file mode 100644 index 00000000..d6207e00 --- /dev/null +++ b/handshake/patterns_test.go @@ -0,0 +1,63 @@ +package handshake + +import ( + "testing" + + "github.com/flynn/noise" + "github.com/slackhq/nebula/header" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSubtypeInfo(t *testing.T) { + t.Run("IX", func(t *testing.T) { + info, err := subtypeInfoFor(header.HandshakeIXPSK0) + require.NoError(t, err) + assert.Equal(t, noise.HandshakeIX.Name, info.pattern.Name) + require.Len(t, info.msgs, 2) + // msg1: payload + cert + assert.True(t, info.msgs[0].expectsPayload) + assert.True(t, info.msgs[0].expectsCert) + // msg2: payload + cert + assert.True(t, info.msgs[1].expectsPayload) + assert.True(t, info.msgs[1].expectsCert) + }) + + t.Run("XX", func(t *testing.T) { + registerTestXXInfo(t) + info, err := subtypeInfoFor(header.HandshakeXXPSK0) + require.NoError(t, err) + assert.Equal(t, noise.HandshakeXX.Name, info.pattern.Name) + require.Len(t, info.msgs, 3) + // msg1: payload only + assert.True(t, info.msgs[0].expectsPayload) + assert.False(t, info.msgs[0].expectsCert) + // msg2: payload + cert + assert.True(t, info.msgs[1].expectsPayload) + assert.True(t, info.msgs[1].expectsCert) + // msg3: cert only + assert.False(t, info.msgs[2].expectsPayload) + assert.True(t, info.msgs[2].expectsCert) + }) + + t.Run("unknown subtype returns error", func(t *testing.T) { + _, err := subtypeInfoFor(99) + require.ErrorIs(t, err, ErrUnknownSubtype) + }) +} + +// registerTestXXInfo temporarily registers XX subtype info for testing. +func registerTestXXInfo(t *testing.T) { + t.Helper() + subtypeInfos[header.HandshakeXXPSK0] = subtypeInfo{ + pattern: noise.HandshakeXX, + msgs: []msgFlags{ + {expectsPayload: true, expectsCert: false}, + {expectsPayload: true, expectsCert: true}, + {expectsPayload: false, expectsCert: true}, + }, + } + t.Cleanup(func() { + delete(subtypeInfos, header.HandshakeXXPSK0) + }) +} diff --git a/handshake/payload.go b/handshake/payload.go new file mode 100644 index 00000000..4567fc0d --- /dev/null +++ b/handshake/payload.go @@ -0,0 +1,173 @@ +package handshake + +import ( + "errors" + "math" + + "google.golang.org/protobuf/encoding/protowire" +) + +var ( + errInvalidHandshakeMessage = errors.New("invalid handshake message") + errInvalidHandshakeDetails = errors.New("invalid handshake details") +) + +// Payload represents the decoded fields of a handshake message. +// Wire format is protobuf-compatible with NebulaHandshake{Details: NebulaHandshakeDetails{...}}. +type Payload struct { + Cert []byte + InitiatorIndex uint32 + ResponderIndex uint32 + Time uint64 + CertVersion uint32 +} + +// Proto field numbers for NebulaHandshakeDetails +const ( + fieldCert = 1 // bytes + fieldInitiatorIndex = 2 // uint32 + fieldResponderIndex = 3 // uint32 + fieldTime = 5 // uint64 + fieldCertVersion = 8 // uint32 +) + +// MarshalPayload encodes a handshake payload in protobuf wire format compatible +// with NebulaHandshake{Details: NebulaHandshakeDetails{...}}. +// Returns out (which may be nil), with the marshalled Payload appended to it. +func MarshalPayload(out []byte, p Payload) []byte { + var details []byte + + if len(p.Cert) > 0 { + details = protowire.AppendTag(details, fieldCert, protowire.BytesType) + details = protowire.AppendBytes(details, p.Cert) + } + if p.InitiatorIndex != 0 { + details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType) + details = protowire.AppendVarint(details, uint64(p.InitiatorIndex)) + } + if p.ResponderIndex != 0 { + details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType) + details = protowire.AppendVarint(details, uint64(p.ResponderIndex)) + } + if p.Time != 0 { + details = protowire.AppendTag(details, fieldTime, protowire.VarintType) + details = protowire.AppendVarint(details, p.Time) + } + if p.CertVersion != 0 { + details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType) + details = protowire.AppendVarint(details, uint64(p.CertVersion)) + } + + out = protowire.AppendTag(out, 1, protowire.BytesType) + out = protowire.AppendBytes(out, details) + + return out +} + +// UnmarshalPayload decodes a protobuf-encoded NebulaHandshake message. +func UnmarshalPayload(b []byte) (Payload, error) { + var p Payload + + for len(b) > 0 { + num, typ, n := protowire.ConsumeTag(b) + if n < 0 { + return p, errInvalidHandshakeMessage + } + b = b[n:] + + switch { + case num == 1 && typ == protowire.BytesType: + details, n := protowire.ConsumeBytes(b) + if n < 0 { + return p, errInvalidHandshakeMessage + } + b = b[n:] + if err := unmarshalPayloadDetails(&p, details); err != nil { + return p, err + } + default: + n := protowire.ConsumeFieldValue(num, typ, b) + if n < 0 { + return p, errInvalidHandshakeMessage + } + b = b[n:] + } + } + + return p, nil +} + +func unmarshalPayloadDetails(p *Payload, b []byte) error { + for len(b) > 0 { + num, typ, n := protowire.ConsumeTag(b) + if n < 0 { + return errInvalidHandshakeDetails + } + b = b[n:] + + // For known field numbers, reject any non-matching wire type as a + // hard error rather than silently skipping. The caller will catch + // missing-field cases downstream, but a wire-type mismatch on a tag + // we know is a peer protocol violation worth flagging here. + // Repeated occurrences of a singular field follow proto3 last-wins. + switch num { + case fieldCert: + if typ != protowire.BytesType { + return errInvalidHandshakeDetails + } + v, n := protowire.ConsumeBytes(b) + if n < 0 { + return errInvalidHandshakeDetails + } + p.Cert = append([]byte(nil), v...) + b = b[n:] + case fieldInitiatorIndex: + if typ != protowire.VarintType { + return errInvalidHandshakeDetails + } + v, n := protowire.ConsumeVarint(b) + if n < 0 || v > math.MaxUint32 { + return errInvalidHandshakeDetails + } + p.InitiatorIndex = uint32(v) + b = b[n:] + case fieldResponderIndex: + if typ != protowire.VarintType { + return errInvalidHandshakeDetails + } + v, n := protowire.ConsumeVarint(b) + if n < 0 || v > math.MaxUint32 { + return errInvalidHandshakeDetails + } + p.ResponderIndex = uint32(v) + b = b[n:] + case fieldTime: + if typ != protowire.VarintType { + return errInvalidHandshakeDetails + } + v, n := protowire.ConsumeVarint(b) + if n < 0 { + return errInvalidHandshakeDetails + } + p.Time = v + b = b[n:] + case fieldCertVersion: + if typ != protowire.VarintType { + return errInvalidHandshakeDetails + } + v, n := protowire.ConsumeVarint(b) + if n < 0 || v > math.MaxUint32 { + return errInvalidHandshakeDetails + } + p.CertVersion = uint32(v) + b = b[n:] + default: + n := protowire.ConsumeFieldValue(num, typ, b) + if n < 0 { + return errInvalidHandshakeDetails + } + b = b[n:] + } + } + return nil +} diff --git a/handshake/payload_test.go b/handshake/payload_test.go new file mode 100644 index 00000000..2ff3231c --- /dev/null +++ b/handshake/payload_test.go @@ -0,0 +1,361 @@ +package handshake + +import ( + "bytes" + "math" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/protowire" +) + +func TestPayloadRoundTrip(t *testing.T) { + t.Run("all fields set", func(t *testing.T) { + data := MarshalPayload(nil, Payload{ + Cert: []byte("test-cert-bytes"), + CertVersion: 2, + InitiatorIndex: 12345, + ResponderIndex: 67890, + Time: 1234567890, + }) + + got, err := UnmarshalPayload(data) + require.NoError(t, err) + + assert.Equal(t, []byte("test-cert-bytes"), got.Cert) + assert.Equal(t, uint32(12345), got.InitiatorIndex) + assert.Equal(t, uint32(67890), got.ResponderIndex) + assert.Equal(t, uint64(1234567890), got.Time) + assert.Equal(t, uint32(2), got.CertVersion) + }) + + t.Run("minimal fields", func(t *testing.T) { + data := MarshalPayload(nil, Payload{InitiatorIndex: 1}) + + got, err := UnmarshalPayload(data) + require.NoError(t, err) + + assert.Equal(t, uint32(1), got.InitiatorIndex) + assert.Equal(t, uint32(0), got.ResponderIndex) + assert.Equal(t, uint64(0), got.Time) + assert.Nil(t, got.Cert) + }) + + t.Run("empty payload", func(t *testing.T) { + data := MarshalPayload(nil, Payload{}) + + got, err := UnmarshalPayload(data) + require.NoError(t, err) + + assert.Equal(t, uint32(0), got.InitiatorIndex) + }) + + t.Run("large cert bytes", func(t *testing.T) { + bigCert := make([]byte, 4096) + for i := range bigCert { + bigCert[i] = byte(i % 256) + } + + data := MarshalPayload(nil, Payload{ + Cert: bigCert, + CertVersion: 2, + InitiatorIndex: 999, + }) + + got, err := UnmarshalPayload(data) + require.NoError(t, err) + + assert.Equal(t, bigCert, got.Cert) + assert.Equal(t, uint32(999), got.InitiatorIndex) + }) + + t.Run("append to existing buffer", func(t *testing.T) { + prefix := []byte("prefix") + data := MarshalPayload(prefix, Payload{InitiatorIndex: 42}) + + assert.Equal(t, []byte("prefix"), data[:6]) + + got, err := UnmarshalPayload(data[6:]) + require.NoError(t, err) + assert.Equal(t, uint32(42), got.InitiatorIndex) + }) +} + +func TestPayloadUnknownFields(t *testing.T) { + t.Run("unknown field in outer message is skipped", func(t *testing.T) { + // Marshal a normal payload then append an unknown field (field 99, varint) + data := MarshalPayload(nil, Payload{InitiatorIndex: 42}) + data = protowire.AppendTag(data, 99, protowire.VarintType) + data = protowire.AppendVarint(data, 12345) + + got, err := UnmarshalPayload(data) + require.NoError(t, err) + assert.Equal(t, uint32(42), got.InitiatorIndex) + }) + + t.Run("unknown field in details is skipped", func(t *testing.T) { + // Build details with a known field + unknown field + var details []byte + details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType) + details = protowire.AppendVarint(details, 77) + // Unknown field 50, varint + details = protowire.AppendTag(details, 50, protowire.VarintType) + details = protowire.AppendVarint(details, 9999) + // Another known field after the unknown one + details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType) + details = protowire.AppendVarint(details, 88) + + // Wrap in outer message + var data []byte + data = protowire.AppendTag(data, 1, protowire.BytesType) + data = protowire.AppendBytes(data, details) + + got, err := UnmarshalPayload(data) + require.NoError(t, err) + assert.Equal(t, uint32(77), got.InitiatorIndex) + assert.Equal(t, uint32(88), got.ResponderIndex) + }) + + t.Run("reserved fields 6 and 7 are skipped", func(t *testing.T) { + // Fields 6 and 7 are reserved in the proto definition + var details []byte + details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType) + details = protowire.AppendVarint(details, 100) + details = protowire.AppendTag(details, 6, protowire.VarintType) + details = protowire.AppendVarint(details, 1) + details = protowire.AppendTag(details, 7, protowire.VarintType) + details = protowire.AppendVarint(details, 2) + + var data []byte + data = protowire.AppendTag(data, 1, protowire.BytesType) + data = protowire.AppendBytes(data, details) + + got, err := UnmarshalPayload(data) + require.NoError(t, err) + assert.Equal(t, uint32(100), got.InitiatorIndex) + }) +} + +func TestPayloadBytesConsumed(t *testing.T) { + t.Run("all bytes consumed on valid input", func(t *testing.T) { + original := Payload{ + Cert: []byte("cert"), + CertVersion: 2, + InitiatorIndex: 100, + ResponderIndex: 200, + Time: 999, + } + data := MarshalPayload(nil, original) + + got, err := UnmarshalPayload(data) + require.NoError(t, err) + + // Re-marshal and compare — proves we consumed and reproduced all fields + remarshaled := MarshalPayload(nil, got) + assert.Equal(t, data, remarshaled) + }) +} + +// wrapDetails wraps raw detail bytes in the outer NebulaHandshake envelope +// so UnmarshalPayload can reach unmarshalPayloadDetails. +func wrapDetails(details []byte) []byte { + var out []byte + out = protowire.AppendTag(out, 1, protowire.BytesType) + out = protowire.AppendBytes(out, details) + return out +} + +func TestPayloadUnmarshalErrors(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + got, err := UnmarshalPayload(nil) + require.NoError(t, err) + assert.Equal(t, uint32(0), got.InitiatorIndex) + }) + + t.Run("truncated outer tag", func(t *testing.T) { + _, err := UnmarshalPayload([]byte{0x80}) + assert.Error(t, err) + }) + + t.Run("truncated outer details field", func(t *testing.T) { + _, err := UnmarshalPayload([]byte{0x0a, 0x64, 0x01, 0x02, 0x03, 0x04, 0x05}) + assert.Error(t, err) + }) + + t.Run("truncated outer unknown field", func(t *testing.T) { + // Valid tag for unknown field 99 varint, but no value follows + var data []byte + data = protowire.AppendTag(data, 99, protowire.VarintType) + _, err := UnmarshalPayload(data) + assert.Error(t, err) + }) + + t.Run("truncated details tag", func(t *testing.T) { + _, err := UnmarshalPayload(wrapDetails([]byte{0x80})) + assert.Error(t, err) + }) + + t.Run("truncated cert bytes", func(t *testing.T) { + // Field 1 (cert), bytes type, length 10 but only 2 bytes + var details []byte + details = protowire.AppendTag(details, fieldCert, protowire.BytesType) + details = append(details, 0x0a, 0x01, 0x02) // length 10, only 2 bytes + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("truncated initiator index varint", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType) + details = append(details, 0x80) // incomplete varint + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("truncated responder index varint", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType) + details = append(details, 0x80) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("truncated time varint", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, fieldTime, protowire.VarintType) + details = append(details, 0x80) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("truncated cert version varint", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType) + details = append(details, 0x80) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("truncated unknown field in details", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, 50, protowire.VarintType) + details = append(details, 0x80) // incomplete varint + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("cert with wrong wire type rejected", func(t *testing.T) { + // fieldCert as Varint instead of Bytes. + var details []byte + details = protowire.AppendTag(details, fieldCert, protowire.VarintType) + details = protowire.AppendVarint(details, 42) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("initiator index with wrong wire type rejected", func(t *testing.T) { + // fieldInitiatorIndex as Bytes instead of Varint. + var details []byte + details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.BytesType) + details = protowire.AppendBytes(details, []byte{1, 2, 3}) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("time with wrong wire type rejected", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, fieldTime, protowire.BytesType) + details = protowire.AppendBytes(details, []byte{1, 2, 3}) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("cert version with wrong wire type rejected", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, fieldCertVersion, protowire.BytesType) + details = protowire.AppendBytes(details, []byte{1, 2, 3}) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("repeated singular field follows proto3 last-wins", func(t *testing.T) { + // Per proto3, multiple instances of a singular field are accepted and + // the last value wins. We keep this behavior so that peers using + // alternative encoders aren't rejected. + var details []byte + details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType) + details = protowire.AppendVarint(details, 1) + details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType) + details = protowire.AppendVarint(details, 42) + got, err := UnmarshalPayload(wrapDetails(details)) + require.NoError(t, err) + assert.Equal(t, uint32(42), got.InitiatorIndex) + }) + + t.Run("initiator index varint overflow rejected", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType) + details = protowire.AppendVarint(details, math.MaxUint32+1) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + + t.Run("cert version varint overflow rejected", func(t *testing.T) { + var details []byte + details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType) + details = protowire.AppendVarint(details, math.MaxUint32+1) + _, err := UnmarshalPayload(wrapDetails(details)) + assert.Error(t, err) + }) + +} + +// FuzzPayload feeds arbitrary bytes through UnmarshalPayload to confirm it +// never panics, and for any input that parses cleanly, that re-marshal + +// re-parse is a fix-point. Inputs come from an authenticated peer (post- +// noise-decrypt), so the threat model is "valid peer behaving arbitrarily," +// not "unauthenticated injection." +func FuzzPayload(f *testing.F) { + // Seed corpus with a handful of known-good shapes. + f.Add(MarshalPayload(nil, Payload{})) + f.Add(MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2})) + f.Add(MarshalPayload(nil, Payload{InitiatorIndex: 42, Time: 1})) + f.Add(MarshalPayload(nil, Payload{ + Cert: []byte("seed-cert"), + InitiatorIndex: 1, + ResponderIndex: 2, + Time: 3, + CertVersion: 2, + })) + f.Add([]byte{}) + f.Add([]byte{0xff}) + + f.Fuzz(func(t *testing.T, data []byte) { + p1, err := UnmarshalPayload(data) + if err != nil { + return + } + + // For any input that parses, re-marshaling and re-parsing must + // yield an equivalent Payload. This catches dispatch bugs (e.g. + // emitting a field on marshal that we don't accept on parse) and + // any non-idempotent parsing behavior. + b2 := MarshalPayload(nil, p1) + p2, err := UnmarshalPayload(b2) + if err != nil { + t.Fatalf("re-parse of self-marshaled payload failed: %v\nintermediate: %x\n", err, b2) + } + if !payloadsEqual(p1, p2) { + t.Fatalf("re-marshal not idempotent\nfirst: %+v\nsecond: %+v", p1, p2) + } + }) +} + +func payloadsEqual(a, b Payload) bool { + return bytes.Equal(a.Cert, b.Cert) && + a.InitiatorIndex == b.InitiatorIndex && + a.ResponderIndex == b.ResponderIndex && + a.Time == b.Time && + a.CertVersion == b.CertVersion +} diff --git a/handshake_ix.go b/handshake_ix.go deleted file mode 100644 index a086960e..00000000 --- a/handshake_ix.go +++ /dev/null @@ -1,813 +0,0 @@ -package nebula - -import ( - "bytes" - "context" - "log/slog" - "net/netip" - "time" - - "github.com/flynn/noise" - "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/header" -) - -// NOISE IX Handshakes - -// This function constructs a handshake packet, but does not actually send it -// Sending is done by the handshake manager -func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { - err := f.handshakeManager.allocateIndex(hh) - if err != nil { - f.l.Error("Failed to generate index", - "error", err, - "vpnAddrs", hh.hostinfo.vpnAddrs, - "handshake", m{"stage": 0, "style": "ix_psk0"}, - ) - return false - } - - cs := f.pki.getCertState() - v := cs.initiatingVersion - if hh.initiatingVersionOverride != cert.VersionPre1 { - v = hh.initiatingVersionOverride - } else if v < cert.Version2 { - // If we're connecting to a v6 address we should encourage use of a V2 cert - for _, a := range hh.hostinfo.vpnAddrs { - if a.Is6() { - v = cert.Version2 - break - } - } - } - - crt := cs.getCertificate(v) - if crt == nil { - f.l.Error("Unable to handshake with host because no certificate is available", - "vpnAddrs", hh.hostinfo.vpnAddrs, - "handshake", m{"stage": 0, "style": "ix_psk0"}, - "certVersion", v, - ) - return false - } - - crtHs := cs.getHandshakeBytes(v) - if crtHs == nil { - f.l.Error("Unable to handshake with host because no certificate handshake bytes is available", - "vpnAddrs", hh.hostinfo.vpnAddrs, - "handshake", m{"stage": 0, "style": "ix_psk0"}, - "certVersion", v, - ) - return false - } - - ci, err := NewConnectionState(cs, crt, true, noise.HandshakeIX) - if err != nil { - f.l.Error("Failed to create connection state", - "error", err, - "vpnAddrs", hh.hostinfo.vpnAddrs, - "handshake", m{"stage": 0, "style": "ix_psk0"}, - "certVersion", v, - ) - return false - } - hh.hostinfo.ConnectionState = ci - - hs := &NebulaHandshake{ - Details: &NebulaHandshakeDetails{ - InitiatorIndex: hh.hostinfo.localIndexId, - Time: uint64(time.Now().UnixNano()), - Cert: crtHs, - CertVersion: uint32(v), - }, - } - - hsBytes, err := hs.Marshal() - if err != nil { - f.l.Error("Failed to marshal handshake message", - "error", err, - "vpnAddrs", hh.hostinfo.vpnAddrs, - "certVersion", v, - "handshake", m{"stage": 0, "style": "ix_psk0"}, - ) - return false - } - - h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1) - - msg, _, _, err := ci.H.WriteMessage(h, hsBytes) - if err != nil { - f.l.Error("Failed to call noise.WriteMessage", - "error", err, - "vpnAddrs", hh.hostinfo.vpnAddrs, - "handshake", m{"stage": 0, "style": "ix_psk0"}, - ) - return false - } - - // We are sending handshake packet 1, so we don't expect to receive - // handshake packet 1 from the responder - ci.window.Update(f.l, 1) - - hh.hostinfo.HandshakePacket[0] = msg - hh.ready = true - return true -} - -func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) { - cs := f.pki.getCertState() - crt := cs.GetDefaultCertificate() - if crt == nil { - f.l.Error("Unable to handshake with host because no certificate is available", - "from", via, - "handshake", m{"stage": 0, "style": "ix_psk0"}, - "certVersion", cs.initiatingVersion, - ) - return - } - - ci, err := NewConnectionState(cs, crt, false, noise.HandshakeIX) - if err != nil { - f.l.Error("Failed to create connection state", - "error", err, - "from", via, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - return - } - - // Mark packet 1 as seen so it doesn't show up as missed - ci.window.Update(f.l, 1) - - msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:]) - if err != nil { - f.l.Error("Failed to call noise.ReadMessage", - "error", err, - "from", via, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - return - } - - hs := &NebulaHandshake{} - err = hs.Unmarshal(msg) - if err != nil || hs.Details == nil { - f.l.Error("Failed unmarshal handshake message", - "error", err, - "from", via, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - return - } - - rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) - if err != nil { - f.l.Info("Handshake did not contain a certificate", - "error", err, - "from", via, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - return - } - - remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc) - if err != nil { - fp, fperr := rc.Fingerprint() - if fperr != nil { - fp = "" - } - - attrs := []slog.Attr{ - slog.Any("error", err), - slog.Any("from", via), - slog.Any("handshake", m{"stage": 1, "style": "ix_psk0"}), - slog.Any("certVpnNetworks", rc.Networks()), - slog.String("certFingerprint", fp), - } - if f.l.Enabled(context.Background(), slog.LevelDebug) { - attrs = append(attrs, slog.Any("cert", rc)) - } - - // LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that - // callers grow conditionally, which has no pair-form equivalent. - //nolint:sloglint - f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...) - return - } - - if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) { - f.l.Info("public key mismatch between certificate and handshake", - "from", via, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - "cert", remoteCert, - ) - return - } - - if remoteCert.Certificate.Version() != ci.myCert.Version() { - // We started off using the wrong certificate version, lets see if we can match the version that was sent to us - myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version()) - if myCertOtherVersion == nil { - if f.l.Enabled(context.Background(), slog.LevelDebug) { - f.l.Debug("Might be unable to handshake with host due to missing certificate version", - "error", err, - "from", via, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - "cert", remoteCert, - ) - } - } else { - // Record the certificate we are actually using - ci.myCert = myCertOtherVersion - } - } - - if len(remoteCert.Certificate.Networks()) == 0 { - f.l.Info("No networks in certificate", - "error", err, - "from", via, - "cert", remoteCert, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - return - } - - certName := remoteCert.Certificate.Name() - certVersion := remoteCert.Certificate.Version() - fingerprint := remoteCert.Fingerprint - issuer := remoteCert.Certificate.Issuer() - vpnNetworks := remoteCert.Certificate.Networks() - - anyVpnAddrsInCommon := false - vpnAddrs := make([]netip.Addr, len(vpnNetworks)) - for i, network := range vpnNetworks { - if f.myVpnAddrsTable.Contains(network.Addr()) { - f.l.Error("Refusing to handshake with myself", - "vpnNetworks", vpnNetworks, - "from", via, - "certName", certName, - "certVersion", certVersion, - "fingerprint", fingerprint, - "issuer", issuer, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - return - } - vpnAddrs[i] = network.Addr() - if f.myVpnNetworksTable.Contains(network.Addr()) { - anyVpnAddrsInCommon = true - } - } - - if !via.IsRelayed { - // We only want to apply the remote allow list for direct tunnels here - if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) { - if f.l.Enabled(context.Background(), slog.LevelDebug) { - f.l.Debug("lighthouse.remote_allow_list denied incoming handshake", - "vpnAddrs", vpnAddrs, - "from", via, - ) - } - return - } - } - - myIndex, err := generateIndex(f.l) - if err != nil { - f.l.Error("Failed to generate index", - "error", err, - "vpnAddrs", vpnAddrs, - "from", via, - "certName", certName, - "certVersion", certVersion, - "fingerprint", fingerprint, - "issuer", issuer, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - return - } - - hostinfo := &HostInfo{ - ConnectionState: ci, - localIndexId: myIndex, - remoteIndexId: hs.Details.InitiatorIndex, - vpnAddrs: vpnAddrs, - HandshakePacket: make(map[uint8][]byte, 0), - lastHandshakeTime: hs.Details.Time, - relayState: RelayState{ - relays: nil, - relayForByAddr: map[netip.Addr]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, - }, - } - - msgRxL := f.l.With( - "vpnAddrs", vpnAddrs, - "from", via, - "certName", certName, - "certVersion", certVersion, - "fingerprint", fingerprint, - "issuer", issuer, - "initiatorIndex", hs.Details.InitiatorIndex, - "responderIndex", hs.Details.ResponderIndex, - "remoteIndex", h.RemoteIndex, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - - if anyVpnAddrsInCommon { - msgRxL.Info("Handshake message received") - } else { - //todo warn if not lighthouse or relay? - msgRxL.Info("Handshake message received, but no vpnNetworks in common.") - } - - hs.Details.ResponderIndex = myIndex - hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version()) - if hs.Details.Cert == nil { - msgRxL.Error("Unable to handshake with host because no certificate handshake bytes is available", - "myCertVersion", ci.myCert.Version(), - ) - return - } - - hs.Details.CertVersion = uint32(ci.myCert.Version()) - // Update the time in case their clock is way off from ours - hs.Details.Time = uint64(time.Now().UnixNano()) - - hsBytes, err := hs.Marshal() - if err != nil { - f.l.Error("Failed to marshal handshake message", - "error", err, - "vpnAddrs", hostinfo.vpnAddrs, - "from", via, - "certName", certName, - "certVersion", certVersion, - "fingerprint", fingerprint, - "issuer", issuer, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - return - } - - nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2) - msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes) - if err != nil { - f.l.Error("Failed to call noise.WriteMessage", - "error", err, - "vpnAddrs", hostinfo.vpnAddrs, - "from", via, - "certName", certName, - "certVersion", certVersion, - "fingerprint", fingerprint, - "issuer", issuer, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - return - } else if dKey == nil || eKey == nil { - f.l.Error("Noise did not arrive at a key", - "vpnAddrs", hostinfo.vpnAddrs, - "from", via, - "certName", certName, - "certVersion", certVersion, - "fingerprint", fingerprint, - "issuer", issuer, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - return - } - - hostinfo.HandshakePacket[0] = make([]byte, len(packet[header.Len:])) - copy(hostinfo.HandshakePacket[0], packet[header.Len:]) - - // Regardless of whether you are the sender or receiver, you should arrive here - // and complete standing up the connection. - hostinfo.HandshakePacket[2] = make([]byte, len(msg)) - copy(hostinfo.HandshakePacket[2], msg) - - // We are sending handshake packet 2, so we don't expect to receive - // handshake packet 2 from the initiator. - ci.window.Update(f.l, 2) - - ci.peerCert = remoteCert - ci.dKey = NewNebulaCipherState(dKey) - ci.eKey = NewNebulaCipherState(eKey) - - hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) - if !via.IsRelayed { - hostinfo.SetRemote(via.UdpAddr) - } - hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) - - existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) - if err != nil { - switch err { - case ErrAlreadySeen: - // Update remote if preferred - if existing.SetRemoteIfPreferred(f.hostMap, via) { - // Send a test packet to ensure the other side has also switched to - // the preferred remote - f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) - } - - msg = existing.HandshakePacket[2] - f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) - if !via.IsRelayed { - err := f.outside.WriteTo(msg, via.UdpAddr) - if err != nil { - f.l.Error("Failed to send handshake message", - "vpnAddrs", existing.vpnAddrs, - "from", via, - "handshake", m{"stage": 2, "style": "ix_psk0"}, - "cached", true, - "error", err, - ) - } else { - f.l.Info("Handshake message sent", - "vpnAddrs", existing.vpnAddrs, - "from", via, - "handshake", m{"stage": 2, "style": "ix_psk0"}, - "cached", true, - ) - } - return - } else { - if via.relay == nil { - f.l.Error("Handshake send failed: both addr and via.relay are nil.") - return - } - hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) - f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.Info("Handshake message sent", - "vpnAddrs", existing.vpnAddrs, - "relay", via.relayHI.vpnAddrs[0], - "handshake", m{"stage": 2, "style": "ix_psk0"}, - "cached", true, - ) - return - } - case ErrExistingHostInfo: - // This means there was an existing tunnel and this handshake was older than the one we are currently based on - f.l.Info("Handshake too old", - "vpnAddrs", vpnAddrs, - "from", via, - "certName", certName, - "certVersion", certVersion, - "oldHandshakeTime", existing.lastHandshakeTime, - "newHandshakeTime", hostinfo.lastHandshakeTime, - "fingerprint", fingerprint, - "issuer", issuer, - "initiatorIndex", hs.Details.InitiatorIndex, - "responderIndex", hs.Details.ResponderIndex, - "remoteIndex", h.RemoteIndex, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - - // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues - f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) - return - case ErrLocalIndexCollision: - // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry - f.l.Error("Failed to add HostInfo due to localIndex collision", - "vpnAddrs", vpnAddrs, - "from", via, - "certName", certName, - "certVersion", certVersion, - "fingerprint", fingerprint, - "issuer", issuer, - "initiatorIndex", hs.Details.InitiatorIndex, - "responderIndex", hs.Details.ResponderIndex, - "remoteIndex", h.RemoteIndex, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - "localIndex", hostinfo.localIndexId, - "collision", existing.vpnAddrs, - ) - return - default: - // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete - // And we forget to update it here - f.l.Error("Failed to add HostInfo to HostMap", - "error", err, - "vpnAddrs", vpnAddrs, - "from", via, - "certName", certName, - "certVersion", certVersion, - "fingerprint", fingerprint, - "issuer", issuer, - "initiatorIndex", hs.Details.InitiatorIndex, - "responderIndex", hs.Details.ResponderIndex, - "remoteIndex", h.RemoteIndex, - "handshake", m{"stage": 1, "style": "ix_psk0"}, - ) - return - } - } - - // Do the send - f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) - if !via.IsRelayed { - err = f.outside.WriteTo(msg, via.UdpAddr) - log := f.l.With( - "vpnAddrs", vpnAddrs, - "from", via, - "certName", certName, - "certVersion", certVersion, - "fingerprint", fingerprint, - "issuer", issuer, - "initiatorIndex", hs.Details.InitiatorIndex, - "responderIndex", hs.Details.ResponderIndex, - "remoteIndex", h.RemoteIndex, - "handshake", m{"stage": 2, "style": "ix_psk0"}, - ) - if err != nil { - log.Error("Failed to send handshake", "error", err) - } else { - log.Info("Handshake message sent") - } - } else { - if via.relay == nil { - f.l.Error("Handshake send failed: both addr and via.relay are nil.") - return - } - hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) - // I successfully received a handshake. Just in case I marked this tunnel as 'Disestablished', ensure - // it's correctly marked as working. - via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established) - f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.Info("Handshake message sent", - "vpnAddrs", vpnAddrs, - "relay", via.relayHI.vpnAddrs[0], - "certName", certName, - "certVersion", certVersion, - "fingerprint", fingerprint, - "issuer", issuer, - "initiatorIndex", hs.Details.InitiatorIndex, - "responderIndex", hs.Details.ResponderIndex, - "remoteIndex", h.RemoteIndex, - "handshake", m{"stage": 2, "style": "ix_psk0"}, - ) - } - - f.connectionManager.AddTrafficWatch(hostinfo) - - hostinfo.remotes.RefreshFromHandshake(vpnAddrs) - - // Don't wait for UpdateWorker - if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) { - f.lightHouse.TriggerUpdate() - } - - return -} - -func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { - if hh == nil { - // Nothing here to tear down, got a bogus stage 2 packet - return true - } - - hh.Lock() - defer hh.Unlock() - - hostinfo := hh.hostinfo - if !via.IsRelayed { - // The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list. - if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) { - if f.l.Enabled(context.Background(), slog.LevelDebug) { - f.l.Debug("lighthouse.remote_allow_list denied incoming handshake", - "vpnAddrs", hostinfo.vpnAddrs, - "from", via, - ) - } - return false - } - } - - ci := hostinfo.ConnectionState - msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:]) - if err != nil { - f.l.Error("Failed to call noise.ReadMessage", - "error", err, - "vpnAddrs", hostinfo.vpnAddrs, - "from", via, - "handshake", m{"stage": 2, "style": "ix_psk0"}, - "header", h, - ) - - // We don't want to tear down the connection on a bad ReadMessage because it could be an attacker trying - // to DOS us. Every other error condition after should to allow a possible good handshake to complete in the - // near future - return false - } else if dKey == nil || eKey == nil { - f.l.Error("Noise did not arrive at a key", - "vpnAddrs", hostinfo.vpnAddrs, - "from", via, - "handshake", m{"stage": 2, "style": "ix_psk0"}, - ) - - // This should be impossible in IX but just in case, if we get here then there is no chance to recover - // the handshake state machine. Tear it down - return true - } - - hs := &NebulaHandshake{} - err = hs.Unmarshal(msg) - if err != nil || hs.Details == nil { - f.l.Error("Failed unmarshal handshake message", - "error", err, - "vpnAddrs", hostinfo.vpnAddrs, - "from", via, - "handshake", m{"stage": 2, "style": "ix_psk0"}, - ) - - // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again - return true - } - - rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) - if err != nil { - f.l.Info("Handshake did not contain a certificate", - "error", err, - "from", via, - "vpnAddrs", hostinfo.vpnAddrs, - "handshake", m{"stage": 2, "style": "ix_psk0"}, - ) - return true - } - - remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc) - if err != nil { - fp, err := rc.Fingerprint() - if err != nil { - fp = "" - } - - attrs := []slog.Attr{ - slog.Any("error", err), - slog.Any("from", via), - slog.Any("vpnAddrs", hostinfo.vpnAddrs), - slog.Any("handshake", m{"stage": 2, "style": "ix_psk0"}), - slog.String("certFingerprint", fp), - slog.Any("certVpnNetworks", rc.Networks()), - } - if f.l.Enabled(context.Background(), slog.LevelDebug) { - attrs = append(attrs, slog.Any("cert", rc)) - } - - // LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that - // callers grow conditionally, which has no pair-form equivalent. - //nolint:sloglint - f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...) - return true - } - if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) { - f.l.Info("public key mismatch between certificate and handshake", - "from", via, - "handshake", m{"stage": 2, "style": "ix_psk0"}, - "cert", remoteCert, - ) - return true - } - - if len(remoteCert.Certificate.Networks()) == 0 { - f.l.Info("No networks in certificate", - "error", err, - "from", via, - "vpnAddrs", hostinfo.vpnAddrs, - "cert", remoteCert, - "handshake", m{"stage": 2, "style": "ix_psk0"}, - ) - return true - } - - vpnNetworks := remoteCert.Certificate.Networks() - certName := remoteCert.Certificate.Name() - certVersion := remoteCert.Certificate.Version() - fingerprint := remoteCert.Fingerprint - issuer := remoteCert.Certificate.Issuer() - - hostinfo.remoteIndexId = hs.Details.ResponderIndex - hostinfo.lastHandshakeTime = hs.Details.Time - - // Store their cert and our symmetric keys - ci.peerCert = remoteCert - ci.dKey = NewNebulaCipherState(dKey) - ci.eKey = NewNebulaCipherState(eKey) - - // Make sure the current udpAddr being used is set for responding - if !via.IsRelayed { - hostinfo.SetRemote(via.UdpAddr) - } else { - hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) - } - - correctHostResponded := false - anyVpnAddrsInCommon := false - vpnAddrs := make([]netip.Addr, len(vpnNetworks)) - for i, network := range vpnNetworks { - vpnAddrs[i] = network.Addr() - if f.myVpnNetworksTable.Contains(network.Addr()) { - anyVpnAddrsInCommon = true - } - if hostinfo.vpnAddrs[0] == network.Addr() { - // todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not? - correctHostResponded = true - } - } - - // Ensure the right host responded - if !correctHostResponded { - f.l.Info("Incorrect host responded to handshake", - "intendedVpnAddrs", hostinfo.vpnAddrs, - "haveVpnNetworks", vpnNetworks, - "from", via, - "certName", certName, - "certVersion", certVersion, - "handshake", m{"stage": 2, "style": "ix_psk0"}, - ) - - // Release our old handshake from pending, it should not continue - f.handshakeManager.DeleteHostInfo(hostinfo) - - // Create a new hostinfo/handshake for the intended vpn ip - //TODO is hostinfo.vpnAddrs[0] always the address to use? - f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) { - // Block the current used address - newHH.hostinfo.remotes = hostinfo.remotes - newHH.hostinfo.remotes.BlockRemote(via) - - f.l.Info("Blocked addresses for handshakes", - "blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes(), - "vpnNetworks", vpnNetworks, - "remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges()), - ) - - // Swap the packet store to benefit the original intended recipient - newHH.packetStore = hh.packetStore - hh.packetStore = []*cachedPacket{} - - // Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down - hostinfo.vpnAddrs = vpnAddrs - f.sendCloseTunnel(hostinfo) - }) - - return true - } - - // Mark packet 2 as seen so it doesn't show up as missed - ci.window.Update(f.l, 2) - - duration := time.Since(hh.startTime).Nanoseconds() - msgRxL := f.l.With( - "vpnAddrs", vpnAddrs, - "from", via, - "certName", certName, - "certVersion", certVersion, - "fingerprint", fingerprint, - "issuer", issuer, - "initiatorIndex", hs.Details.InitiatorIndex, - "responderIndex", hs.Details.ResponderIndex, - "remoteIndex", h.RemoteIndex, - "handshake", m{"stage": 2, "style": "ix_psk0"}, - "durationNs", duration, - "sentCachedPackets", len(hh.packetStore), - ) - if anyVpnAddrsInCommon { - msgRxL.Info("Handshake message received") - } else { - //todo warn if not lighthouse or relay? - msgRxL.Info("Handshake message received, but no vpnNetworks in common.") - } - - // Build up the radix for the firewall if we have subnets in the cert - hostinfo.vpnAddrs = vpnAddrs - hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) - - // Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here - f.handshakeManager.Complete(hostinfo, f) - f.connectionManager.AddTrafficWatch(hostinfo) - - if f.l.Enabled(context.Background(), slog.LevelDebug) { - hostinfo.logger(f.l).Debug("Sending stored packets", - "count", len(hh.packetStore), - ) - } - - if len(hh.packetStore) > 0 { - nb := make([]byte, 12, 12) - out := make([]byte, mtu) - for _, cp := range hh.packetStore { - cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out) - } - f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore))) - } - - hostinfo.remotes.RefreshFromHandshake(vpnAddrs) - f.metricHandshakes.Update(duration) - - // Don't wait for UpdateWorker - if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) { - f.lightHouse.TriggerUpdate() - } - - return false -} diff --git a/handshake_manager.go b/handshake_manager.go index 8040ec2e..9fc69ff4 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -14,6 +14,7 @@ import ( "github.com/rcrowley/go-metrics" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/handshake" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" ) @@ -23,6 +24,18 @@ const ( DefaultHandshakeRetries = 10 DefaultHandshakeTriggerBuffer = 64 DefaultUseRelays = true + + // maxCachedPackets is how many unsent packets we'll buffer per pending + // handshake before dropping further ones. + maxCachedPackets = 100 + + // HandshakePacket map keys mirror the IX protocol stage convention: + // stage 0 = the initiator's first message (and what the responder + // receives, stripped of header) + // stage 2 = the responder's reply + // Other handshake patterns will need new keys when added. + handshakePacketStage0 uint8 = 0 + handshakePacketStage2 uint8 = 2 ) var ( @@ -76,10 +89,11 @@ type HandshakeHostInfo struct { packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes hostinfo *HostInfo + machine *handshake.Machine // The handshake state machine, set during stage 0 (initiator) or beginHandshake (responder multi-message) } func (hh *HandshakeHostInfo) cachePacket(l *slog.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { - if len(hh.packetStore) < 100 { + if len(hh.packetStore) < maxCachedPackets { tempPacket := make([]byte, len(packet)) copy(tempPacket, packet) @@ -137,6 +151,18 @@ func (hm *HandshakeManager) Run(ctx context.Context) { } func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *header.H) { + // Gate on known handshake subtypes. Unknown subtypes (or future ones we + // don't yet support) are dropped here rather than silently routed through + // the IX path. Add a case when introducing a new pattern. + switch h.Subtype { + case header.HandshakeIXPSK0: + // supported + default: + hm.l.Debug("dropping handshake with unsupported subtype", + "from", via, "subtype", h.Subtype) + return + } + // First remote allow list check before we know the vpnIp if !via.IsRelayed { if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) { @@ -145,19 +171,27 @@ func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *head } } - switch h.Subtype { - case header.HandshakeIXPSK0: - switch h.MessageCounter { - case 1: - ixHandshakeStage1(hm.f, via, packet, h) - - case 2: - newHostinfo := hm.queryIndex(h.RemoteIndex) - tearDown := ixHandshakeStage2(hm.f, via, newHostinfo, packet, h) - if tearDown && newHostinfo != nil { - hm.DeleteHostInfo(newHostinfo.hostinfo) - } + // First message of a new handshake. The wire format requires RemoteIndex + // to be zero here (the initiator has no responder index to fill in yet), + // and generateIndex never allocates 0, so any non-zero RemoteIndex on a + // stage-1 packet is malformed or someone probing for an index collision. + // Drop without paying the cost of running noise on a pending Machine. + if h.MessageCounter == 1 { + if h.RemoteIndex != 0 { + hm.l.Debug("dropping stage-1 handshake with non-zero RemoteIndex", + "from", via, "remoteIndex", h.RemoteIndex) + return } + hm.beginHandshake(via, packet, h) + return + } + + // Continuation message must match a pending handshake by index. + // Anything else is an orphaned packet (e.g., late retransmit after + // timeout) and is dropped. + if hh := hm.queryIndex(h.RemoteIndex); hh != nil { + hm.continueHandshake(via, hh, packet) + return } } @@ -183,13 +217,22 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered hostinfo := hh.hostinfo // If we are out of time, clean up if hh.counter >= hm.config.retries { - hh.hostinfo.logger(hm.l).Info("Handshake timed out", + fields := []any{ "udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()), "initiatorIndex", hh.hostinfo.localIndexId, "remoteIndex", hh.hostinfo.remoteIndexId, - "handshake", m{"stage": 1, "style": "ix_psk0"}, "durationNs", time.Since(hh.startTime).Nanoseconds(), - ) + } + // hh.machine can be nil here if buildStage0Packet never succeeded + // (e.g., no certificate available). In that case there's no useful + // handshake metadata to log. + if hh.machine != nil { + fields = append(fields, "handshake", m{ + "stage": uint64(hh.machine.MessageIndex()), + "style": header.SubTypeName(header.Handshake, hh.machine.Subtype()), + }) + } + hh.hostinfo.logger(hm.l).Info("Handshake timed out", fields...) hm.metricTimedOut.Inc(1) hm.DeleteHostInfo(hostinfo) return @@ -200,12 +243,25 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered // Check if we have a handshake packet to transmit yet if !hh.ready { - if !ixHandshakeStage0(hm.f, hh) { + if !hm.buildStage0Packet(hh) { hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval*time.Duration(hh.counter)) return } } + // TODO: this hardcodes "always retransmit stage 0", which is correct for + // IX (the initiator only ever sends one packet, msg1) but wrong the + // moment a 3+ message pattern lands. The retry loop should resend the + // most recent outgoing message, not always stage 0. That implies + // HandshakeHostInfo tracking a single "currentOutbound" packet (bytes + + // header metadata) that gets replaced as the handshake progresses, + // instead of indexing into HandshakePacket. + stage0 := hostinfo.HandshakePacket[handshakePacketStage0] + hsFields := m{ + "stage": uint64(hh.machine.MessageIndex()), + "style": header.SubTypeName(header.Handshake, hh.machine.Subtype()), + } + // Get a remotes object if we don't already have one. // This is mainly to protect us as this should never be the case // NB ^ This comment doesn't jive. It's how the thing gets initialized. @@ -239,13 +295,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered // Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply var sentTo []netip.AddrPort hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr netip.AddrPort, _ bool) { - hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) - err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr) + hm.messageMetrics.Tx(header.Handshake, hh.machine.Subtype(), 1) + err := hm.outside.WriteTo(stage0, addr) if err != nil { hostinfo.logger(hm.l).Error("Failed to send handshake message", "udpAddr", addr, "initiatorIndex", hostinfo.localIndexId, - "handshake", m{"stage": 1, "style": "ix_psk0"}, + "handshake", hsFields, "error", err, ) @@ -260,13 +316,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered hostinfo.logger(hm.l).Info("Handshake message sent", "udpAddrs", sentTo, "initiatorIndex", hostinfo.localIndexId, - "handshake", m{"stage": 1, "style": "ix_psk0"}, + "handshake", hsFields, ) } else if hm.l.Enabled(context.Background(), slog.LevelDebug) { hostinfo.logger(hm.l).Debug("Handshake message sent", "udpAddrs", sentTo, "initiatorIndex", hostinfo.localIndexId, - "handshake", m{"stage": 1, "style": "ix_psk0"}, + "handshake", hsFields, ) } @@ -348,7 +404,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered switch existingRelay.State { case Established: hostinfo.logger(hm.l).Info("Send handshake via relay", "relay", relay.String()) - hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) + hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[handshakePacketStage0], make([]byte, 12), make([]byte, mtu), false) case Disestablished: // Mark this relay as 'requested' relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested) @@ -587,7 +643,7 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { // allocateIndex generates a unique localIndexId for this HostInfo // and adds it to the pendingHostMap. Will error if we are unable to generate // a unique localIndexId -func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error { +func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) (uint32, error) { hm.mainHostMap.RLock() defer hm.mainHostMap.RUnlock() hm.Lock() @@ -596,7 +652,7 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error { for range 32 { index, err := generateIndex(hm.l) if err != nil { - return err + return 0, err } _, inPending := hm.indexes[index] @@ -605,11 +661,11 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error { if !inMain && !inPending { hh.hostinfo.localIndexId = index hm.indexes[index] = hh - return nil + return index, nil } } - return errors.New("failed to generate unique localIndexId") + return 0, errors.New("failed to generate unique localIndexId") } func (hm *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { @@ -728,3 +784,524 @@ func generateIndex(l *slog.Logger) (uint32, error) { func hsTimeout(tries int64, interval time.Duration) time.Duration { return time.Duration(tries / 2 * ((2 * int64(interval)) + (tries-1)*int64(interval))) } + +// buildStage0Packet creates the initial handshake packet for the initiator. +func (hm *HandshakeManager) buildStage0Packet(hh *HandshakeHostInfo) bool { + cs := hm.f.pki.getCertState() + v := cs.DefaultVersion() + if hh.initiatingVersionOverride != cert.VersionPre1 { + v = hh.initiatingVersionOverride + } else if v < cert.Version2 { + for _, a := range hh.hostinfo.vpnAddrs { + if a.Is6() { + v = cert.Version2 + break + } + } + } + + cred := cs.GetCredential(v) + if cred == nil { + hm.f.l.Error("Unable to handshake with host because no certificate is available", + "vpnAddrs", hh.hostinfo.vpnAddrs, "certVersion", v) + return false + } + + machine, err := handshake.NewMachine( + v, cs.GetCredential, + hm.certVerifier(), func() (uint32, error) { return hm.allocateIndex(hh) }, + true, header.HandshakeIXPSK0, + ) + if err != nil { + hm.f.l.Error("Failed to create handshake machine", + "vpnAddrs", hh.hostinfo.vpnAddrs, "error", err) + return false + } + + msg, err := machine.Initiate(nil) + if err != nil { + hm.f.l.Error("Failed to initiate handshake", + "vpnAddrs", hh.hostinfo.vpnAddrs, "error", err) + return false + } + + // hostinfo.ConnectionState stays nil until the handshake completes in + // continueHandshake. Pre-completion control surfaces guard with nil + // checks; the data plane never observes a pending hostinfo. + hh.hostinfo.HandshakePacket[handshakePacketStage0] = msg + hh.machine = machine + hh.ready = true + return true +} + +// beginHandshake handles an incoming handshake packet that doesn't match any +// existing pending handshake. It creates a new responder Machine and processes +// the first message. +func (hm *HandshakeManager) beginHandshake(via ViaSender, packet []byte, h *header.H) { + f := hm.f + cs := f.pki.getCertState() + + v := cs.DefaultVersion() + if cs.GetCredential(v) == nil { + f.l.Error("Unable to handshake with host because no certificate is available", + "from", via, "certVersion", v) + return + } + + machine, err := handshake.NewMachine( + v, cs.GetCredential, + hm.certVerifier(), func() (uint32, error) { return generateIndex(f.l) }, + false, header.HandshakeIXPSK0, + ) + if err != nil { + f.l.Error("Failed to create handshake machine", "from", via, "error", err) + return + } + + response, result, err := machine.ProcessPacket(nil, packet) + if err != nil { + f.l.Error("Failed to process handshake packet", "from", via, "error", err) + return + } + + if result == nil { + // Multi-message pattern: the responder Machine would need to be + // registered in hm.indexes so a future inbound packet finds it via + // continueHandshake. The current manager doesn't do that yet, so + // fail loudly rather than silently dropping the in-flight handshake. + // TODO: support multi-message responder flows (XX, pqIX, etc.). + // See also the IX-shaped cipher key assignment in handshake.Machine. + f.l.Error("multi-message handshake responder is not supported", + "from", via, "error", handshake.ErrMultiMessageUnsupported) + return + } + + remoteCert := result.RemoteCert + if remoteCert == nil { + f.l.Error("Handshake did not produce a peer certificate", "from", via) + return + } + + // Validate peer identity + vpnAddrs, anyVpnAddrsInCommon, ok := hm.validatePeerCert(via, remoteCert) + if !ok { + return + } + + hostinfo := &HostInfo{ + ConnectionState: newConnectionStateFromResult(result), + localIndexId: result.LocalIndex, + remoteIndexId: result.RemoteIndex, + vpnAddrs: vpnAddrs, + HandshakePacket: make(map[uint8][]byte, 0), + lastHandshakeTime: result.HandshakeTime, + relayState: RelayState{ + relays: nil, + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, + }, + } + + msg := "Handshake message received" + if !anyVpnAddrsInCommon { + msg = "Handshake message received, but no vpnNetworks in common." + } + f.l.Info(msg, + "vpnAddrs", vpnAddrs, + "from", via, + "certName", remoteCert.Certificate.Name(), + "certVersion", remoteCert.Certificate.Version(), + "fingerprint", remoteCert.Fingerprint, + "issuer", remoteCert.Certificate.Issuer(), + "initiatorIndex", result.RemoteIndex, + "responderIndex", result.LocalIndex, + "handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())}, + ) + + // packet aliases the listener's incoming buffer, so this copy must stay. + hostinfo.HandshakePacket[handshakePacketStage0] = make([]byte, len(packet[header.Len:])) + copy(hostinfo.HandshakePacket[handshakePacketStage0], packet[header.Len:]) + + // response was freshly allocated by ProcessPacket; safe to retain directly. + if response != nil { + hostinfo.HandshakePacket[handshakePacketStage2] = response + } + + hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) + if !via.IsRelayed { + hostinfo.SetRemote(via.UdpAddr) + } + hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) + + existing, err := hm.CheckAndComplete(hostinfo, handshakePacketStage0, f) + if err != nil { + hm.handleCheckAndCompleteError(err, existing, hostinfo, via) + return + } + + hm.sendHandshakeResponse(via, response, hostinfo, false) + f.connectionManager.AddTrafficWatch(hostinfo) + hostinfo.remotes.RefreshFromHandshake(vpnAddrs) + + // Don't wait for UpdateWorker + if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) { + f.lightHouse.TriggerUpdate() + } +} + +// continueHandshake feeds an incoming packet to an existing pending handshake Machine. +func (hm *HandshakeManager) continueHandshake(via ViaSender, hh *HandshakeHostInfo, packet []byte) { + f := hm.f + + hh.Lock() + defer hh.Unlock() + + // Re-verify hh is still tracked. Between queryIndex returning and us taking + // hh.Lock, handleOutbound may have timed out and deleted it. Once we hold + // hh.Lock no other deleter can race our index: handleOutbound also takes + // hh.Lock first, and handleRecvError targets a main-hostmap entry with a + // different localIndexId. + hm.RLock() + cur, ok := hm.indexes[hh.hostinfo.localIndexId] + hm.RUnlock() + if !ok || cur != hh { + return + } + + hostinfo := hh.hostinfo + if !via.IsRelayed { + if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) { + f.l.Debug("lighthouse.remote_allow_list denied incoming handshake", + "vpnAddrs", hostinfo.vpnAddrs, "from", via) + return + } + } + + machine := hh.machine + if machine == nil { + f.l.Error("No handshake machine available for continuation", + "vpnAddrs", hostinfo.vpnAddrs, "from", via) + hm.DeleteHostInfo(hostinfo) + return + } + + response, result, err := machine.ProcessPacket(nil, packet) + if err != nil { + // Recoverable errors are routine noise, log at Debug. Fatal errors get a Warn. + if machine.Failed() { + f.l.Warn("Failed to process handshake packet, abandoning", + "vpnAddrs", hostinfo.vpnAddrs, "from", via, "error", err) + hm.DeleteHostInfo(hostinfo) + } else { + f.l.Debug("Failed to process handshake packet", + "vpnAddrs", hostinfo.vpnAddrs, "from", via, "error", err) + } + return + } + + if response != nil { + hm.sendHandshakeResponse(via, response, hostinfo, false) + } + + if result == nil { + return + } + + // Handshake complete; build the ConnectionState now that we have keys and a verified peer cert. + hostinfo.ConnectionState = newConnectionStateFromResult(result) + + remoteCert := result.RemoteCert + if remoteCert == nil { + f.l.Error("Handshake completed without peer certificate", + "vpnAddrs", hostinfo.vpnAddrs, "from", via) + hm.DeleteHostInfo(hostinfo) + return + } + + vpnNetworks := remoteCert.Certificate.Networks() + hostinfo.remoteIndexId = result.RemoteIndex + hostinfo.lastHandshakeTime = result.HandshakeTime + + if !via.IsRelayed { + hostinfo.SetRemote(via.UdpAddr) + } else { + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) + } + + // Verify correct host responded (initiator check) + vpnAddrs := make([]netip.Addr, len(vpnNetworks)) + correctHostResponded := false + anyVpnAddrsInCommon := false + for i, network := range vpnNetworks { + // inside.go drops self-routed packets at the firewall stage, but we'd + // rather not let a self-handshake complete in the first place: it + // wastes a hostmap slot, suppresses no log, and obscures routing + // misconfig. Explicit refusal here mirrors the responder-side check + // in validatePeerCert. + if f.myVpnAddrsTable.Contains(network.Addr()) { + f.l.Error("Refusing to handshake with myself", + "vpnNetworks", vpnNetworks, + "from", via, + "certName", remoteCert.Certificate.Name(), + "certVersion", remoteCert.Certificate.Version(), + "fingerprint", remoteCert.Fingerprint, + "issuer", remoteCert.Certificate.Issuer(), + "handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())}, + ) + hm.DeleteHostInfo(hostinfo) + return + } + vpnAddrs[i] = network.Addr() + if hostinfo.vpnAddrs[0] == network.Addr() { + correctHostResponded = true + } + if f.myVpnNetworksTable.Contains(network.Addr()) { + anyVpnAddrsInCommon = true + } + } + + if !correctHostResponded { + f.l.Info("Incorrect host responded to handshake", + "intendedVpnAddrs", hostinfo.vpnAddrs, + "haveVpnNetworks", vpnNetworks, + "from", via, + "certName", remoteCert.Certificate.Name(), + "certVersion", remoteCert.Certificate.Version(), + "fingerprint", remoteCert.Fingerprint, + "issuer", remoteCert.Certificate.Issuer(), + "handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())}, + ) + + hm.DeleteHostInfo(hostinfo) + hm.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) { + newHH.hostinfo.remotes = hostinfo.remotes + newHH.hostinfo.remotes.BlockRemote(via) + newHH.packetStore = hh.packetStore + hh.packetStore = []*cachedPacket{} + hostinfo.vpnAddrs = vpnAddrs + f.sendCloseTunnel(hostinfo) + }) + return + } + + duration := time.Since(hh.startTime).Nanoseconds() + msg := "Handshake message received" + if !anyVpnAddrsInCommon { + msg = "Handshake message received, but no vpnNetworks in common." + } + f.l.Info(msg, + "vpnAddrs", vpnAddrs, + "from", via, + "certName", remoteCert.Certificate.Name(), + "certVersion", remoteCert.Certificate.Version(), + "fingerprint", remoteCert.Fingerprint, + "issuer", remoteCert.Certificate.Issuer(), + "initiatorIndex", result.LocalIndex, + "responderIndex", result.RemoteIndex, + "handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())}, + "durationNs", duration, + "sentCachedPackets", len(hh.packetStore), + ) + + hostinfo.vpnAddrs = vpnAddrs + hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) + + hm.Complete(hostinfo, f) + f.connectionManager.AddTrafficWatch(hostinfo) + + if len(hh.packetStore) > 0 { + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Sending stored packets", "count", len(hh.packetStore)) + } + nb := make([]byte, 12, 12) + out := make([]byte, mtu) + for _, cp := range hh.packetStore { + cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out) + } + f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore))) + } + + hostinfo.remotes.RefreshFromHandshake(vpnAddrs) + f.metricHandshakes.Update(duration) + + // Don't wait for UpdateWorker + if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) { + f.lightHouse.TriggerUpdate() + } +} + +// validatePeerCert checks the peer certificate for self-connection and remote allow list. +// Returns the VPN addrs, whether any of them fall within one of our own VPN +// networks, and true if valid; false if rejected. +func (hm *HandshakeManager) validatePeerCert(via ViaSender, remoteCert *cert.CachedCertificate) ([]netip.Addr, bool, bool) { + f := hm.f + vpnNetworks := remoteCert.Certificate.Networks() + + // The cert package rejects host certs with no networks at parse time, so + // reaching this state would mean an invariant was bypassed elsewhere. + // Refuse explicitly so downstream code (which indexes vpnAddrs[0]) can't + // panic if that invariant ever changes. + if len(vpnNetworks) == 0 { + f.l.Info("No networks in certificate", + "from", via, "cert", remoteCert) + return nil, false, false + } + + vpnAddrs := make([]netip.Addr, len(vpnNetworks)) + anyVpnAddrsInCommon := false + + for i, network := range vpnNetworks { + if f.myVpnAddrsTable.Contains(network.Addr()) { + f.l.Error("Refusing to handshake with myself", + "vpnNetworks", vpnNetworks, + "from", via, + "certName", remoteCert.Certificate.Name(), + "certVersion", remoteCert.Certificate.Version(), + "fingerprint", remoteCert.Fingerprint, + "issuer", remoteCert.Certificate.Issuer(), + ) + return nil, false, false + } + vpnAddrs[i] = network.Addr() + if f.myVpnNetworksTable.Contains(network.Addr()) { + anyVpnAddrsInCommon = true + } + } + + if !via.IsRelayed { + if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) { + f.l.Debug("lighthouse.remote_allow_list denied incoming handshake", + "vpnAddrs", vpnAddrs, "from", via) + return nil, false, false + } + } + + return vpnAddrs, anyVpnAddrsInCommon, true +} + +// sendHandshakeResponse sends a handshake response via the appropriate transport. +// cached is true when msg is a stored response being retransmitted because +// the peer's stage-1 retransmit landed (the ErrAlreadySeen path); false on a +// fresh response. +func (hm *HandshakeManager) sendHandshakeResponse(via ViaSender, msg []byte, hostinfo *HostInfo, cached bool) { + if msg == nil { + return + } + + f := hm.f + f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) + + // Common log fields. peerCert may be nil during intermediate + // multi-message flows (handshake hasn't completed yet); skip the cert + // block if so. + logFields := []any{ + "vpnAddrs", hostinfo.vpnAddrs, + "handshake", m{"stage": uint64(2), "style": header.SubTypeName(header.Handshake, header.HandshakeIXPSK0)}, + "cached", cached, + "initiatorIndex", hostinfo.remoteIndexId, + "responderIndex", hostinfo.localIndexId, + } + if peerCert := hostinfo.ConnectionState.peerCert; peerCert != nil { + logFields = append(logFields, + "certName", peerCert.Certificate.Name(), + "certVersion", peerCert.Certificate.Version(), + "fingerprint", peerCert.Fingerprint, + "issuer", peerCert.Certificate.Issuer(), + ) + } + + if !via.IsRelayed { + fields := append(logFields, "from", via) + err := f.outside.WriteTo(msg, via.UdpAddr) + if err != nil { + f.l.Error("Failed to send handshake message", append(fields, "error", err)...) + } else { + f.l.Info("Handshake message sent", fields...) + } + } else { + if via.relay == nil { + f.l.Error("Handshake send failed: both addr and via.relay are nil.") + return + } + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) + // We received a valid handshake on this relay, so make sure the relay + // state reflects that, in case it had been marked Disestablished. + via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established) + f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) + f.l.Info("Handshake message sent", append(logFields, "relay", via.relayHI.vpnAddrs[0])...) + } +} + +// handleCheckAndCompleteError handles errors from CheckAndComplete. +// This only fires from the responder-side beginHandshake path, after the +// peer cert has been validated and ConnectionState populated, so peerCert +// is always non-nil for the cases that log it. +func (hm *HandshakeManager) handleCheckAndCompleteError(err error, existing, hostinfo *HostInfo, via ViaSender) { + f := hm.f + peerCert := hostinfo.ConnectionState.peerCert + hsFields := m{"stage": uint64(1), "style": header.SubTypeName(header.Handshake, header.HandshakeIXPSK0)} + + switch err { + case ErrAlreadySeen: + if existing.SetRemoteIfPreferred(f.hostMap, via) { + f.SendMessageToVpnAddr(header.Test, header.TestRequest, hostinfo.vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + } + // Resend the original response. The peer is committed to that response's + // ephemeral keys; a freshly-built one would have different keys and break + // the tunnel even though both sides "completed" the handshake. + if msg := existing.HandshakePacket[handshakePacketStage2]; msg != nil { + hm.sendHandshakeResponse(via, msg, existing, true) + } + + case ErrExistingHostInfo: + f.l.Info("Handshake too old", + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "certName", peerCert.Certificate.Name(), + "certVersion", peerCert.Certificate.Version(), + "fingerprint", peerCert.Fingerprint, + "issuer", peerCert.Certificate.Issuer(), + "oldHandshakeTime", existing.lastHandshakeTime, + "newHandshakeTime", hostinfo.lastHandshakeTime, + "initiatorIndex", hostinfo.remoteIndexId, + "responderIndex", hostinfo.localIndexId, + "handshake", hsFields, + ) + f.SendMessageToVpnAddr(header.Test, header.TestRequest, hostinfo.vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + + case ErrLocalIndexCollision: + f.l.Error("Failed to add HostInfo due to localIndex collision", + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "certName", peerCert.Certificate.Name(), + "certVersion", peerCert.Certificate.Version(), + "fingerprint", peerCert.Fingerprint, + "issuer", peerCert.Certificate.Issuer(), + "localIndex", hostinfo.localIndexId, + "initiatorIndex", hostinfo.remoteIndexId, + "responderIndex", hostinfo.localIndexId, + "handshake", hsFields, + ) + + default: + f.l.Error("Failed to add HostInfo to HostMap", + "vpnAddrs", hostinfo.vpnAddrs, + "from", via, + "error", err, + "certName", peerCert.Certificate.Name(), + "certVersion", peerCert.Certificate.Version(), + "fingerprint", peerCert.Fingerprint, + "issuer", peerCert.Certificate.Issuer(), + "initiatorIndex", hostinfo.remoteIndexId, + "responderIndex", hostinfo.localIndexId, + "handshake", hsFields, + ) + } +} + +// certVerifier returns a CertVerifier that validates certs against the current CA pool. +func (hm *HandshakeManager) certVerifier() handshake.CertVerifier { + return func(c cert.Certificate) (*cert.CachedCertificate, error) { + return hm.f.pki.GetCAPool().VerifyCertificate(time.Now(), c) + } +} diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 2e6d34b5..5f8383e4 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/gaissmai/bart" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/test" @@ -27,7 +28,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { initiatingVersion: cert.Version1, privateKey: []byte{}, v1Cert: &dummyCert{version: cert.Version1}, - v1HandshakeBytes: []byte{}, + v1Credential: nil, } blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig) @@ -100,3 +101,137 @@ func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo { func (mw *mockEncWriter) GetCertState() *CertState { return &CertState{initiatingVersion: cert.Version2} } + +func TestValidatePeerCert(t *testing.T) { + l := test.NewLogger() + + myNetwork := netip.MustParsePrefix("10.0.0.1/24") + myAddrTable := new(bart.Lite) + myAddrTable.Insert(netip.PrefixFrom(myNetwork.Addr(), myNetwork.Addr().BitLen())) + myNetTable := new(bart.Lite) + myNetTable.Insert(myNetwork.Masked()) + + newHM := func() *HandshakeManager { + hm := NewHandshakeManager(l, newHostMap(l), newTestLighthouse(), &udp.NoopConn{}, defaultHandshakeConfig) + hm.f = &Interface{ + handshakeManager: hm, + pki: &PKI{}, + l: l, + myVpnAddrsTable: myAddrTable, + myVpnNetworksTable: myNetTable, + lightHouse: hm.lightHouse, + } + return hm + } + + cached := func(networks ...netip.Prefix) *cert.CachedCertificate { + return &cert.CachedCertificate{ + Certificate: &dummyCert{name: "peer", networks: networks}, + } + } + + via := ViaSender{ + UdpAddr: netip.MustParseAddrPort("198.51.100.7:4242"), + IsRelayed: true, // skip the remote allow list (covered separately) + } + + t.Run("addr inside our networks sets anyVpnAddrsInCommon", func(t *testing.T) { + hm := newHM() + // 10.0.0.2 falls inside our 10.0.0.0/24 + addrs, common, ok := hm.validatePeerCert(via, cached(netip.MustParsePrefix("10.0.0.2/24"))) + assert.True(t, ok) + assert.True(t, common) + assert.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.2")}, addrs) + }) + + t.Run("addr outside our networks leaves anyVpnAddrsInCommon false", func(t *testing.T) { + hm := newHM() + addrs, common, ok := hm.validatePeerCert(via, cached(netip.MustParsePrefix("192.168.1.5/24"))) + assert.True(t, ok) + assert.False(t, common) + assert.Equal(t, []netip.Addr{netip.MustParseAddr("192.168.1.5")}, addrs) + }) + + t.Run("any matching network is enough", func(t *testing.T) { + hm := newHM() + addrs, common, ok := hm.validatePeerCert(via, cached( + netip.MustParsePrefix("192.168.1.5/24"), + netip.MustParsePrefix("10.0.0.42/24"), + )) + assert.True(t, ok) + assert.True(t, common) + assert.Len(t, addrs, 2) + }) + + t.Run("self-handshake is rejected", func(t *testing.T) { + hm := newHM() + // 10.0.0.1 is in myVpnAddrsTable + addrs, common, ok := hm.validatePeerCert(via, cached(netip.MustParsePrefix("10.0.0.1/24"))) + assert.False(t, ok) + assert.False(t, common) + assert.Nil(t, addrs) + }) + + t.Run("cert with no networks is rejected", func(t *testing.T) { + hm := newHM() + addrs, common, ok := hm.validatePeerCert(via, cached()) + assert.False(t, ok) + assert.False(t, common) + assert.Nil(t, addrs) + }) +} + +func TestHandleIncomingDispatch(t *testing.T) { + l := test.NewLogger() + + newHM := func() *HandshakeManager { + hm := NewHandshakeManager(l, newHostMap(l), newTestLighthouse(), &udp.NoopConn{}, defaultHandshakeConfig) + hm.f = &Interface{ + handshakeManager: hm, + pki: &PKI{}, + l: l, + } + return hm + } + + via := ViaSender{ + UdpAddr: netip.MustParseAddrPort("198.51.100.7:4242"), + IsRelayed: true, // bypass remote allow list + } + + // A packet body of zero length is fine for these tests: dispatch is + // gated on header fields, and we assert that we never reach noise/cert + // processing for any of the malformed shapes here. + pkt := make([]byte, header.Len) + + t.Run("unsupported subtype dropped", func(t *testing.T) { + hm := newHM() + h := &header.H{Type: header.Handshake, Subtype: header.MessageSubType(99), MessageCounter: 1} + hm.HandleIncoming(via, pkt, h) + assert.Empty(t, hm.indexes, "no pending handshake should be created") + }) + + t.Run("stage-1 with non-zero RemoteIndex dropped", func(t *testing.T) { + hm := newHM() + h := &header.H{ + Type: header.Handshake, + Subtype: header.HandshakeIXPSK0, + RemoteIndex: 0xdeadbeef, + MessageCounter: 1, + } + hm.HandleIncoming(via, pkt, h) + assert.Empty(t, hm.indexes, "spoofed stage-1 must not create a pending machine") + }) + + t.Run("continuation with no matching pending index dropped", func(t *testing.T) { + hm := newHM() + h := &header.H{ + Type: header.Handshake, + Subtype: header.HandshakeIXPSK0, + RemoteIndex: 0xcafef00d, + MessageCounter: 2, + } + hm.HandleIncoming(via, pkt, h) + assert.Empty(t, hm.indexes, "orphan stage-2 must not create state") + }) +} diff --git a/nebula.pb.go b/nebula.pb.go index 2fd2ff66..94a4ebe2 100644 --- a/nebula.pb.go +++ b/nebula.pb.go @@ -124,7 +124,7 @@ func (x NebulaControl_MessageType) String() string { } func (NebulaControl_MessageType) EnumDescriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{8, 0} + return fileDescriptor_2d65afa7693df5ef, []int{6, 0} } type NebulaMeta struct { @@ -489,142 +489,6 @@ func (m *NebulaPing) GetTime() uint64 { return 0 } -type NebulaHandshake struct { - Details *NebulaHandshakeDetails `protobuf:"bytes,1,opt,name=Details,proto3" json:"Details,omitempty"` - Hmac []byte `protobuf:"bytes,2,opt,name=Hmac,proto3" json:"Hmac,omitempty"` -} - -func (m *NebulaHandshake) Reset() { *m = NebulaHandshake{} } -func (m *NebulaHandshake) String() string { return proto.CompactTextString(m) } -func (*NebulaHandshake) ProtoMessage() {} -func (*NebulaHandshake) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{6} -} -func (m *NebulaHandshake) XXX_Unmarshal(b []byte) error { - return m.Unmarshal(b) -} -func (m *NebulaHandshake) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - if deterministic { - return xxx_messageInfo_NebulaHandshake.Marshal(b, m, deterministic) - } else { - b = b[:cap(b)] - n, err := m.MarshalToSizedBuffer(b) - if err != nil { - return nil, err - } - return b[:n], nil - } -} -func (m *NebulaHandshake) XXX_Merge(src proto.Message) { - xxx_messageInfo_NebulaHandshake.Merge(m, src) -} -func (m *NebulaHandshake) XXX_Size() int { - return m.Size() -} -func (m *NebulaHandshake) XXX_DiscardUnknown() { - xxx_messageInfo_NebulaHandshake.DiscardUnknown(m) -} - -var xxx_messageInfo_NebulaHandshake proto.InternalMessageInfo - -func (m *NebulaHandshake) GetDetails() *NebulaHandshakeDetails { - if m != nil { - return m.Details - } - return nil -} - -func (m *NebulaHandshake) GetHmac() []byte { - if m != nil { - return m.Hmac - } - return nil -} - -type NebulaHandshakeDetails struct { - Cert []byte `protobuf:"bytes,1,opt,name=Cert,proto3" json:"Cert,omitempty"` - InitiatorIndex uint32 `protobuf:"varint,2,opt,name=InitiatorIndex,proto3" json:"InitiatorIndex,omitempty"` - ResponderIndex uint32 `protobuf:"varint,3,opt,name=ResponderIndex,proto3" json:"ResponderIndex,omitempty"` - Cookie uint64 `protobuf:"varint,4,opt,name=Cookie,proto3" json:"Cookie,omitempty"` - Time uint64 `protobuf:"varint,5,opt,name=Time,proto3" json:"Time,omitempty"` - CertVersion uint32 `protobuf:"varint,8,opt,name=CertVersion,proto3" json:"CertVersion,omitempty"` -} - -func (m *NebulaHandshakeDetails) Reset() { *m = NebulaHandshakeDetails{} } -func (m *NebulaHandshakeDetails) String() string { return proto.CompactTextString(m) } -func (*NebulaHandshakeDetails) ProtoMessage() {} -func (*NebulaHandshakeDetails) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{7} -} -func (m *NebulaHandshakeDetails) XXX_Unmarshal(b []byte) error { - return m.Unmarshal(b) -} -func (m *NebulaHandshakeDetails) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - if deterministic { - return xxx_messageInfo_NebulaHandshakeDetails.Marshal(b, m, deterministic) - } else { - b = b[:cap(b)] - n, err := m.MarshalToSizedBuffer(b) - if err != nil { - return nil, err - } - return b[:n], nil - } -} -func (m *NebulaHandshakeDetails) XXX_Merge(src proto.Message) { - xxx_messageInfo_NebulaHandshakeDetails.Merge(m, src) -} -func (m *NebulaHandshakeDetails) XXX_Size() int { - return m.Size() -} -func (m *NebulaHandshakeDetails) XXX_DiscardUnknown() { - xxx_messageInfo_NebulaHandshakeDetails.DiscardUnknown(m) -} - -var xxx_messageInfo_NebulaHandshakeDetails proto.InternalMessageInfo - -func (m *NebulaHandshakeDetails) GetCert() []byte { - if m != nil { - return m.Cert - } - return nil -} - -func (m *NebulaHandshakeDetails) GetInitiatorIndex() uint32 { - if m != nil { - return m.InitiatorIndex - } - return 0 -} - -func (m *NebulaHandshakeDetails) GetResponderIndex() uint32 { - if m != nil { - return m.ResponderIndex - } - return 0 -} - -func (m *NebulaHandshakeDetails) GetCookie() uint64 { - if m != nil { - return m.Cookie - } - return 0 -} - -func (m *NebulaHandshakeDetails) GetTime() uint64 { - if m != nil { - return m.Time - } - return 0 -} - -func (m *NebulaHandshakeDetails) GetCertVersion() uint32 { - if m != nil { - return m.CertVersion - } - return 0 -} - type NebulaControl struct { Type NebulaControl_MessageType `protobuf:"varint,1,opt,name=Type,proto3,enum=nebula.NebulaControl_MessageType" json:"Type,omitempty"` InitiatorRelayIndex uint32 `protobuf:"varint,2,opt,name=InitiatorRelayIndex,proto3" json:"InitiatorRelayIndex,omitempty"` @@ -639,7 +503,7 @@ func (m *NebulaControl) Reset() { *m = NebulaControl{} } func (m *NebulaControl) String() string { return proto.CompactTextString(m) } func (*NebulaControl) ProtoMessage() {} func (*NebulaControl) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{8} + return fileDescriptor_2d65afa7693df5ef, []int{6} } func (m *NebulaControl) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -729,65 +593,55 @@ func init() { proto.RegisterType((*V4AddrPort)(nil), "nebula.V4AddrPort") proto.RegisterType((*V6AddrPort)(nil), "nebula.V6AddrPort") proto.RegisterType((*NebulaPing)(nil), "nebula.NebulaPing") - proto.RegisterType((*NebulaHandshake)(nil), "nebula.NebulaHandshake") - proto.RegisterType((*NebulaHandshakeDetails)(nil), "nebula.NebulaHandshakeDetails") proto.RegisterType((*NebulaControl)(nil), "nebula.NebulaControl") } func init() { proto.RegisterFile("nebula.proto", fileDescriptor_2d65afa7693df5ef) } var fileDescriptor_2d65afa7693df5ef = []byte{ - // 785 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x55, 0xcd, 0x6e, 0xeb, 0x44, - 0x14, 0x8e, 0x1d, 0x27, 0x4e, 0x4f, 0x7e, 0xae, 0x39, 0x15, 0xc1, 0x41, 0x22, 0x0a, 0x5e, 0x54, - 0x57, 0x2c, 0x72, 0x51, 0x5a, 0xae, 0x58, 0x72, 0x1b, 0x84, 0xd2, 0xaa, 0x3f, 0x61, 0x54, 0x8a, - 0xc4, 0x06, 0xb9, 0xf6, 0xd0, 0x58, 0x71, 0x3c, 0xa9, 0x3d, 0x41, 0xcd, 0x5b, 0xf0, 0x30, 0x3c, - 0x04, 0xec, 0xba, 0x42, 0x2c, 0x51, 0xbb, 0x64, 0xc9, 0x0b, 0xa0, 0x19, 0xff, 0x27, 0x86, 0xbb, - 0x9b, 0x73, 0xbe, 0xef, 0x3b, 0x73, 0xe6, 0xf3, 0x9c, 0x31, 0x74, 0x02, 0x7a, 0xb7, 0xf1, 0xed, - 0xf1, 0x3a, 0x64, 0x9c, 0x61, 0x33, 0x8e, 0xac, 0xbf, 0x55, 0x80, 0x2b, 0xb9, 0xbc, 0xa4, 0xdc, - 0xc6, 0x09, 0x68, 0x37, 0xdb, 0x35, 0x35, 0x95, 0x91, 0xf2, 0xba, 0x37, 0x19, 0x8e, 0x13, 0x4d, - 0xce, 0x18, 0x5f, 0xd2, 0x28, 0xb2, 0xef, 0xa9, 0x60, 0x11, 0xc9, 0xc5, 0x63, 0xd0, 0xbf, 0xa6, - 0xdc, 0xf6, 0xfc, 0xc8, 0x54, 0x47, 0xca, 0xeb, 0xf6, 0x64, 0xb0, 0x2f, 0x4b, 0x08, 0x24, 0x65, - 0x5a, 0xff, 0x28, 0xd0, 0x2e, 0x94, 0xc2, 0x16, 0x68, 0x57, 0x2c, 0xa0, 0x46, 0x0d, 0xbb, 0x70, - 0x30, 0x63, 0x11, 0xff, 0x76, 0x43, 0xc3, 0xad, 0xa1, 0x20, 0x42, 0x2f, 0x0b, 0x09, 0x5d, 0xfb, - 0x5b, 0x43, 0xc5, 0x8f, 0xa1, 0x2f, 0x72, 0xdf, 0xad, 0x5d, 0x9b, 0xd3, 0x2b, 0xc6, 0xbd, 0x9f, - 0x3c, 0xc7, 0xe6, 0x1e, 0x0b, 0x8c, 0x3a, 0x0e, 0xe0, 0x43, 0x81, 0x5d, 0xb2, 0x9f, 0xa9, 0x5b, - 0x82, 0xb4, 0x14, 0x9a, 0x6f, 0x02, 0x67, 0x51, 0x82, 0x1a, 0xd8, 0x03, 0x10, 0xd0, 0xf7, 0x0b, - 0x66, 0xaf, 0x3c, 0xa3, 0x89, 0x87, 0xf0, 0x2a, 0x8f, 0xe3, 0x6d, 0x75, 0xd1, 0xd9, 0xdc, 0xe6, - 0x8b, 0xe9, 0x82, 0x3a, 0x4b, 0xa3, 0x25, 0x3a, 0xcb, 0xc2, 0x98, 0x72, 0x80, 0x9f, 0xc0, 0xa0, - 0xba, 0xb3, 0x77, 0xce, 0xd2, 0x00, 0xeb, 0x77, 0x15, 0x3e, 0xd8, 0x33, 0x05, 0x2d, 0x80, 0x6b, - 0xdf, 0xbd, 0x5d, 0x07, 0xef, 0x5c, 0x37, 0x94, 0xd6, 0x77, 0x4f, 0x55, 0x53, 0x21, 0x85, 0x2c, - 0x1e, 0x81, 0x9e, 0x12, 0x9a, 0xd2, 0xe4, 0x4e, 0x6a, 0xb2, 0xc8, 0x91, 0x14, 0xc4, 0x31, 0x18, - 0xd7, 0xbe, 0x4b, 0xa8, 0x6f, 0x6f, 0x93, 0x54, 0x64, 0x36, 0x46, 0xf5, 0xa4, 0xe2, 0x1e, 0x86, - 0x13, 0xe8, 0x96, 0xc9, 0xfa, 0xa8, 0xbe, 0x57, 0xbd, 0x4c, 0xc1, 0x13, 0x68, 0xdf, 0x9e, 0x88, - 0xe5, 0x9c, 0x85, 0x5c, 0x7c, 0x74, 0xa1, 0xc0, 0x54, 0x91, 0x43, 0xa4, 0x48, 0x93, 0xaa, 0xb7, - 0xb9, 0x4a, 0xdb, 0x51, 0xbd, 0x2d, 0xa8, 0x72, 0x1a, 0x9a, 0xa0, 0x3b, 0x6c, 0x13, 0x70, 0x1a, - 0x9a, 0x75, 0x61, 0x0c, 0x49, 0x43, 0xeb, 0x08, 0x34, 0x79, 0xe2, 0x1e, 0xa8, 0x33, 0x4f, 0xba, - 0xa6, 0x11, 0x75, 0xe6, 0x89, 0xf8, 0x82, 0xc9, 0x9b, 0xa8, 0x11, 0xf5, 0x82, 0x59, 0x27, 0x00, - 0x79, 0x1b, 0x88, 0xb1, 0x2a, 0x76, 0x99, 0xc4, 0x15, 0x10, 0x34, 0x81, 0x49, 0x4d, 0x97, 0xc8, - 0xb5, 0xf5, 0x15, 0x40, 0xde, 0xc6, 0xfb, 0xf6, 0xc8, 0x2a, 0xd4, 0x0b, 0x15, 0x1e, 0xd3, 0xc1, - 0x9a, 0x7b, 0xc1, 0xfd, 0xff, 0x0f, 0x96, 0x60, 0x54, 0x0c, 0x16, 0x82, 0x76, 0xe3, 0xad, 0x68, - 0xb2, 0x8f, 0x5c, 0x5b, 0xd6, 0xde, 0xd8, 0x08, 0xb1, 0x51, 0xc3, 0x03, 0x68, 0xc4, 0x97, 0x50, - 0xb1, 0x7e, 0x84, 0x57, 0x71, 0xdd, 0x99, 0x1d, 0xb8, 0xd1, 0xc2, 0x5e, 0x52, 0xfc, 0x32, 0x9f, - 0x51, 0x45, 0x5e, 0x9f, 0x9d, 0x0e, 0x32, 0xe6, 0xee, 0xa0, 0x8a, 0x26, 0x66, 0x2b, 0xdb, 0x91, - 0x4d, 0x74, 0x88, 0x5c, 0x5b, 0x7f, 0x28, 0xd0, 0xaf, 0xd6, 0x09, 0xfa, 0x94, 0x86, 0x5c, 0xee, - 0xd2, 0x21, 0x72, 0x8d, 0x47, 0xd0, 0x3b, 0x0b, 0x3c, 0xee, 0xd9, 0x9c, 0x85, 0x67, 0x81, 0x4b, - 0x1f, 0x13, 0xa7, 0x77, 0xb2, 0x82, 0x47, 0x68, 0xb4, 0x66, 0x81, 0x4b, 0x13, 0x5e, 0xec, 0xe7, - 0x4e, 0x16, 0xfb, 0xd0, 0x9c, 0x32, 0xb6, 0xf4, 0xa8, 0xa9, 0x49, 0x67, 0x92, 0x28, 0xf3, 0xab, - 0x91, 0xfb, 0x85, 0x23, 0x68, 0x8b, 0x1e, 0x6e, 0x69, 0x18, 0x79, 0x2c, 0x30, 0x5b, 0xb2, 0x60, - 0x31, 0x75, 0xae, 0xb5, 0x9a, 0x86, 0x7e, 0xae, 0xb5, 0x74, 0xa3, 0x65, 0xfd, 0x5a, 0x87, 0x6e, - 0x7c, 0xb0, 0x29, 0x0b, 0x78, 0xc8, 0x7c, 0xfc, 0xa2, 0xf4, 0xdd, 0x3e, 0x2d, 0xbb, 0x96, 0x90, - 0x2a, 0x3e, 0xdd, 0xe7, 0x70, 0x98, 0x1d, 0x4e, 0x0e, 0x4f, 0xf1, 0xdc, 0x55, 0x90, 0x50, 0x64, - 0xc7, 0x2c, 0x28, 0x62, 0x07, 0xaa, 0x20, 0xfc, 0x0c, 0x7a, 0xe9, 0x38, 0xdf, 0x30, 0x79, 0xa9, - 0xb5, 0xec, 0xe9, 0xd8, 0x41, 0x8a, 0xcf, 0xc2, 0x37, 0x21, 0x5b, 0x49, 0x76, 0x23, 0x63, 0xef, - 0x61, 0x38, 0x86, 0x76, 0xb1, 0x70, 0xd5, 0x93, 0x53, 0x24, 0x64, 0xcf, 0x48, 0x56, 0x5c, 0xaf, - 0x50, 0x94, 0x29, 0xd6, 0xec, 0xbf, 0xfe, 0x00, 0x7d, 0xc0, 0x69, 0x48, 0x6d, 0x4e, 0x25, 0x9f, - 0xd0, 0x87, 0x0d, 0x8d, 0xb8, 0xa1, 0xe0, 0x47, 0x70, 0x58, 0xca, 0x0b, 0x4b, 0x22, 0x6a, 0xa8, - 0xa7, 0xc7, 0xbf, 0x3d, 0x0f, 0x95, 0xa7, 0xe7, 0xa1, 0xf2, 0xd7, 0xf3, 0x50, 0xf9, 0xe5, 0x65, - 0x58, 0x7b, 0x7a, 0x19, 0xd6, 0xfe, 0x7c, 0x19, 0xd6, 0x7e, 0x18, 0xdc, 0x7b, 0x7c, 0xb1, 0xb9, - 0x1b, 0x3b, 0x6c, 0xf5, 0x26, 0xf2, 0x6d, 0x67, 0xb9, 0x78, 0x78, 0x13, 0xb7, 0x74, 0xd7, 0x94, - 0x3f, 0xc2, 0xe3, 0x7f, 0x03, 0x00, 0x00, 0xff, 0xff, 0xea, 0x6f, 0xbc, 0x50, 0x18, 0x07, 0x00, - 0x00, + // 665 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x54, 0xcd, 0x6e, 0xd3, 0x5c, + 0x10, 0x8d, 0x1d, 0x27, 0x69, 0x27, 0x4d, 0x3e, 0x7f, 0x53, 0x51, 0x12, 0x24, 0xac, 0xe0, 0x45, + 0x55, 0xb1, 0x48, 0x51, 0x5a, 0xba, 0xa6, 0x2d, 0x42, 0xa9, 0xd4, 0x9f, 0x70, 0x55, 0x8a, 0xc4, + 0xce, 0xb5, 0x2f, 0x8d, 0x55, 0xc7, 0x37, 0xb5, 0x6f, 0x50, 0xf3, 0x16, 0x3c, 0x0c, 0x0f, 0x01, + 0xbb, 0x2e, 0x59, 0xa2, 0x66, 0xc9, 0x92, 0x17, 0x40, 0xf7, 0xfa, 0xbf, 0x31, 0xb0, 0xbb, 0x33, + 0xe7, 0x9c, 0x99, 0xc9, 0xc9, 0x8c, 0x61, 0xcd, 0xa7, 0x97, 0x33, 0xcf, 0xea, 0x4f, 0x03, 0xc6, + 0x19, 0xd6, 0xa3, 0xc8, 0xfc, 0xa9, 0x02, 0x9c, 0xca, 0xe7, 0x09, 0xe5, 0x16, 0x0e, 0x40, 0x3b, + 0x9f, 0x4f, 0x69, 0x47, 0xe9, 0x29, 0x5b, 0xed, 0x81, 0xd1, 0x8f, 0x35, 0x19, 0xa3, 0x7f, 0x42, + 0xc3, 0xd0, 0xba, 0xa2, 0x82, 0x45, 0x24, 0x17, 0x77, 0xa0, 0xf1, 0x9a, 0x72, 0xcb, 0xf5, 0xc2, + 0x8e, 0xda, 0x53, 0xb6, 0x9a, 0x83, 0xee, 0xb2, 0x2c, 0x26, 0x90, 0x84, 0x69, 0xfe, 0x52, 0xa0, + 0x99, 0x2b, 0x85, 0x2b, 0xa0, 0x9d, 0x32, 0x9f, 0xea, 0x15, 0x6c, 0xc1, 0xea, 0x90, 0x85, 0xfc, + 0xed, 0x8c, 0x06, 0x73, 0x5d, 0x41, 0x84, 0x76, 0x1a, 0x12, 0x3a, 0xf5, 0xe6, 0xba, 0x8a, 0x4f, + 0x60, 0x43, 0xe4, 0xde, 0x4d, 0x1d, 0x8b, 0xd3, 0x53, 0xc6, 0xdd, 0x8f, 0xae, 0x6d, 0x71, 0x97, + 0xf9, 0x7a, 0x15, 0xbb, 0xf0, 0x48, 0x60, 0x27, 0xec, 0x13, 0x75, 0x0a, 0x90, 0x96, 0x40, 0xa3, + 0x99, 0x6f, 0x8f, 0x0b, 0x50, 0x0d, 0xdb, 0x00, 0x02, 0x7a, 0x3f, 0x66, 0xd6, 0xc4, 0xd5, 0xeb, + 0xb8, 0x0e, 0xff, 0x65, 0x71, 0xd4, 0xb6, 0x21, 0x26, 0x1b, 0x59, 0x7c, 0x7c, 0x38, 0xa6, 0xf6, + 0xb5, 0xbe, 0x22, 0x26, 0x4b, 0xc3, 0x88, 0xb2, 0x8a, 0x4f, 0xa1, 0x5b, 0x3e, 0xd9, 0xbe, 0x7d, + 0xad, 0x83, 0xf9, 0x4d, 0x85, 0xff, 0x97, 0x4c, 0x41, 0x13, 0xe0, 0xcc, 0x73, 0x2e, 0xa6, 0xfe, + 0xbe, 0xe3, 0x04, 0xd2, 0xfa, 0xd6, 0x81, 0xda, 0x51, 0x48, 0x2e, 0x8b, 0x9b, 0xd0, 0x48, 0x08, + 0x75, 0x69, 0xf2, 0x5a, 0x62, 0xb2, 0xc8, 0x91, 0x04, 0xc4, 0x3e, 0xe8, 0x67, 0x9e, 0x43, 0xa8, + 0x67, 0xcd, 0xe3, 0x54, 0xd8, 0xa9, 0xf5, 0xaa, 0x71, 0xc5, 0x25, 0x0c, 0x07, 0xd0, 0x2a, 0x92, + 0x1b, 0xbd, 0xea, 0x52, 0xf5, 0x22, 0x05, 0x77, 0xa1, 0x79, 0xb1, 0x2b, 0x9e, 0x23, 0x16, 0x70, + 0xf1, 0xa7, 0x0b, 0x05, 0x26, 0x8a, 0x0c, 0x22, 0x79, 0x9a, 0x54, 0xed, 0x65, 0x2a, 0xed, 0x81, + 0x6a, 0x2f, 0xa7, 0xca, 0x68, 0xd8, 0x81, 0x86, 0xcd, 0x66, 0x3e, 0xa7, 0x41, 0xa7, 0x2a, 0x8c, + 0x21, 0x49, 0x68, 0x6e, 0x82, 0x26, 0x7f, 0x71, 0x1b, 0xd4, 0xa1, 0x2b, 0x5d, 0xd3, 0x88, 0x3a, + 0x74, 0x45, 0x7c, 0xcc, 0xe4, 0x26, 0x6a, 0x44, 0x3d, 0x66, 0xe6, 0x2e, 0x40, 0x36, 0x06, 0x62, + 0xa4, 0x8a, 0x5c, 0x26, 0x51, 0x05, 0x04, 0x4d, 0x60, 0x52, 0xd3, 0x22, 0xf2, 0x6d, 0xbe, 0x02, + 0xc8, 0xc6, 0xf8, 0x57, 0x8f, 0xb4, 0x42, 0x35, 0x57, 0xe1, 0x36, 0x39, 0xac, 0x91, 0xeb, 0x5f, + 0xfd, 0xfd, 0xb0, 0x04, 0xa3, 0xe4, 0xb0, 0x10, 0xb4, 0x73, 0x77, 0x42, 0xe3, 0x3e, 0xf2, 0x6d, + 0x9a, 0x4b, 0x67, 0x23, 0xc4, 0x7a, 0x05, 0x57, 0xa1, 0x16, 0x2d, 0xa1, 0x62, 0x7e, 0xa9, 0x42, + 0x2b, 0x2a, 0x7c, 0xc8, 0x7c, 0x1e, 0x30, 0x0f, 0x5f, 0x16, 0xba, 0x3f, 0x2b, 0x76, 0x8f, 0x49, + 0x25, 0x03, 0xbc, 0x80, 0xf5, 0x23, 0xdf, 0xe5, 0xae, 0xc5, 0x59, 0x20, 0x57, 0xe0, 0xc8, 0x77, + 0xe8, 0x6d, 0xec, 0x53, 0x19, 0x24, 0x14, 0x84, 0x86, 0x53, 0xe6, 0x3b, 0x34, 0xaf, 0x88, 0x7c, + 0x29, 0x83, 0xf0, 0x39, 0xb4, 0x93, 0xa5, 0x3c, 0x67, 0xf2, 0xaf, 0xd1, 0xd2, 0x03, 0x78, 0x80, + 0xe4, 0x97, 0xfb, 0x4d, 0xc0, 0x26, 0x92, 0x5d, 0x4b, 0xd9, 0x4b, 0x18, 0xf6, 0xa1, 0x99, 0x2f, + 0x5c, 0x76, 0x38, 0x79, 0x42, 0x7a, 0x0c, 0x69, 0xf1, 0x46, 0x89, 0xa2, 0x48, 0x31, 0x87, 0x7f, + 0xfa, 0x8e, 0x6d, 0x00, 0x1e, 0x06, 0xd4, 0xe2, 0x54, 0xf2, 0x09, 0xbd, 0x99, 0xd1, 0x90, 0xeb, + 0x0a, 0x3e, 0x86, 0xf5, 0x42, 0x5e, 0x58, 0x12, 0x52, 0x5d, 0x3d, 0xd8, 0xf9, 0x7a, 0x6f, 0x28, + 0x77, 0xf7, 0x86, 0xf2, 0xe3, 0xde, 0x50, 0x3e, 0x2f, 0x8c, 0xca, 0xdd, 0xc2, 0xa8, 0x7c, 0x5f, + 0x18, 0x95, 0x0f, 0xdd, 0x2b, 0x97, 0x8f, 0x67, 0x97, 0x7d, 0x9b, 0x4d, 0xb6, 0x43, 0xcf, 0xb2, + 0xaf, 0xc7, 0x37, 0xdb, 0xd1, 0x48, 0x97, 0x75, 0xf9, 0x39, 0xdf, 0xf9, 0x1d, 0x00, 0x00, 0xff, + 0xff, 0x51, 0x0a, 0xe3, 0xd7, 0xde, 0x05, 0x00, 0x00, } func (m *NebulaMeta) Marshal() (dAtA []byte, err error) { @@ -1072,103 +926,6 @@ func (m *NebulaPing) MarshalToSizedBuffer(dAtA []byte) (int, error) { return len(dAtA) - i, nil } -func (m *NebulaHandshake) Marshal() (dAtA []byte, err error) { - size := m.Size() - dAtA = make([]byte, size) - n, err := m.MarshalToSizedBuffer(dAtA[:size]) - if err != nil { - return nil, err - } - return dAtA[:n], nil -} - -func (m *NebulaHandshake) MarshalTo(dAtA []byte) (int, error) { - size := m.Size() - return m.MarshalToSizedBuffer(dAtA[:size]) -} - -func (m *NebulaHandshake) MarshalToSizedBuffer(dAtA []byte) (int, error) { - i := len(dAtA) - _ = i - var l int - _ = l - if len(m.Hmac) > 0 { - i -= len(m.Hmac) - copy(dAtA[i:], m.Hmac) - i = encodeVarintNebula(dAtA, i, uint64(len(m.Hmac))) - i-- - dAtA[i] = 0x12 - } - if m.Details != nil { - { - size, err := m.Details.MarshalToSizedBuffer(dAtA[:i]) - if err != nil { - return 0, err - } - i -= size - i = encodeVarintNebula(dAtA, i, uint64(size)) - } - i-- - dAtA[i] = 0xa - } - return len(dAtA) - i, nil -} - -func (m *NebulaHandshakeDetails) Marshal() (dAtA []byte, err error) { - size := m.Size() - dAtA = make([]byte, size) - n, err := m.MarshalToSizedBuffer(dAtA[:size]) - if err != nil { - return nil, err - } - return dAtA[:n], nil -} - -func (m *NebulaHandshakeDetails) MarshalTo(dAtA []byte) (int, error) { - size := m.Size() - return m.MarshalToSizedBuffer(dAtA[:size]) -} - -func (m *NebulaHandshakeDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) { - i := len(dAtA) - _ = i - var l int - _ = l - if m.CertVersion != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.CertVersion)) - i-- - dAtA[i] = 0x40 - } - if m.Time != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.Time)) - i-- - dAtA[i] = 0x28 - } - if m.Cookie != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.Cookie)) - i-- - dAtA[i] = 0x20 - } - if m.ResponderIndex != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.ResponderIndex)) - i-- - dAtA[i] = 0x18 - } - if m.InitiatorIndex != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.InitiatorIndex)) - i-- - dAtA[i] = 0x10 - } - if len(m.Cert) > 0 { - i -= len(m.Cert) - copy(dAtA[i:], m.Cert) - i = encodeVarintNebula(dAtA, i, uint64(len(m.Cert))) - i-- - dAtA[i] = 0xa - } - return len(dAtA) - i, nil -} - func (m *NebulaControl) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) @@ -1375,51 +1132,6 @@ func (m *NebulaPing) Size() (n int) { return n } -func (m *NebulaHandshake) Size() (n int) { - if m == nil { - return 0 - } - var l int - _ = l - if m.Details != nil { - l = m.Details.Size() - n += 1 + l + sovNebula(uint64(l)) - } - l = len(m.Hmac) - if l > 0 { - n += 1 + l + sovNebula(uint64(l)) - } - return n -} - -func (m *NebulaHandshakeDetails) Size() (n int) { - if m == nil { - return 0 - } - var l int - _ = l - l = len(m.Cert) - if l > 0 { - n += 1 + l + sovNebula(uint64(l)) - } - if m.InitiatorIndex != 0 { - n += 1 + sovNebula(uint64(m.InitiatorIndex)) - } - if m.ResponderIndex != 0 { - n += 1 + sovNebula(uint64(m.ResponderIndex)) - } - if m.Cookie != 0 { - n += 1 + sovNebula(uint64(m.Cookie)) - } - if m.Time != 0 { - n += 1 + sovNebula(uint64(m.Time)) - } - if m.CertVersion != 0 { - n += 1 + sovNebula(uint64(m.CertVersion)) - } - return n -} - func (m *NebulaControl) Size() (n int) { if m == nil { return 0 @@ -2236,305 +1948,6 @@ func (m *NebulaPing) Unmarshal(dAtA []byte) error { } return nil } -func (m *NebulaHandshake) Unmarshal(dAtA []byte) error { - l := len(dAtA) - iNdEx := 0 - for iNdEx < l { - preIndex := iNdEx - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - wire |= uint64(b&0x7F) << shift - if b < 0x80 { - break - } - } - fieldNum := int32(wire >> 3) - wireType := int(wire & 0x7) - if wireType == 4 { - return fmt.Errorf("proto: NebulaHandshake: wiretype end group for non-group") - } - if fieldNum <= 0 { - return fmt.Errorf("proto: NebulaHandshake: illegal tag %d (wire type %d)", fieldNum, wire) - } - switch fieldNum { - case 1: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Details", wireType) - } - var msglen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - msglen |= int(b&0x7F) << shift - if b < 0x80 { - break - } - } - if msglen < 0 { - return ErrInvalidLengthNebula - } - postIndex := iNdEx + msglen - if postIndex < 0 { - return ErrInvalidLengthNebula - } - if postIndex > l { - return io.ErrUnexpectedEOF - } - if m.Details == nil { - m.Details = &NebulaHandshakeDetails{} - } - if err := m.Details.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { - return err - } - iNdEx = postIndex - case 2: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Hmac", wireType) - } - var byteLen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - byteLen |= int(b&0x7F) << shift - if b < 0x80 { - break - } - } - if byteLen < 0 { - return ErrInvalidLengthNebula - } - postIndex := iNdEx + byteLen - if postIndex < 0 { - return ErrInvalidLengthNebula - } - if postIndex > l { - return io.ErrUnexpectedEOF - } - m.Hmac = append(m.Hmac[:0], dAtA[iNdEx:postIndex]...) - if m.Hmac == nil { - m.Hmac = []byte{} - } - iNdEx = postIndex - default: - iNdEx = preIndex - skippy, err := skipNebula(dAtA[iNdEx:]) - if err != nil { - return err - } - if (skippy < 0) || (iNdEx+skippy) < 0 { - return ErrInvalidLengthNebula - } - if (iNdEx + skippy) > l { - return io.ErrUnexpectedEOF - } - iNdEx += skippy - } - } - - if iNdEx > l { - return io.ErrUnexpectedEOF - } - return nil -} -func (m *NebulaHandshakeDetails) Unmarshal(dAtA []byte) error { - l := len(dAtA) - iNdEx := 0 - for iNdEx < l { - preIndex := iNdEx - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - wire |= uint64(b&0x7F) << shift - if b < 0x80 { - break - } - } - fieldNum := int32(wire >> 3) - wireType := int(wire & 0x7) - if wireType == 4 { - return fmt.Errorf("proto: NebulaHandshakeDetails: wiretype end group for non-group") - } - if fieldNum <= 0 { - return fmt.Errorf("proto: NebulaHandshakeDetails: illegal tag %d (wire type %d)", fieldNum, wire) - } - switch fieldNum { - case 1: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Cert", wireType) - } - var byteLen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - byteLen |= int(b&0x7F) << shift - if b < 0x80 { - break - } - } - if byteLen < 0 { - return ErrInvalidLengthNebula - } - postIndex := iNdEx + byteLen - if postIndex < 0 { - return ErrInvalidLengthNebula - } - if postIndex > l { - return io.ErrUnexpectedEOF - } - m.Cert = append(m.Cert[:0], dAtA[iNdEx:postIndex]...) - if m.Cert == nil { - m.Cert = []byte{} - } - iNdEx = postIndex - case 2: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field InitiatorIndex", wireType) - } - m.InitiatorIndex = 0 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - m.InitiatorIndex |= uint32(b&0x7F) << shift - if b < 0x80 { - break - } - } - case 3: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field ResponderIndex", wireType) - } - m.ResponderIndex = 0 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - m.ResponderIndex |= uint32(b&0x7F) << shift - if b < 0x80 { - break - } - } - case 4: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field Cookie", wireType) - } - m.Cookie = 0 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - m.Cookie |= uint64(b&0x7F) << shift - if b < 0x80 { - break - } - } - case 5: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field Time", wireType) - } - m.Time = 0 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - m.Time |= uint64(b&0x7F) << shift - if b < 0x80 { - break - } - } - case 8: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field CertVersion", wireType) - } - m.CertVersion = 0 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowNebula - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - m.CertVersion |= uint32(b&0x7F) << shift - if b < 0x80 { - break - } - } - default: - iNdEx = preIndex - skippy, err := skipNebula(dAtA[iNdEx:]) - if err != nil { - return err - } - if (skippy < 0) || (iNdEx+skippy) < 0 { - return ErrInvalidLengthNebula - } - if (iNdEx + skippy) > l { - return io.ErrUnexpectedEOF - } - iNdEx += skippy - } - } - - if iNdEx > l { - return io.ErrUnexpectedEOF - } - return nil -} func (m *NebulaControl) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 diff --git a/nebula.proto b/nebula.proto index ea102334..7b44f473 100644 --- a/nebula.proto +++ b/nebula.proto @@ -60,21 +60,9 @@ message NebulaPing { uint64 Time = 2; } -message NebulaHandshake { - NebulaHandshakeDetails Details = 1; - bytes Hmac = 2; -} - -message NebulaHandshakeDetails { - bytes Cert = 1; - uint32 InitiatorIndex = 2; - uint32 ResponderIndex = 3; - uint64 Cookie = 4; - uint64 Time = 5; - uint32 CertVersion = 8; - // reserved for WIP multiport - reserved 6, 7; -} +// NebulaHandshake / NebulaHandshakeDetails moved to +// handshake/handshake.proto. The handshake package speaks that wire format +// directly via a hand-written encoder/decoder. message NebulaControl { enum MessageType { diff --git a/pki.go b/pki.go index fb8cc5c6..acc80486 100644 --- a/pki.go +++ b/pki.go @@ -15,9 +15,12 @@ import ( "sync/atomic" "time" + "github.com/flynn/noise" "github.com/gaissmai/bart" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/handshake" + "github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/util" ) @@ -28,11 +31,11 @@ type PKI struct { } type CertState struct { - v1Cert cert.Certificate - v1HandshakeBytes []byte + v1Cert cert.Certificate + v1Credential *handshake.Credential - v2Cert cert.Certificate - v2HandshakeBytes []byte + v2Cert cert.Certificate + v2Credential *handshake.Credential initiatingVersion cert.Version privateKey []byte @@ -92,13 +95,35 @@ func (p *PKI) reload(c *config.C, initial bool) error { } func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { - newState, err := newCertStateFromConfig(c) + var cipher string + var currentState *CertState + if initial { + cipher = c.GetString("cipher", "aes") + //TODO: this sucks and we should make it not a global + switch cipher { + case "aes": + noiseEndianness = binary.BigEndian + case "chachapoly": + noiseEndianness = binary.LittleEndian + default: + return util.NewContextualError( + "unknown cipher", + m{"cipher": cipher}, + nil, + ) + } + } else { + // Cipher cant be hot swapped so just leave it at what it was before + currentState = p.cs.Load() + cipher = currentState.cipher + } + + newState, err := newCertStateFromConfig(c, cipher) if err != nil { return util.NewContextualError("Could not load client cert", nil, err) } - if !initial { - currentState := p.cs.Load() + if currentState != nil { if newState.v1Cert != nil { if currentState.v1Cert == nil { //adding certs is fine, actually. Networks-in-common confirmed in newCertState(). @@ -158,25 +183,6 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { ) } } - - // Cipher cant be hot swapped so just leave it at what it was before - newState.cipher = currentState.cipher - - } else { - newState.cipher = c.GetString("cipher", "aes") - //TODO: this sucks and we should make it not a global - switch newState.cipher { - case "aes": - noiseEndianness = binary.BigEndian - case "chachapoly": - noiseEndianness = binary.LittleEndian - default: - return util.NewContextualError( - "unknown cipher", - m{"cipher": newState.cipher}, - nil, - ) - } } p.cs.Store(newState) @@ -208,6 +214,20 @@ func (cs *CertState) GetDefaultCertificate() cert.Certificate { return c } +// DefaultVersion returns the preferred cert version for initiating handshakes. +func (cs *CertState) DefaultVersion() cert.Version { return cs.initiatingVersion } + +// GetCredential returns the pre-computed handshake credential for the given version, or nil. +func (cs *CertState) GetCredential(v cert.Version) *handshake.Credential { + switch v { + case cert.Version1: + return cs.v1Credential + case cert.Version2: + return cs.v2Credential + } + return nil +} + func (cs *CertState) getCertificate(v cert.Version) cert.Certificate { switch v { case cert.Version1: @@ -219,17 +239,25 @@ func (cs *CertState) getCertificate(v cert.Version) cert.Certificate { return nil } -// getHandshakeBytes returns the cached bytes to be used in a handshake message for the requested version. -// Callers must check if the return []byte is nil. -func (cs *CertState) getHandshakeBytes(v cert.Version) []byte { - switch v { - case cert.Version1: - return cs.v1HandshakeBytes - case cert.Version2: - return cs.v2HandshakeBytes +func newCipherSuite(curve cert.Curve, pkcs11backed bool, cipher string) (noise.CipherSuite, error) { + var dhFunc noise.DHFunc + switch curve { + case cert.Curve_CURVE25519: + dhFunc = noise.DH25519 + case cert.Curve_P256: + if pkcs11backed { + dhFunc = noiseutil.DHP256PKCS11 + } else { + dhFunc = noiseutil.DHP256 + } default: - return nil + return nil, fmt.Errorf("unsupported curve: %s", curve) } + + if cipher == "chachapoly" { + return noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256), nil + } + return noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256), nil } func (cs *CertState) String() string { @@ -261,7 +289,7 @@ func (cs *CertState) MarshalJSON() ([]byte, error) { return json.Marshal(msg) } -func newCertStateFromConfig(c *config.C) (*CertState, error) { +func newCertStateFromConfig(c *config.C, cipher string) (*CertState, error) { var err error privPathOrPEM := c.GetString("pki.key", "") @@ -345,13 +373,14 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { return nil, fmt.Errorf("unknown pki.initiating_version: %v", rawInitiatingVersion) } - return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey) + return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey, cipher) } -func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) { +func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte, cipher string) (*CertState, error) { cs := CertState{ privateKey: privateKey, pkcs11Backed: pkcs11backed, + cipher: cipher, myVpnNetworksTable: new(bart.Lite), myVpnAddrsTable: new(bart.Lite), myVpnBroadcastAddrsTable: new(bart.Lite), @@ -384,10 +413,14 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p v1hs, err := v1.MarshalForHandshakes() if err != nil { - return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err) + return nil, fmt.Errorf("error marshalling v1 certificate for handshake: %w", err) + } + ncs, err := newCipherSuite(v1.Curve(), pkcs11backed, cipher) + if err != nil { + return nil, err } cs.v1Cert = v1 - cs.v1HandshakeBytes = v1hs + cs.v1Credential = handshake.NewCredential(v1, v1hs, privateKey, ncs) if cs.initiatingVersion == 0 { cs.initiatingVersion = cert.Version1 @@ -405,10 +438,14 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p v2hs, err := v2.MarshalForHandshakes() if err != nil { - return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err) + return nil, fmt.Errorf("error marshalling v2 certificate for handshake: %w", err) + } + ncs, err := newCipherSuite(v2.Curve(), pkcs11backed, cipher) + if err != nil { + return nil, err } cs.v2Cert = v2 - cs.v2HandshakeBytes = v2hs + cs.v2Credential = handshake.NewCredential(v2, v2hs, privateKey, ncs) if cs.initiatingVersion == 0 { cs.initiatingVersion = cert.Version2