PSK Support

This commit is contained in:
Nate Brown 2021-04-14 17:04:17 -05:00
parent a22c134bf5
commit cf3b7ec2fa
9 changed files with 372 additions and 44 deletions

View File

@ -27,7 +27,7 @@ type ConnectionState struct {
ready bool
}
func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, p []byte) (*ConnectionState, error) {
cs := noise.NewCipherSuite(noise.DH25519, noise.CipherAESGCM, noise.HashSHA256)
if f.cipher == "chachapoly" {
cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
@ -43,14 +43,15 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern
hs, err := noise.NewHandshakeState(noise.Config{
CipherSuite: cs,
Random: rand.Reader,
Pattern: pattern,
Pattern: noise.HandshakeIX,
Initiator: initiator,
StaticKeypair: static,
PresharedKey: psk,
PresharedKeyPlacement: pskStage,
PresharedKey: p,
PresharedKeyPlacement: 0,
})
if err != nil {
return nil
return nil, err
}
// The queue and ready params prevent a counter race that would happen when
@ -63,7 +64,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern
certState: curCertState,
}
return ci
return ci, nil
}
func (cs *ConnectionState) MarshalJSON() ([]byte, error) {

View File

@ -215,17 +215,42 @@ logging:
# e.g.: `lighthouse.rx.HostQuery`
#lighthouse_metrics: false
# Handshake Manager Settings
#handshakes:
# Handshake Manger Settings
handshakes:
# Handshakes are sent to all known addresses at each interval with a linear backoff,
# Wait try_interval after the 1st attempt, 2 * try_interval after the 2nd, etc, until the handshake is older than timeout
# A 100ms interval with the default 10 retries will give a handshake 5.5 seconds to resolve before timing out
#try_interval: 100ms
#retries: 20
# trigger_buffer is the size of the buffer channel for quickly sending handshakes
# after receiving the response for lighthouse queries
#trigger_buffer: 64
# pki can be used to mask the contents of handshakes and makes handshaking with unintended recipients more difficult
psk:
# mode defines the how pre shared keys can be used in a handshake
# `none` (the default) does not send or receive using a psk. Ideally `enforced` is used.
# `transitional` can receive handshakes using a psk that we know about, but we will not send any handshakes using a psk.
# This is helpful for transitioning to `enforced` and should be changed to `enforced` as soon as possible.
# Move every node in your mesh to `transitional` then you can move every node in your mesh to `enforced` without having to stop the world
# This assumes `keys` is the same on every node in your mesh
# `enforced` enforces the use of a psk for all tunnels. Any node not also using `enforced` or `transitional` will not be able to handshake with us
#mode: none
# In `transitional` and `enforced` modes, the keys provided here are sent through hkdf with the intended recipients
# ip used in the info section. This helps guard against handshaking with the wrong host if your static_host_map or
# lighthouse(s) has incorrect information.
#
# Setting keys if mode is `none` has no effect.
#
# Only the first key is used for outbound handshakes but all keys provided will be tried in the order specified, on
# incoming handshakes. This is to allow for psk rotation.
#keys:
# - shared secret string, this one is used in all outbound handshakes
# - this is a fallback key, received handshakes can use this
# - another fallback, received handshakes can use this one too
# - "\x68\x65\x6c\x6c\x6f\x20\x66\x72\x69\x65\x6e\x64\x73" # for raw bytes if you desire
# Nebula security group configuration
firewall:

View File

@ -4,7 +4,6 @@ import (
"sync/atomic"
"time"
"github.com/flynn/noise"
"github.com/golang/protobuf/proto"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
@ -71,28 +70,51 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
}
func ixHandshakeStage1(f *Interface, addr *udp.Addr, packet []byte, h *header.H) {
ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
// Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(f.l, 1)
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
return
}
var (
err error
ci *ConnectionState
msg []byte
)
hs := &NebulaHandshake{}
// Handle multiple possible psk options, ensure the protobuf comes out clean too
for _, p := range f.psk.Cache {
//TODO: benchmark generation time of makePsk
ci, err = f.newConnectionState(f.l, false, p)
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).Error("Failed to get a new connection state")
continue
}
msg, _, _, err = ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
// Calls to ReadMessage with an incorrect psk should fail, try the next one if we have one
continue
}
// Sometimes ReadMessage returns fine with a nil psk even if the handshake is using a psk, ensure our protobuf
// comes out clean as well
err = proto.Unmarshal(msg, hs)
/*
l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
*/
if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
if err == nil {
// There was no error, we can continue with this handshake
break
}
// The unmarshal failed, try the next psk if we have one
}
// We finished with an error, log it and get out
if err != nil {
// We aren't logging the error here because we can't be sure of the failure when using psk
f.l.WithField("udpAddr", addr).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Was unable to decrypt the handshake")
return
}
// Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(f.l, 1)
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).

