From 5f920fdd7d5af2510516ef3e6dbd9543de8019ae Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Wed, 6 May 2026 17:37:03 -0500 Subject: [PATCH] Remove the global noiseEndianness var (#1707) --- connection_state.go | 9 +- handshake/machine.go | 2 + noise.go | 73 --------------- noiseutil/aesgcm.go | 53 +++++++++++ noiseutil/chachapoly.go | 52 +++++++++++ noiseutil/cipher_state.go | 40 ++++++++ noiseutil/cipher_state_test.go | 166 +++++++++++++++++++++++++++++++++ pki.go | 8 +- 8 files changed, 321 insertions(+), 82 deletions(-) delete mode 100644 noise.go create mode 100644 noiseutil/aesgcm.go create mode 100644 noiseutil/chachapoly.go create mode 100644 noiseutil/cipher_state.go create mode 100644 noiseutil/cipher_state_test.go diff --git a/connection_state.go b/connection_state.go index 47e23b5a..0ae2d9be 100644 --- a/connection_state.go +++ b/connection_state.go @@ -7,13 +7,14 @@ import ( "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/handshake" + "github.com/slackhq/nebula/noiseutil" ) const ReplayWindow = 1024 type ConnectionState struct { - eKey *NebulaCipherState - dKey *NebulaCipherState + eKey noiseutil.CipherState + dKey noiseutil.CipherState myCert cert.Certificate peerCert *cert.CachedCertificate initiator bool @@ -31,8 +32,8 @@ func newConnectionStateFromResult(r *handshake.Result) *ConnectionState { myCert: r.MyCert, initiator: r.Initiator, peerCert: r.RemoteCert, - eKey: NewNebulaCipherState(r.EKey), - dKey: NewNebulaCipherState(r.DKey), + eKey: noiseutil.NewCipherState(r.EKey, r.Cipher), + dKey: noiseutil.NewCipherState(r.DKey, r.Cipher), window: NewBits(ReplayWindow), } ci.messageCounter.Add(r.MessageIndex) diff --git a/handshake/machine.go b/handshake/machine.go index 25ed3a5a..737358dc 100644 --- a/handshake/machine.go +++ b/handshake/machine.go @@ -31,6 +31,7 @@ type CertVerifier func(cert.Certificate) (*cert.CachedCertificate, error) type Result struct { EKey *noise.CipherState DKey *noise.CipherState + Cipher noise.CipherFunc // identifies which post-handshake CipherState the data plane should wrap EKey/DKey in MyCert cert.Certificate RemoteCert *cert.CachedCertificate RemoteIndex uint32 @@ -105,6 +106,7 @@ func NewMachine( myVersion: version, result: &Result{ Initiator: initiator, + Cipher: cred.cipherSuite, }, }, nil } diff --git a/noise.go b/noise.go deleted file mode 100644 index 0491da17..00000000 --- a/noise.go +++ /dev/null @@ -1,73 +0,0 @@ -package nebula - -import ( - "crypto/cipher" - "encoding/binary" - "errors" - - "github.com/flynn/noise" -) - -type endianness interface { - PutUint64(b []byte, v uint64) -} - -var noiseEndianness endianness = binary.BigEndian - -type NebulaCipherState struct { - c cipher.AEAD -} - -func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState { - x := s.Cipher() - return &NebulaCipherState{c: x.(cipher.AEAD)} -} - -// EncryptDanger encrypts and authenticates a given payload. -// -// out is a destination slice to hold the output of the EncryptDanger operation. -// - ad is additional data, which will be authenticated and appended to out, but not encrypted. -// - plaintext is encrypted, authenticated and appended to out. -// - n is a nonce value which must never be re-used with this key. -// - nb is a buffer used for temporary storage in the implementation of this call, which should -// be re-used by callers to minimize garbage collection. -func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) { - if s != nil { - // TODO: Is this okay now that we have made messageCounter atomic? - // Alternative may be to split the counter space into ranges - //if n <= s.n { - // return nil, errors.New("CRITICAL: a duplicate counter value was used") - //} - //s.n = n - nb[0] = 0 - nb[1] = 0 - nb[2] = 0 - nb[3] = 0 - noiseEndianness.PutUint64(nb[4:], n) - out = s.c.Seal(out, nb, plaintext, ad) - //l.Debugf("Encryption: outlen: %d, nonce: %d, ad: %s, plainlen %d", len(out), n, ad, len(plaintext)) - return out, nil - } else { - return nil, errors.New("no cipher state available to encrypt") - } -} - -func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) { - if s != nil { - nb[0] = 0 - nb[1] = 0 - nb[2] = 0 - nb[3] = 0 - noiseEndianness.PutUint64(nb[4:], n) - return s.c.Open(out, nb, ciphertext, ad) - } else { - return []byte{}, nil - } -} - -func (s *NebulaCipherState) Overhead() int { - if s != nil { - return s.c.Overhead() - } - return 0 -} diff --git a/noiseutil/aesgcm.go b/noiseutil/aesgcm.go new file mode 100644 index 00000000..dcbd5693 --- /dev/null +++ b/noiseutil/aesgcm.go @@ -0,0 +1,53 @@ +package noiseutil + +import ( + "crypto/cipher" + "encoding/binary" + "errors" + + "github.com/flynn/noise" +) + +// CipherStateAESGCM is the data-plane wrapper for the AES-GCM AEAD cipher. +// AES-GCM uses big-endian nonce encoding per the Noise spec. +type CipherStateAESGCM struct { + c cipher.AEAD +} + +// NewCipherStateAESGCM extracts the underlying AEAD from the post-handshake noise.CipherState. +// The caller is responsible for ensuring the noise cipher is actually AES-GCM, +// otherwise the type assertion still succeeds but the nonce endianness will be wrong on the wire. +func NewCipherStateAESGCM(s *noise.CipherState) *CipherStateAESGCM { + return &CipherStateAESGCM{c: s.Cipher().(cipher.AEAD)} +} + +func (s *CipherStateAESGCM) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) { + if s == nil { + return nil, errors.New("no cipher state available to encrypt") + } + nb[0] = 0 + nb[1] = 0 + nb[2] = 0 + nb[3] = 0 + binary.BigEndian.PutUint64(nb[4:], n) + return s.c.Seal(out, nb, plaintext, ad), nil +} + +func (s *CipherStateAESGCM) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) { + if s == nil { + return []byte{}, nil + } + nb[0] = 0 + nb[1] = 0 + nb[2] = 0 + nb[3] = 0 + binary.BigEndian.PutUint64(nb[4:], n) + return s.c.Open(out, nb, ciphertext, ad) +} + +func (s *CipherStateAESGCM) Overhead() int { + if s == nil { + return 0 + } + return s.c.Overhead() +} diff --git a/noiseutil/chachapoly.go b/noiseutil/chachapoly.go new file mode 100644 index 00000000..31ab3bfe --- /dev/null +++ b/noiseutil/chachapoly.go @@ -0,0 +1,52 @@ +package noiseutil + +import ( + "crypto/cipher" + "encoding/binary" + "errors" + + "github.com/flynn/noise" +) + +// CipherStateChaChaPoly is the data-plane wrapper for the ChaCha20-Poly1305 AEAD cipher. +// ChaCha20-Poly1305 uses little-endian nonce encoding per the Noise spec. +type CipherStateChaChaPoly struct { + c cipher.AEAD +} + +// NewCipherStateChaChaPoly extracts the underlying AEAD from the post-handshake noise.CipherState. +// The caller is responsible for ensuring the noise cipher is actually ChaCha20-Poly1305. +func NewCipherStateChaChaPoly(s *noise.CipherState) *CipherStateChaChaPoly { + return &CipherStateChaChaPoly{c: s.Cipher().(cipher.AEAD)} +} + +func (s *CipherStateChaChaPoly) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) { + if s == nil { + return nil, errors.New("no cipher state available to encrypt") + } + nb[0] = 0 + nb[1] = 0 + nb[2] = 0 + nb[3] = 0 + binary.LittleEndian.PutUint64(nb[4:], n) + return s.c.Seal(out, nb, plaintext, ad), nil +} + +func (s *CipherStateChaChaPoly) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) { + if s == nil { + return []byte{}, nil + } + nb[0] = 0 + nb[1] = 0 + nb[2] = 0 + nb[3] = 0 + binary.LittleEndian.PutUint64(nb[4:], n) + return s.c.Open(out, nb, ciphertext, ad) +} + +func (s *CipherStateChaChaPoly) Overhead() int { + if s == nil { + return 0 + } + return s.c.Overhead() +} diff --git a/noiseutil/cipher_state.go b/noiseutil/cipher_state.go new file mode 100644 index 00000000..bb316385 --- /dev/null +++ b/noiseutil/cipher_state.go @@ -0,0 +1,40 @@ +package noiseutil + +import ( + "fmt" + + "github.com/flynn/noise" +) + +// CipherState is the post-handshake AEAD cipher used for the data plane. +// Each supported cipher has its own concrete implementation in this package with the nonce endianness hardcoded, +// so the encrypt/decrypt fast path avoids interface dispatch on the byte order. +type CipherState interface { + // EncryptDanger encrypts and authenticates a given payload. + // + // out is a destination slice to hold the output of the EncryptDanger operation. + // - ad is additional data, which will be authenticated and appended to out, but not encrypted. + // - plaintext is encrypted, authenticated and appended to out. + // - n is a nonce value which must never be re-used with this key. + // - nb is a scratch buffer used to assemble the nonce. + EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) + + // DecryptDanger authenticates and decrypts a given payload, with the same argument shape as EncryptDanger. + DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) + + // Overhead returns the AEAD tag size, or 0 if the receiver is nil. + Overhead() int +} + +// NewCipherState wraps the post-handshake noise.CipherState in the per-cipher type that matches cipherFunc. +// cipherFunc must be the same cipher used to build the noise CipherSuite that produced s. +func NewCipherState(s *noise.CipherState, cipherFunc noise.CipherFunc) CipherState { + switch cipherFunc.CipherName() { + case CipherAESGCM.CipherName(): + return NewCipherStateAESGCM(s) + case noise.CipherChaChaPoly.CipherName(): + return NewCipherStateChaChaPoly(s) + default: + panic(fmt.Sprintf("noiseutil: unsupported cipher %q", cipherFunc.CipherName())) + } +} diff --git a/noiseutil/cipher_state_test.go b/noiseutil/cipher_state_test.go new file mode 100644 index 00000000..a4df01e9 --- /dev/null +++ b/noiseutil/cipher_state_test.go @@ -0,0 +1,166 @@ +package noiseutil + +import ( + "testing" + + "github.com/flynn/noise" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCipherStateAESGCMRoundtrip(t *testing.T) { + enc, dec := buildCipherStates(t, CipherAESGCM) + roundtrip(t, NewCipherStateAESGCM(enc), NewCipherStateAESGCM(dec)) +} + +func TestCipherStateChaChaPolyRoundtrip(t *testing.T) { + enc, dec := buildCipherStates(t, noise.CipherChaChaPoly) + roundtrip(t, NewCipherStateChaChaPoly(enc), NewCipherStateChaChaPoly(dec)) +} + +func TestNewCipherStateDispatch(t *testing.T) { + encA, _ := buildCipherStates(t, CipherAESGCM) + encC, _ := buildCipherStates(t, noise.CipherChaChaPoly) + + assert.IsType(t, &CipherStateAESGCM{}, NewCipherState(encA, CipherAESGCM)) + assert.IsType(t, &CipherStateChaChaPoly{}, NewCipherState(encC, noise.CipherChaChaPoly)) +} + +func TestNewCipherStateUnsupportedPanics(t *testing.T) { + enc, _ := buildCipherStates(t, CipherAESGCM) + assert.Panics(t, func() { + NewCipherState(enc, fakeCipher{}) + }) +} + +type fakeCipher struct{} + +func (fakeCipher) Cipher(k [32]byte) noise.Cipher { return nil } +func (fakeCipher) CipherName() string { return "Fake" } + +// buildCipherStates runs an in-memory NN handshake with the requested cipher +// to produce a pair of post-handshake CipherStates that share keys. +func buildCipherStates(t *testing.T, c noise.CipherFunc) (*noise.CipherState, *noise.CipherState) { + t.Helper() + suite := noise.NewCipherSuite(noise.DH25519, c, noise.HashSHA256) + cfg := noise.Config{CipherSuite: suite, Pattern: noise.HandshakeNN} + cfg.Initiator = true + hsI, err := noise.NewHandshakeState(cfg) + require.NoError(t, err) + cfg.Initiator = false + hsR, err := noise.NewHandshakeState(cfg) + require.NoError(t, err) + + msg, _, _, err := hsI.WriteMessage(nil, nil) + require.NoError(t, err) + _, _, _, err = hsR.ReadMessage(nil, msg) + require.NoError(t, err) + + msg, dR, _, err := hsR.WriteMessage(nil, nil) + require.NoError(t, err) + _, eI, _, err := hsI.ReadMessage(nil, msg) + require.NoError(t, err) + require.NotNil(t, eI) + require.NotNil(t, dR) + + // noise returns (cs1, cs2) where cs1 is the initiator->responder cipher. + return eI, dR +} + +func roundtrip(t *testing.T, enc, dec CipherState) { + t.Helper() + plaintext := []byte("nebula cipher state roundtrip") + ad := []byte("aad") + nb := make([]byte, 12) + + ct, err := enc.EncryptDanger(nil, ad, plaintext, 1, nb) + require.NoError(t, err) + assert.NotEqual(t, plaintext, ct) + + pt, err := dec.DecryptDanger(nil, ad, ct, 1, nb) + require.NoError(t, err) + assert.Equal(t, plaintext, pt) + + // Wrong nonce must fail authentication. + _, err = dec.DecryptDanger(nil, ad, ct, 2, nb) + require.Error(t, err) + + assert.Equal(t, enc.Overhead(), dec.Overhead()) + assert.Equal(t, 16, enc.Overhead()) +} + +func BenchmarkCipherStateEncryptAESGCM(b *testing.B) { + enc, _ := buildCipherStatesB(b, CipherAESGCM) + benchEncryptCipherState(b, NewCipherState(enc, CipherAESGCM)) +} + +func BenchmarkCipherStateEncryptChaChaPoly(b *testing.B) { + enc, _ := buildCipherStatesB(b, noise.CipherChaChaPoly) + benchEncryptCipherState(b, NewCipherState(enc, noise.CipherChaChaPoly)) +} + +func benchEncryptCipherState(b *testing.B, cs CipherState) { + plaintext := make([]byte, 1280) + ad := make([]byte, 16) + nb := make([]byte, 12) + out := make([]byte, 0, len(plaintext)+cs.Overhead()) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + var err error + out, err = cs.EncryptDanger(out[:0], ad, plaintext, uint64(i+1), nb) + if err != nil { + b.Fatal(err) + } + } +} + +func buildCipherStatesB(b *testing.B, c noise.CipherFunc) (*noise.CipherState, *noise.CipherState) { + b.Helper() + suite := noise.NewCipherSuite(noise.DH25519, c, noise.HashSHA256) + cfg := noise.Config{CipherSuite: suite, Pattern: noise.HandshakeNN} + cfg.Initiator = true + hsI, err := noise.NewHandshakeState(cfg) + if err != nil { + b.Fatal(err) + } + cfg.Initiator = false + hsR, err := noise.NewHandshakeState(cfg) + if err != nil { + b.Fatal(err) + } + msg, _, _, err := hsI.WriteMessage(nil, nil) + if err != nil { + b.Fatal(err) + } + if _, _, _, err := hsR.ReadMessage(nil, msg); err != nil { + b.Fatal(err) + } + msg, dR, _, err := hsR.WriteMessage(nil, nil) + if err != nil { + b.Fatal(err) + } + _, eI, _, err := hsI.ReadMessage(nil, msg) + if err != nil { + b.Fatal(err) + } + return eI, dR +} + +func TestCipherStateNilSafety(t *testing.T) { + var aes *CipherStateAESGCM + _, err := aes.EncryptDanger(nil, nil, nil, 0, make([]byte, 12)) + require.Error(t, err) + out, err := aes.DecryptDanger(nil, nil, nil, 0, make([]byte, 12)) + require.NoError(t, err) + assert.Empty(t, out) + assert.Equal(t, 0, aes.Overhead()) + + var cc *CipherStateChaChaPoly + _, err = cc.EncryptDanger(nil, nil, nil, 0, make([]byte, 12)) + require.Error(t, err) + out, err = cc.DecryptDanger(nil, nil, nil, 0, make([]byte, 12)) + require.NoError(t, err) + assert.Empty(t, out) + assert.Equal(t, 0, cc.Overhead()) +} diff --git a/pki.go b/pki.go index acc80486..1bef5106 100644 --- a/pki.go +++ b/pki.go @@ -99,12 +99,10 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { var currentState *CertState if initial { cipher = c.GetString("cipher", "aes") - //TODO: this sucks and we should make it not a global switch cipher { - case "aes": - noiseEndianness = binary.BigEndian - case "chachapoly": - noiseEndianness = binary.LittleEndian + case "aes", "chachapoly": + // Each post-handshake CipherState in noiseutil hardcodes its own + // nonce endianness now, so there's nothing to set up here. default: return util.NewContextualError( "unknown cipher",