From cf3b7ec2fafa2ad9a582ff8fbd9da22fe3e6e9d3 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Wed, 14 Apr 2021 17:04:17 -0500 Subject: [PATCH] PSK Support --- connection_state.go | 13 ++-- examples/config.yml | 29 +++++++- handshake_ix.go | 58 ++++++++++----- inside.go | 24 +++--- interface.go | 25 ++++++- main.go | 10 ++- noise.go | 1 - psk.go | 178 ++++++++++++++++++++++++++++++++++++++++++++ psk_test.go | 78 +++++++++++++++++++ 9 files changed, 372 insertions(+), 44 deletions(-) create mode 100644 psk.go create mode 100644 psk_test.go diff --git a/connection_state.go b/connection_state.go index c28cc42..649669c 100644 --- a/connection_state.go +++ b/connection_state.go @@ -27,7 +27,7 @@ type ConnectionState struct { ready bool } -func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { +func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, p []byte) (*ConnectionState, error) { cs := noise.NewCipherSuite(noise.DH25519, noise.CipherAESGCM, noise.HashSHA256) if f.cipher == "chachapoly" { cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256) @@ -43,14 +43,15 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern hs, err := noise.NewHandshakeState(noise.Config{ CipherSuite: cs, Random: rand.Reader, - Pattern: pattern, + Pattern: noise.HandshakeIX, Initiator: initiator, StaticKeypair: static, - PresharedKey: psk, - PresharedKeyPlacement: pskStage, + PresharedKey: p, + PresharedKeyPlacement: 0, }) + if err != nil { - return nil + return nil, err } // The queue and ready params prevent a counter race that would happen when @@ -63,7 +64,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern certState: curCertState, } - return ci + return ci, nil } func (cs *ConnectionState) MarshalJSON() ([]byte, error) { diff --git a/examples/config.yml b/examples/config.yml index 87b7954..77f6ea1 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -215,17 +215,42 @@ logging: # e.g.: `lighthouse.rx.HostQuery` #lighthouse_metrics: false -# Handshake Manager Settings -#handshakes: +# Handshake Manger Settings +handshakes: # Handshakes are sent to all known addresses at each interval with a linear backoff, # Wait try_interval after the 1st attempt, 2 * try_interval after the 2nd, etc, until the handshake is older than timeout # A 100ms interval with the default 10 retries will give a handshake 5.5 seconds to resolve before timing out #try_interval: 100ms #retries: 20 + # trigger_buffer is the size of the buffer channel for quickly sending handshakes # after receiving the response for lighthouse queries #trigger_buffer: 64 + # pki can be used to mask the contents of handshakes and makes handshaking with unintended recipients more difficult + psk: + # mode defines the how pre shared keys can be used in a handshake + # `none` (the default) does not send or receive using a psk. Ideally `enforced` is used. + # `transitional` can receive handshakes using a psk that we know about, but we will not send any handshakes using a psk. + # This is helpful for transitioning to `enforced` and should be changed to `enforced` as soon as possible. + # Move every node in your mesh to `transitional` then you can move every node in your mesh to `enforced` without having to stop the world + # This assumes `keys` is the same on every node in your mesh + # `enforced` enforces the use of a psk for all tunnels. Any node not also using `enforced` or `transitional` will not be able to handshake with us + #mode: none + + # In `transitional` and `enforced` modes, the keys provided here are sent through hkdf with the intended recipients + # ip used in the info section. This helps guard against handshaking with the wrong host if your static_host_map or + # lighthouse(s) has incorrect information. + # + # Setting keys if mode is `none` has no effect. + # + # Only the first key is used for outbound handshakes but all keys provided will be tried in the order specified, on + # incoming handshakes. This is to allow for psk rotation. + #keys: + # - shared secret string, this one is used in all outbound handshakes + # - this is a fallback key, received handshakes can use this + # - another fallback, received handshakes can use this one too + # - "\x68\x65\x6c\x6c\x6f\x20\x66\x72\x69\x65\x6e\x64\x73" # for raw bytes if you desire # Nebula security group configuration firewall: diff --git a/handshake_ix.go b/handshake_ix.go index a0defc6..2222d13 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -4,7 +4,6 @@ import ( "sync/atomic" "time" - "github.com/flynn/noise" "github.com/golang/protobuf/proto" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" @@ -71,28 +70,51 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) { } func ixHandshakeStage1(f *Interface, addr *udp.Addr, packet []byte, h *header.H) { - ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0) - // Mark packet 1 as seen so it doesn't show up as missed - ci.window.Update(f.l, 1) - - msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:]) - if err != nil { - f.l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage") - return - } + var ( + err error + ci *ConnectionState + msg []byte + ) hs := &NebulaHandshake{} - err = proto.Unmarshal(msg, hs) - /* - l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex) - */ - if err != nil || hs.Details == nil { - f.l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") + + // Handle multiple possible psk options, ensure the protobuf comes out clean too + for _, p := range f.psk.Cache { + //TODO: benchmark generation time of makePsk + ci, err = f.newConnectionState(f.l, false, p) + if err != nil { + f.l.WithError(err).WithField("udpAddr", addr).Error("Failed to get a new connection state") + continue + } + + msg, _, _, err = ci.H.ReadMessage(nil, packet[header.Len:]) + if err != nil { + // Calls to ReadMessage with an incorrect psk should fail, try the next one if we have one + continue + } + + // Sometimes ReadMessage returns fine with a nil psk even if the handshake is using a psk, ensure our protobuf + // comes out clean as well + err = proto.Unmarshal(msg, hs) + if err == nil { + // There was no error, we can continue with this handshake + break + } + + // The unmarshal failed, try the next psk if we have one + } + + // We finished with an error, log it and get out + if err != nil { + // We aren't logging the error here because we can't be sure of the failure when using psk + f.l.WithField("udpAddr", addr).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Error("Was unable to decrypt the handshake") return } + // Mark packet 1 as seen so it doesn't show up as missed + ci.window.Update(f.l, 1) + remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool) if err != nil { f.l.WithError(err).WithField("udpAddr", addr). diff --git a/inside.go b/inside.go index 8a7c990..0e62b78 100644 --- a/inside.go +++ b/inside.go @@ -3,7 +3,6 @@ package nebula import ( "sync/atomic" - "github.com/flynn/noise" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -79,7 +78,6 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo { } hostinfo, err := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f) - //if err != nil || hostinfo.ConnectionState == nil { if err != nil { hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp) if err != nil { @@ -102,21 +100,27 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo { return hostinfo } + // Create a connection state if we don't have one yet if ci == nil { - // if we don't have a connection state, then send a handshake initiation - ci = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0) - // FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us. - //ci = f.newConnectionState(true, noise.HandshakeXX, []byte{}, 0) + // Generate a PSK based on our config, this may be nil + p, err := f.psk.MakeFor(vpnIp) + if err != nil { + //TODO: This isn't fatal specifically but it's pretty bad + f.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to get a PSK KDF") + return hostinfo + } + + ci, err = f.newConnectionState(f.l, true, p) + if err != nil { + f.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to get a connection state") + return hostinfo + } hostinfo.ConnectionState = ci - } else if ci.eKey == nil { - // if we don't have any state at all, create it } // If we have already created the handshake packet, we don't want to call the function at all. if !hostinfo.HandshakeReady { ixHandshakeStage0(f, vpnIp, hostinfo) - // FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us. - //xx_handshakeStage0(f, ip, hostinfo) // If this is a static host, we don't need to wait for the HostQueryReply // We can trigger the handshake right now diff --git a/interface.go b/interface.go index c95a354..a75068b 100644 --- a/interface.go +++ b/interface.go @@ -7,7 +7,9 @@ import ( "net" "os" "runtime" + "sync/atomic" "time" + "unsafe" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" @@ -48,6 +50,7 @@ type InterfaceConfig struct { version string caPool *cert.NebulaCAPool disconnectInvalid bool + psk *Psk ConntrackCacheTimeout time.Duration l *logrus.Logger @@ -78,9 +81,9 @@ type Interface struct { version string conntrackCacheTimeout time.Duration - - writers []*udp.Conn - readers []io.ReadWriteCloser + psk *Psk + writers []*udp.Conn + readers []io.ReadWriteCloser metricHandshakes metrics.Histogram messageMetrics *MessageMetrics @@ -104,6 +107,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { } myVpnIp := iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].IP) + ifce := &Interface{ hostMap: c.HostMap, outside: c.Outside, @@ -124,6 +128,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { readers: make([]io.ReadWriteCloser, c.routines), caPool: c.caPool, disconnectInvalid: c.disconnectInvalid, + psk: c.psk, myVpnIp: myVpnIp, conntrackCacheTimeout: c.ConntrackCacheTimeout, @@ -234,6 +239,7 @@ func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { for _, udpConn := range f.writers { c.RegisterReloadCallback(udpConn.ReloadConfig) } + c.RegisterReloadCallback(f.reloadPSKs) } func (f *Interface) reloadCA(c *config.C) { @@ -308,6 +314,19 @@ func (f *Interface) reloadFirewall(c *config.C) { Info("New firewall has been installed") } +func (f *Interface) reloadPSKs(c *config.C) { + psk, err := NewPskFromConfig(c, f.myVpnIp) + if err != nil { + f.l.WithError(err).Error("Error while reloading PSKs") + return + } + + atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&f.psk)), unsafe.Pointer(psk)) + + f.l.WithField("pskMode", psk.mode).WithField("keysLen", len(psk.Cache)). + Info("New psks are in use") +} + func (f *Interface) emitStats(ctx context.Context, i time.Duration) { ticker := time.NewTicker(i) defer ticker.Stop() diff --git a/main.go b/main.go index 91418e1..b15ea99 100644 --- a/main.go +++ b/main.go @@ -95,6 +95,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } } + psk, err := NewPskFromConfig(c, iputil.Ip2VpnIp(tunCidr.IP)) + if err != nil { + return nil, NewContextualError("Failed to create psk", nil, err) + } + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // All non system modifying configuration consumption should live above this line // tun config, listeners, anything modifying the computer should be below @@ -356,10 +361,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg handshakeManager := NewHandshakeManager(l, tunCidr, preferredRanges, hostMap, lightHouse, udpConns[0], handshakeConfig) lightHouse.handshakeTrigger = handshakeManager.trigger - //TODO: These will be reused for psk - //handshakeMACKey := config.GetString("handshake_mac.key", "") - //handshakeAcceptedMACKeys := config.GetStringSlice("handshake_mac.accepted_keys", []string{}) - serveDns := false if c.GetBool("lighthouse.serve_dns", false) { if c.GetBool("lighthouse.am_lighthouse", false) { @@ -390,6 +391,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg version: buildVersion, caPool: caPool, disconnectInvalid: c.GetBool("pki.disconnect_invalid", false), + psk: psk, ConntrackCacheTimeout: conntrackCacheTimeout, l: l, diff --git a/noise.go b/noise.go index 543bb52..55a5e38 100644 --- a/noise.go +++ b/noise.go @@ -22,7 +22,6 @@ type NebulaCipherState struct { func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState { return &NebulaCipherState{c: s.Cipher()} - } func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) { diff --git a/psk.go b/psk.go new file mode 100644 index 0000000..cd61236 --- /dev/null +++ b/psk.go @@ -0,0 +1,178 @@ +package nebula + +import ( + "crypto/sha256" + "errors" + "fmt" + "io" + + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/iputil" + "golang.org/x/crypto/hkdf" +) + +var ErrNotAPskMode = errors.New("not a psk mode") +var ErrKeyTooShort = errors.New("key is too short") +var ErrNotEnoughPskKeys = errors.New("at least 1 key is required") + +// The minimum length that we accept for a user defined psk, the choice is arbitrary +const MinPskLength = 8 + +type PskMode int + +func (p PskMode) String() string { + switch p { + case PskNone: + return "none" + case PskTransitional: + return "transitional" + case PskEnforced: + return "enforced" + } + + return "unknown" +} + +func NewPskMode(m string) (PskMode, error) { + switch m { + case "none": + return PskNone, nil + case "transitional": + return PskTransitional, nil + case "enforced": + return PskEnforced, nil + } + return PskNone, ErrNotAPskMode +} + +const ( + PskNone PskMode = 0 + PskTransitional PskMode = 1 + PskEnforced PskMode = 2 +) + +type Psk struct { + // pskMode sets how psk works, ignored, allowed for incoming, or enforced for all + mode PskMode + + // Cache holds all pre-computed psk hkdfs + // Handshakes iterate this directly + Cache [][]byte + + // The key has already been extracted and is ready to be expanded for use + // MakeFor does the final expand step mixing in the intended recipients vpn ip + key []byte +} + +// NewPskFromConfig is a helper for initial boot and config reloading. +func NewPskFromConfig(c *config.C, myVpnIp iputil.VpnIp) (*Psk, error) { + sMode := c.GetString("handshakes.psk.mode", "none") + mode, err := NewPskMode(sMode) + if err != nil { + return nil, NewContextualError("Could not parse handshakes.psk.mode", m{"mode": mode}, err) + } + + return NewPsk( + mode, + c.GetStringSlice("handshakes.psk.keys", nil), + myVpnIp, + ) +} + +// NewPsk creates a new Psk object and handles the caching of all accepted keys and preparation of the primary key +func NewPsk(mode PskMode, keys []string, myVpnIp iputil.VpnIp) (*Psk, error) { + psk := &Psk{ + mode: mode, + } + + err := psk.preparePrimaryKey(keys) + if err != nil { + return nil, err + } + + err = psk.cachePsks(myVpnIp, keys) + if err != nil { + return nil, err + } + + return psk, nil +} + +// MakeFor if we are in enforced mode, the final hkdf expand stage is done on the pre extracted primary key, +// mixing in the intended recipients vpn ip and the result is returned. +// If we are transitional or not using psks, an empty byte slice is returned +func (p *Psk) MakeFor(vpnIp iputil.VpnIp) ([]byte, error) { + if p.mode != PskEnforced { + return []byte{}, nil + } + + hmacKey := make([]byte, sha256.Size) + _, err := io.ReadFull(hkdf.Expand(sha256.New, p.key, vpnIp.ToIP()), hmacKey) + if err != nil { + return nil, err + } + + return hmacKey, nil +} + +// cachePsks generates all psks we accept and caches them to speed up handshaking +func (p *Psk) cachePsks(myVpnIp iputil.VpnIp, keys []string) error { + // If PskNone is set then we are using the nil byte array for a psk, we can return + if p.mode == PskNone { + p.Cache = [][]byte{nil} + return nil + } + + if len(keys) < 1 { + return ErrNotEnoughPskKeys + } + + p.Cache = [][]byte{} + + if p.mode == PskTransitional { + // We are transitional, we accept empty psks + p.Cache = append(p.Cache, nil) + } + + // We are either PskAuto or PskTransitional, build all possibilities + for i, rk := range keys { + k, err := sha256KdfFromString(rk, myVpnIp) + if err != nil { + return fmt.Errorf("failed to generate key for position %v: %w", i, err) + } + + p.Cache = append(p.Cache, k) + } + + return nil +} + +// preparePrimaryKey if we are in enforced mode, will do an hkdf extract on the first key to benefit +// outgoing handshake performance, MakeFor does the final expand step +func (p *Psk) preparePrimaryKey(keys []string) error { + if p.mode != PskEnforced { + // If we aren't enforcing then there is nothing to prepare + return nil + } + + if len(keys) < 1 { + return ErrNotEnoughPskKeys + } + + p.key = hkdf.Extract(sha256.New, []byte(keys[0]), nil) + return nil +} + +// sha256KdfFromString generates a full hkdf +func sha256KdfFromString(secret string, vpnIp iputil.VpnIp) ([]byte, error) { + if len(secret) < MinPskLength { + return nil, ErrKeyTooShort + } + + hmacKey := make([]byte, sha256.Size) + _, err := io.ReadFull(hkdf.New(sha256.New, []byte(secret), nil, vpnIp.ToIP()), hmacKey) + if err != nil { + return nil, err + } + return hmacKey, nil +} diff --git a/psk_test.go b/psk_test.go new file mode 100644 index 0000000..88b7f27 --- /dev/null +++ b/psk_test.go @@ -0,0 +1,78 @@ +package nebula + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewPsk(t *testing.T) { + t.Run("mode none", func(t *testing.T) { + p, err := NewPsk(PskNone, nil, 1) + assert.NoError(t, err) + assert.Equal(t, PskNone, p.mode) + assert.Empty(t, p.key) + + assert.Len(t, p.Cache, 1) + assert.Nil(t, p.Cache[0]) + + b, err := p.MakeFor(0) + assert.Equal(t, []byte{}, b) + }) + + t.Run("mode transitional", func(t *testing.T) { + p, err := NewPsk(PskTransitional, nil, 1) + assert.Error(t, ErrNotEnoughPskKeys, err) + + p, err = NewPsk(PskTransitional, []string{"1234567"}, 1) + assert.Error(t, ErrKeyTooShort) + + p, err = NewPsk(PskTransitional, []string{"hi there friends"}, 1) + assert.NoError(t, err) + assert.Equal(t, PskTransitional, p.mode) + assert.Empty(t, p.key) + + assert.Len(t, p.Cache, 2) + assert.Nil(t, p.Cache[0]) + + expectedCache := []byte{146, 120, 135, 31, 158, 102, 45, 189, 128, 190, 37, 101, 58, 254, 6, 166, 91, 209, 148, 131, 27, 193, 24, 25, 170, 65, 130, 189, 7, 179, 255, 17} + assert.Equal(t, expectedCache, p.Cache[1]) + + b, err := p.MakeFor(0) + assert.Equal(t, []byte{}, b) + }) + + t.Run("mode enforced", func(t *testing.T) { + p, err := NewPsk(PskEnforced, nil, 1) + assert.Error(t, ErrNotEnoughPskKeys, err) + + p, err = NewPsk(PskEnforced, []string{"hi there friends"}, 1) + assert.NoError(t, err) + assert.Equal(t, PskEnforced, p.mode) + + expectedKey := []byte{156, 103, 171, 88, 121, 92, 138, 240, 170, 240, 76, 108, 154, 66, 107, 14, 226, 148, 177, 0, 40, 28, 220, 136, 68, 53, 63, 183, 213, 9, 192, 218} + assert.Equal(t, expectedKey, p.key) + + assert.Len(t, p.Cache, 1) + expectedCache := []byte{146, 120, 135, 31, 158, 102, 45, 189, 128, 190, 37, 101, 58, 254, 6, 166, 91, 209, 148, 131, 27, 193, 24, 25, 170, 65, 130, 189, 7, 179, 255, 17} + assert.Equal(t, expectedCache, p.Cache[0]) + + expectedPsk := []byte{0xd9, 0x16, 0xa3, 0x66, 0x6a, 0x20, 0x26, 0xcf, 0x5d, 0x93, 0xad, 0xa3, 0x88, 0x2d, 0x57, 0xac, 0x9b, 0xc3, 0x5a, 0xb7, 0x8f, 0x6, 0x71, 0xc4, 0x3e, 0x5, 0x9e, 0xbc, 0x4e, 0xc8, 0x24, 0x17} + b, err := p.MakeFor(0) + assert.Equal(t, expectedPsk, b) + + // Make sure different vpn ips generate different psks + expectedPsk = []byte{0x92, 0x78, 0x87, 0x1f, 0x9e, 0x66, 0x2d, 0xbd, 0x80, 0xbe, 0x25, 0x65, 0x3a, 0xfe, 0x6, 0xa6, 0x5b, 0xd1, 0x94, 0x83, 0x1b, 0xc1, 0x18, 0x19, 0xaa, 0x41, 0x82, 0xbd, 0x7, 0xb3, 0xff, 0x11} + b, err = p.MakeFor(1) + assert.Equal(t, expectedPsk, b) + }) +} + +func BenchmarkPsk_MakeFor(b *testing.B) { + p, err := NewPsk(PskEnforced, []string{"hi there friends"}, 1) + assert.NoError(b, err) + + for n := 0; n < b.N; n++ { + p.MakeFor(99) + } +}