View File

@ -3,7 +3,6 @@ package nebula
import (
"sync/atomic"
"github.com/flynn/noise"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
@ -79,7 +78,6 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
}
hostinfo, err := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f)
//if err != nil || hostinfo.ConnectionState == nil {
if err != nil {
hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp)
if err != nil {
@ -102,21 +100,27 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
return hostinfo
}
// Create a connection state if we don't have one yet
if ci == nil {
// if we don't have a connection state, then send a handshake initiation
ci = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
//ci = f.newConnectionState(true, noise.HandshakeXX, []byte{}, 0)
// Generate a PSK based on our config, this may be nil
p, err := f.psk.MakeFor(vpnIp)
if err != nil {
//TODO: This isn't fatal specifically but it's pretty bad
f.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to get a PSK KDF")
return hostinfo
}
ci, err = f.newConnectionState(f.l, true, p)
if err != nil {
f.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to get a connection state")
return hostinfo
}
hostinfo.ConnectionState = ci
} else if ci.eKey == nil {
// if we don't have any state at all, create it
}
// If we have already created the handshake packet, we don't want to call the function at all.
if !hostinfo.HandshakeReady {
ixHandshakeStage0(f, vpnIp, hostinfo)
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
//xx_handshakeStage0(f, ip, hostinfo)
// If this is a static host, we don't need to wait for the HostQueryReply
// We can trigger the handshake right now

View File

@ -7,7 +7,9 @@ import (
"net"
"os"
"runtime"
"sync/atomic"
"time"
"unsafe"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
@ -48,6 +50,7 @@ type InterfaceConfig struct {
version string
caPool *cert.NebulaCAPool
disconnectInvalid bool
psk *Psk
ConntrackCacheTimeout time.Duration
l *logrus.Logger
@ -78,7 +81,7 @@ type Interface struct {
version string
conntrackCacheTimeout time.Duration
psk *Psk
writers []*udp.Conn
readers []io.ReadWriteCloser
@ -104,6 +107,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
}
myVpnIp := iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].IP)
ifce := &Interface{
hostMap: c.HostMap,
outside: c.Outside,
@ -124,6 +128,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
readers: make([]io.ReadWriteCloser, c.routines),
caPool: c.caPool,
disconnectInvalid: c.disconnectInvalid,
psk: c.psk,
myVpnIp: myVpnIp,
conntrackCacheTimeout: c.ConntrackCacheTimeout,
@ -234,6 +239,7 @@ func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
for _, udpConn := range f.writers {
c.RegisterReloadCallback(udpConn.ReloadConfig)
}
c.RegisterReloadCallback(f.reloadPSKs)
}
func (f *Interface) reloadCA(c *config.C) {
@ -308,6 +314,19 @@ func (f *Interface) reloadFirewall(c *config.C) {
Info("New firewall has been installed")
}
func (f *Interface) reloadPSKs(c *config.C) {
psk, err := NewPskFromConfig(c, f.myVpnIp)
if err != nil {
f.l.WithError(err).Error("Error while reloading PSKs")
return
}
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&f.psk)), unsafe.Pointer(psk))
f.l.WithField("pskMode", psk.mode).WithField("keysLen", len(psk.Cache)).
Info("New psks are in use")
}
func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
ticker := time.NewTicker(i)
defer ticker.Stop()

10
main.go
View File

@ -95,6 +95,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
}
psk, err := NewPskFromConfig(c, iputil.Ip2VpnIp(tunCidr.IP))
if err != nil {
return nil, NewContextualError("Failed to create psk", nil, err)
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// All non system modifying configuration consumption should live above this line
// tun config, listeners, anything modifying the computer should be below
@ -356,10 +361,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
handshakeManager := NewHandshakeManager(l, tunCidr, preferredRanges, hostMap, lightHouse, udpConns[0], handshakeConfig)
lightHouse.handshakeTrigger = handshakeManager.trigger
//TODO: These will be reused for psk
//handshakeMACKey := config.GetString("handshake_mac.key", "")
//handshakeAcceptedMACKeys := config.GetStringSlice("handshake_mac.accepted_keys", []string{})
serveDns := false
if c.GetBool("lighthouse.serve_dns", false) {
if c.GetBool("lighthouse.am_lighthouse", false) {
@ -390,6 +391,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
version: buildVersion,
caPool: caPool,
disconnectInvalid: c.GetBool("pki.disconnect_invalid", false),
psk: psk,
ConntrackCacheTimeout: conntrackCacheTimeout,
l: l,

View File

@ -22,7 +22,6 @@ type NebulaCipherState struct {
func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState {
return &NebulaCipherState{c: s.Cipher()}
}
func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) {

178
psk.go Normal file
View File

@ -0,0 +1,178 @@
package nebula
import (
"crypto/sha256"
"errors"
"fmt"
"io"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
"golang.org/x/crypto/hkdf"
)
var ErrNotAPskMode = errors.New("not a psk mode")
var ErrKeyTooShort = errors.New("key is too short")
var ErrNotEnoughPskKeys = errors.New("at least 1 key is required")
// The minimum length that we accept for a user defined psk, the choice is arbitrary
const MinPskLength = 8
type PskMode int
func (p PskMode) String() string {
switch p {
case PskNone:
return "none"
case PskTransitional:
return "transitional"
case PskEnforced:
return "enforced"
}
return "unknown"
}
func NewPskMode(m string) (PskMode, error) {
switch m {
case "none":
return PskNone, nil
case "transitional":
return PskTransitional, nil
case "enforced":
return PskEnforced, nil
}
return PskNone, ErrNotAPskMode
}
const (
PskNone PskMode = 0
PskTransitional PskMode = 1
PskEnforced PskMode = 2
)
type Psk struct {
// pskMode sets how psk works, ignored, allowed for incoming, or enforced for all
mode PskMode
// Cache holds all pre-computed psk hkdfs
// Handshakes iterate this directly
Cache [][]byte
// The key has already been extracted and is ready to be expanded for use
// MakeFor does the final expand step mixing in the intended recipients vpn ip
key []byte
}
// NewPskFromConfig is a helper for initial boot and config reloading.
func NewPskFromConfig(c *config.C, myVpnIp iputil.VpnIp) (*Psk, error) {
sMode := c.GetString("handshakes.psk.mode", "none")
mode, err := NewPskMode(sMode)
if err != nil {
return nil, NewContextualError("Could not parse handshakes.psk.mode", m{"mode": mode}, err)
}
return NewPsk(
mode,
c.GetStringSlice("handshakes.psk.keys", nil),
myVpnIp,
)
}
// NewPsk creates a new Psk object and handles the caching of all accepted keys and preparation of the primary key
func NewPsk(mode PskMode, keys []string, myVpnIp iputil.VpnIp) (*Psk, error) {
psk := &Psk{
mode: mode,
}
err := psk.preparePrimaryKey(keys)
if err != nil {
return nil, err
}
err = psk.cachePsks(myVpnIp, keys)
if err != nil {
return nil, err
}
return psk, nil
}
// MakeFor if we are in enforced mode, the final hkdf expand stage is done on the pre extracted primary key,
// mixing in the intended recipients vpn ip and the result is returned.
// If we are transitional or not using psks, an empty byte slice is returned
func (p *Psk) MakeFor(vpnIp iputil.VpnIp) ([]byte, error) {
if p.mode != PskEnforced {
return []byte{}, nil
}
hmacKey := make([]byte, sha256.Size)
_, err := io.ReadFull(hkdf.Expand(sha256.New, p.key, vpnIp.ToIP()), hmacKey)
if err != nil {
return nil, err
}
return hmacKey, nil
}
// cachePsks generates all psks we accept and caches them to speed up handshaking
func (p *Psk) cachePsks(myVpnIp iputil.VpnIp, keys []string) error {
// If PskNone is set then we are using the nil byte array for a psk, we can return
if p.mode == PskNone {
p.Cache = [][]byte{nil}
return nil
}
if len(keys) < 1 {
return ErrNotEnoughPskKeys
}
p.Cache = [][]byte{}
if p.mode == PskTransitional {
// We are transitional, we accept empty psks
p.Cache = append(p.Cache, nil)
}
// We are either PskAuto or PskTransitional, build all possibilities
for i, rk := range keys {
k, err := sha256KdfFromString(rk, myVpnIp)
if err != nil {
return fmt.Errorf("failed to generate key for position %v: %w", i, err)
}
p.Cache = append(p.Cache, k)
}
return nil
}
// preparePrimaryKey if we are in enforced mode, will do an hkdf extract on the first key to benefit
// outgoing handshake performance, MakeFor does the final expand step
func (p *Psk) preparePrimaryKey(keys []string) error {
if p.mode != PskEnforced {
// If we aren't enforcing then there is nothing to prepare
return nil
}
if len(keys) < 1 {
return ErrNotEnoughPskKeys
}
p.key = hkdf.Extract(sha256.New, []byte(keys[0]), nil)
return nil
}
// sha256KdfFromString generates a full hkdf
func sha256KdfFromString(secret string, vpnIp iputil.VpnIp) ([]byte, error) {
if len(secret) < MinPskLength {
return nil, ErrKeyTooShort
}
hmacKey := make([]byte, sha256.Size)
_, err := io.ReadFull(hkdf.New(sha256.New, []byte(secret), nil, vpnIp.ToIP()), hmacKey)
if err != nil {
return nil, err
}
return hmacKey, nil
}

78
psk_test.go Normal file
View File

@ -0,0 +1,78 @@
package nebula
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestNewPsk(t *testing.T) {
t.Run("mode none", func(t *testing.T) {
p, err := NewPsk(PskNone, nil, 1)
assert.NoError(t, err)
assert.Equal(t, PskNone, p.mode)
assert.Empty(t, p.key)
assert.Len(t, p.Cache, 1)
assert.Nil(t, p.Cache[0])
b, err := p.MakeFor(0)
assert.Equal(t, []byte{}, b)
})
t.Run("mode transitional", func(t *testing.T) {
p, err := NewPsk(PskTransitional, nil, 1)
assert.Error(t, ErrNotEnoughPskKeys, err)
p, err = NewPsk(PskTransitional, []string{"1234567"}, 1)
assert.Error(t, ErrKeyTooShort)
p, err = NewPsk(PskTransitional, []string{"hi there friends"}, 1)
assert.NoError(t, err)
assert.Equal(t, PskTransitional, p.mode)
assert.Empty(t, p.key)
assert.Len(t, p.Cache, 2)
assert.Nil(t, p.Cache[0])
expectedCache := []byte{146, 120, 135, 31, 158, 102, 45, 189, 128, 190, 37, 101, 58, 254, 6, 166, 91, 209, 148, 131, 27, 193, 24, 25, 170, 65, 130, 189, 7, 179, 255, 17}
assert.Equal(t, expectedCache, p.Cache[1])
b, err := p.MakeFor(0)
assert.Equal(t, []byte{}, b)
})
t.Run("mode enforced", func(t *testing.T) {
p, err := NewPsk(PskEnforced, nil, 1)
assert.Error(t, ErrNotEnoughPskKeys, err)
p, err = NewPsk(PskEnforced, []string{"hi there friends"}, 1)
assert.NoError(t, err)
assert.Equal(t, PskEnforced, p.mode)
expectedKey := []byte{156, 103, 171, 88, 121, 92, 138, 240, 170, 240, 76, 108, 154, 66, 107, 14, 226, 148, 177, 0, 40, 28, 220, 136, 68, 53, 63, 183, 213, 9, 192, 218}
assert.Equal(t, expectedKey, p.key)
assert.Len(t, p.Cache, 1)
expectedCache := []byte{146, 120, 135, 31, 158, 102, 45, 189, 128, 190, 37, 101, 58, 254, 6, 166, 91, 209, 148, 131, 27, 193, 24, 25, 170, 65, 130, 189, 7, 179, 255, 17}
assert.Equal(t, expectedCache, p.Cache[0])
expectedPsk := []byte{0xd9, 0x16, 0xa3, 0x66, 0x6a, 0x20, 0x26, 0xcf, 0x5d, 0x93, 0xad, 0xa3, 0x88, 0x2d, 0x57, 0xac, 0x9b, 0xc3, 0x5a, 0xb7, 0x8f, 0x6, 0x71, 0xc4, 0x3e, 0x5, 0x9e, 0xbc, 0x4e, 0xc8, 0x24, 0x17}
b, err := p.MakeFor(0)
assert.Equal(t, expectedPsk, b)
// Make sure different vpn ips generate different psks
expectedPsk = []byte{0x92, 0x78, 0x87, 0x1f, 0x9e, 0x66, 0x2d, 0xbd, 0x80, 0xbe, 0x25, 0x65, 0x3a, 0xfe, 0x6, 0xa6, 0x5b, 0xd1, 0x94, 0x83, 0x1b, 0xc1, 0x18, 0x19, 0xaa, 0x41, 0x82, 0xbd, 0x7, 0xb3, 0xff, 0x11}
b, err = p.MakeFor(1)
assert.Equal(t, expectedPsk, b)
})
}
func BenchmarkPsk_MakeFor(b *testing.B) {
p, err := NewPsk(PskEnforced, []string{"hi there friends"}, 1)
assert.NoError(b, err)
for n := 0; n < b.N; n++ {
p.MakeFor(99)
}
}