mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-11 21:03:57 +01:00
PSK Support
This commit is contained in:
parent
a22c134bf5
commit
cf3b7ec2fa
@ -27,7 +27,7 @@ type ConnectionState struct {
|
||||
ready bool
|
||||
}
|
||||
|
||||
func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
|
||||
func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, p []byte) (*ConnectionState, error) {
|
||||
cs := noise.NewCipherSuite(noise.DH25519, noise.CipherAESGCM, noise.HashSHA256)
|
||||
if f.cipher == "chachapoly" {
|
||||
cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||
@ -43,14 +43,15 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern
|
||||
hs, err := noise.NewHandshakeState(noise.Config{
|
||||
CipherSuite: cs,
|
||||
Random: rand.Reader,
|
||||
Pattern: pattern,
|
||||
Pattern: noise.HandshakeIX,
|
||||
Initiator: initiator,
|
||||
StaticKeypair: static,
|
||||
PresharedKey: psk,
|
||||
PresharedKeyPlacement: pskStage,
|
||||
PresharedKey: p,
|
||||
PresharedKeyPlacement: 0,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// The queue and ready params prevent a counter race that would happen when
|
||||
@ -63,7 +64,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern
|
||||
certState: curCertState,
|
||||
}
|
||||
|
||||
return ci
|
||||
return ci, nil
|
||||
}
|
||||
|
||||
func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
|
||||
|
||||
@ -215,17 +215,42 @@ logging:
|
||||
# e.g.: `lighthouse.rx.HostQuery`
|
||||
#lighthouse_metrics: false
|
||||
|
||||
# Handshake Manager Settings
|
||||
#handshakes:
|
||||
# Handshake Manger Settings
|
||||
handshakes:
|
||||
# Handshakes are sent to all known addresses at each interval with a linear backoff,
|
||||
# Wait try_interval after the 1st attempt, 2 * try_interval after the 2nd, etc, until the handshake is older than timeout
|
||||
# A 100ms interval with the default 10 retries will give a handshake 5.5 seconds to resolve before timing out
|
||||
#try_interval: 100ms
|
||||
#retries: 20
|
||||
|
||||
# trigger_buffer is the size of the buffer channel for quickly sending handshakes
|
||||
# after receiving the response for lighthouse queries
|
||||
#trigger_buffer: 64
|
||||
|
||||
# pki can be used to mask the contents of handshakes and makes handshaking with unintended recipients more difficult
|
||||
psk:
|
||||
# mode defines the how pre shared keys can be used in a handshake
|
||||
# `none` (the default) does not send or receive using a psk. Ideally `enforced` is used.
|
||||
# `transitional` can receive handshakes using a psk that we know about, but we will not send any handshakes using a psk.
|
||||
# This is helpful for transitioning to `enforced` and should be changed to `enforced` as soon as possible.
|
||||
# Move every node in your mesh to `transitional` then you can move every node in your mesh to `enforced` without having to stop the world
|
||||
# This assumes `keys` is the same on every node in your mesh
|
||||
# `enforced` enforces the use of a psk for all tunnels. Any node not also using `enforced` or `transitional` will not be able to handshake with us
|
||||
#mode: none
|
||||
|
||||
# In `transitional` and `enforced` modes, the keys provided here are sent through hkdf with the intended recipients
|
||||
# ip used in the info section. This helps guard against handshaking with the wrong host if your static_host_map or
|
||||
# lighthouse(s) has incorrect information.
|
||||
#
|
||||
# Setting keys if mode is `none` has no effect.
|
||||
#
|
||||
# Only the first key is used for outbound handshakes but all keys provided will be tried in the order specified, on
|
||||
# incoming handshakes. This is to allow for psk rotation.
|
||||
#keys:
|
||||
# - shared secret string, this one is used in all outbound handshakes
|
||||
# - this is a fallback key, received handshakes can use this
|
||||
# - another fallback, received handshakes can use this one too
|
||||
# - "\x68\x65\x6c\x6c\x6f\x20\x66\x72\x69\x65\x6e\x64\x73" # for raw bytes if you desire
|
||||
|
||||
# Nebula security group configuration
|
||||
firewall:
|
||||
|
||||
@ -4,7 +4,6 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
@ -71,28 +70,51 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
|
||||
}
|
||||
|
||||
func ixHandshakeStage1(f *Interface, addr *udp.Addr, packet []byte, h *header.H) {
|
||||
ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
|
||||
// Mark packet 1 as seen so it doesn't show up as missed
|
||||
ci.window.Update(f.l, 1)
|
||||
|
||||
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
|
||||
return
|
||||
}
|
||||
var (
|
||||
err error
|
||||
ci *ConnectionState
|
||||
msg []byte
|
||||
)
|
||||
|
||||
hs := &NebulaHandshake{}
|
||||
err = proto.Unmarshal(msg, hs)
|
||||
/*
|
||||
l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
|
||||
*/
|
||||
if err != nil || hs.Details == nil {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
|
||||
|
||||
// Handle multiple possible psk options, ensure the protobuf comes out clean too
|
||||
for _, p := range f.psk.Cache {
|
||||
//TODO: benchmark generation time of makePsk
|
||||
ci, err = f.newConnectionState(f.l, false, p)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).Error("Failed to get a new connection state")
|
||||
continue
|
||||
}
|
||||
|
||||
msg, _, _, err = ci.H.ReadMessage(nil, packet[header.Len:])
|
||||
if err != nil {
|
||||
// Calls to ReadMessage with an incorrect psk should fail, try the next one if we have one
|
||||
continue
|
||||
}
|
||||
|
||||
// Sometimes ReadMessage returns fine with a nil psk even if the handshake is using a psk, ensure our protobuf
|
||||
// comes out clean as well
|
||||
err = proto.Unmarshal(msg, hs)
|
||||
if err == nil {
|
||||
// There was no error, we can continue with this handshake
|
||||
break
|
||||
}
|
||||
|
||||
// The unmarshal failed, try the next psk if we have one
|
||||
}
|
||||
|
||||
// We finished with an error, log it and get out
|
||||
if err != nil {
|
||||
// We aren't logging the error here because we can't be sure of the failure when using psk
|
||||
f.l.WithField("udpAddr", addr).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Error("Was unable to decrypt the handshake")
|
||||
return
|
||||
}
|
||||
|
||||
// Mark packet 1 as seen so it doesn't show up as missed
|
||||
ci.window.Update(f.l, 1)
|
||||
|
||||
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
|
||||
24
inside.go
24
inside.go
@ -3,7 +3,6 @@ package nebula
|
||||
import (
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/slackhq/nebula/header"
|
||||
@ -79,7 +78,6 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
|
||||
}
|
||||
hostinfo, err := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f)
|
||||
|
||||
//if err != nil || hostinfo.ConnectionState == nil {
|
||||
if err != nil {
|
||||
hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp)
|
||||
if err != nil {
|
||||
@ -102,21 +100,27 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
|
||||
return hostinfo
|
||||
}
|
||||
|
||||
// Create a connection state if we don't have one yet
|
||||
if ci == nil {
|
||||
// if we don't have a connection state, then send a handshake initiation
|
||||
ci = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
|
||||
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
|
||||
//ci = f.newConnectionState(true, noise.HandshakeXX, []byte{}, 0)
|
||||
// Generate a PSK based on our config, this may be nil
|
||||
p, err := f.psk.MakeFor(vpnIp)
|
||||
if err != nil {
|
||||
//TODO: This isn't fatal specifically but it's pretty bad
|
||||
f.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to get a PSK KDF")
|
||||
return hostinfo
|
||||
}
|
||||
|
||||
ci, err = f.newConnectionState(f.l, true, p)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to get a connection state")
|
||||
return hostinfo
|
||||
}
|
||||
hostinfo.ConnectionState = ci
|
||||
} else if ci.eKey == nil {
|
||||
// if we don't have any state at all, create it
|
||||
}
|
||||
|
||||
// If we have already created the handshake packet, we don't want to call the function at all.
|
||||
if !hostinfo.HandshakeReady {
|
||||
ixHandshakeStage0(f, vpnIp, hostinfo)
|
||||
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
|
||||
//xx_handshakeStage0(f, ip, hostinfo)
|
||||
|
||||
// If this is a static host, we don't need to wait for the HostQueryReply
|
||||
// We can trigger the handshake right now
|
||||
|
||||
25
interface.go
25
interface.go
@ -7,7 +7,9 @@ import (
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
@ -48,6 +50,7 @@ type InterfaceConfig struct {
|
||||
version string
|
||||
caPool *cert.NebulaCAPool
|
||||
disconnectInvalid bool
|
||||
psk *Psk
|
||||
|
||||
ConntrackCacheTimeout time.Duration
|
||||
l *logrus.Logger
|
||||
@ -78,9 +81,9 @@ type Interface struct {
|
||||
version string
|
||||
|
||||
conntrackCacheTimeout time.Duration
|
||||
|
||||
writers []*udp.Conn
|
||||
readers []io.ReadWriteCloser
|
||||
psk *Psk
|
||||
writers []*udp.Conn
|
||||
readers []io.ReadWriteCloser
|
||||
|
||||
metricHandshakes metrics.Histogram
|
||||
messageMetrics *MessageMetrics
|
||||
@ -104,6 +107,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
||||
}
|
||||
|
||||
myVpnIp := iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].IP)
|
||||
|
||||
ifce := &Interface{
|
||||
hostMap: c.HostMap,
|
||||
outside: c.Outside,
|
||||
@ -124,6 +128,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
||||
readers: make([]io.ReadWriteCloser, c.routines),
|
||||
caPool: c.caPool,
|
||||
disconnectInvalid: c.disconnectInvalid,
|
||||
psk: c.psk,
|
||||
myVpnIp: myVpnIp,
|
||||
|
||||
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
||||
@ -234,6 +239,7 @@ func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
||||
for _, udpConn := range f.writers {
|
||||
c.RegisterReloadCallback(udpConn.ReloadConfig)
|
||||
}
|
||||
c.RegisterReloadCallback(f.reloadPSKs)
|
||||
}
|
||||
|
||||
func (f *Interface) reloadCA(c *config.C) {
|
||||
@ -308,6 +314,19 @@ func (f *Interface) reloadFirewall(c *config.C) {
|
||||
Info("New firewall has been installed")
|
||||
}
|
||||
|
||||
func (f *Interface) reloadPSKs(c *config.C) {
|
||||
psk, err := NewPskFromConfig(c, f.myVpnIp)
|
||||
if err != nil {
|
||||
f.l.WithError(err).Error("Error while reloading PSKs")
|
||||
return
|
||||
}
|
||||
|
||||
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&f.psk)), unsafe.Pointer(psk))
|
||||
|
||||
f.l.WithField("pskMode", psk.mode).WithField("keysLen", len(psk.Cache)).
|
||||
Info("New psks are in use")
|
||||
}
|
||||
|
||||
func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
||||
ticker := time.NewTicker(i)
|
||||
defer ticker.Stop()
|
||||
|
||||
10
main.go
10
main.go
@ -95,6 +95,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
}
|
||||
}
|
||||
|
||||
psk, err := NewPskFromConfig(c, iputil.Ip2VpnIp(tunCidr.IP))
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Failed to create psk", nil, err)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// All non system modifying configuration consumption should live above this line
|
||||
// tun config, listeners, anything modifying the computer should be below
|
||||
@ -356,10 +361,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
handshakeManager := NewHandshakeManager(l, tunCidr, preferredRanges, hostMap, lightHouse, udpConns[0], handshakeConfig)
|
||||
lightHouse.handshakeTrigger = handshakeManager.trigger
|
||||
|
||||
//TODO: These will be reused for psk
|
||||
//handshakeMACKey := config.GetString("handshake_mac.key", "")
|
||||
//handshakeAcceptedMACKeys := config.GetStringSlice("handshake_mac.accepted_keys", []string{})
|
||||
|
||||
serveDns := false
|
||||
if c.GetBool("lighthouse.serve_dns", false) {
|
||||
if c.GetBool("lighthouse.am_lighthouse", false) {
|
||||
@ -390,6 +391,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
version: buildVersion,
|
||||
caPool: caPool,
|
||||
disconnectInvalid: c.GetBool("pki.disconnect_invalid", false),
|
||||
psk: psk,
|
||||
|
||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||
l: l,
|
||||
|
||||
1
noise.go
1
noise.go
@ -22,7 +22,6 @@ type NebulaCipherState struct {
|
||||
|
||||
func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState {
|
||||
return &NebulaCipherState{c: s.Cipher()}
|
||||
|
||||
}
|
||||
|
||||
func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) {
|
||||
|
||||
178
psk.go
Normal file
178
psk.go
Normal file
@ -0,0 +1,178 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
var ErrNotAPskMode = errors.New("not a psk mode")
|
||||
var ErrKeyTooShort = errors.New("key is too short")
|
||||
var ErrNotEnoughPskKeys = errors.New("at least 1 key is required")
|
||||
|
||||
// The minimum length that we accept for a user defined psk, the choice is arbitrary
|
||||
const MinPskLength = 8
|
||||
|
||||
type PskMode int
|
||||
|
||||
func (p PskMode) String() string {
|
||||
switch p {
|
||||
case PskNone:
|
||||
return "none"
|
||||
case PskTransitional:
|
||||
return "transitional"
|
||||
case PskEnforced:
|
||||
return "enforced"
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func NewPskMode(m string) (PskMode, error) {
|
||||
switch m {
|
||||
case "none":
|
||||
return PskNone, nil
|
||||
case "transitional":
|
||||
return PskTransitional, nil
|
||||
case "enforced":
|
||||
return PskEnforced, nil
|
||||
}
|
||||
return PskNone, ErrNotAPskMode
|
||||
}
|
||||
|
||||
const (
|
||||
PskNone PskMode = 0
|
||||
PskTransitional PskMode = 1
|
||||
PskEnforced PskMode = 2
|
||||
)
|
||||
|
||||
type Psk struct {
|
||||
// pskMode sets how psk works, ignored, allowed for incoming, or enforced for all
|
||||
mode PskMode
|
||||
|
||||
// Cache holds all pre-computed psk hkdfs
|
||||
// Handshakes iterate this directly
|
||||
Cache [][]byte
|
||||
|
||||
// The key has already been extracted and is ready to be expanded for use
|
||||
// MakeFor does the final expand step mixing in the intended recipients vpn ip
|
||||
key []byte
|
||||
}
|
||||
|
||||
// NewPskFromConfig is a helper for initial boot and config reloading.
|
||||
func NewPskFromConfig(c *config.C, myVpnIp iputil.VpnIp) (*Psk, error) {
|
||||
sMode := c.GetString("handshakes.psk.mode", "none")
|
||||
mode, err := NewPskMode(sMode)
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Could not parse handshakes.psk.mode", m{"mode": mode}, err)
|
||||
}
|
||||
|
||||
return NewPsk(
|
||||
mode,
|
||||
c.GetStringSlice("handshakes.psk.keys", nil),
|
||||
myVpnIp,
|
||||
)
|
||||
}
|
||||
|
||||
// NewPsk creates a new Psk object and handles the caching of all accepted keys and preparation of the primary key
|
||||
func NewPsk(mode PskMode, keys []string, myVpnIp iputil.VpnIp) (*Psk, error) {
|
||||
psk := &Psk{
|
||||
mode: mode,
|
||||
}
|
||||
|
||||
err := psk.preparePrimaryKey(keys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = psk.cachePsks(myVpnIp, keys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return psk, nil
|
||||
}
|
||||
|
||||
// MakeFor if we are in enforced mode, the final hkdf expand stage is done on the pre extracted primary key,
|
||||
// mixing in the intended recipients vpn ip and the result is returned.
|
||||
// If we are transitional or not using psks, an empty byte slice is returned
|
||||
func (p *Psk) MakeFor(vpnIp iputil.VpnIp) ([]byte, error) {
|
||||
if p.mode != PskEnforced {
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
hmacKey := make([]byte, sha256.Size)
|
||||
_, err := io.ReadFull(hkdf.Expand(sha256.New, p.key, vpnIp.ToIP()), hmacKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return hmacKey, nil
|
||||
}
|
||||
|
||||
// cachePsks generates all psks we accept and caches them to speed up handshaking
|
||||
func (p *Psk) cachePsks(myVpnIp iputil.VpnIp, keys []string) error {
|
||||
// If PskNone is set then we are using the nil byte array for a psk, we can return
|
||||
if p.mode == PskNone {
|
||||
p.Cache = [][]byte{nil}
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(keys) < 1 {
|
||||
return ErrNotEnoughPskKeys
|
||||
}
|
||||
|
||||
p.Cache = [][]byte{}
|
||||
|
||||
if p.mode == PskTransitional {
|
||||
// We are transitional, we accept empty psks
|
||||
p.Cache = append(p.Cache, nil)
|
||||
}
|
||||
|
||||
// We are either PskAuto or PskTransitional, build all possibilities
|
||||
for i, rk := range keys {
|
||||
k, err := sha256KdfFromString(rk, myVpnIp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate key for position %v: %w", i, err)
|
||||
}
|
||||
|
||||
p.Cache = append(p.Cache, k)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// preparePrimaryKey if we are in enforced mode, will do an hkdf extract on the first key to benefit
|
||||
// outgoing handshake performance, MakeFor does the final expand step
|
||||
func (p *Psk) preparePrimaryKey(keys []string) error {
|
||||
if p.mode != PskEnforced {
|
||||
// If we aren't enforcing then there is nothing to prepare
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(keys) < 1 {
|
||||
return ErrNotEnoughPskKeys
|
||||
}
|
||||
|
||||
p.key = hkdf.Extract(sha256.New, []byte(keys[0]), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
// sha256KdfFromString generates a full hkdf
|
||||
func sha256KdfFromString(secret string, vpnIp iputil.VpnIp) ([]byte, error) {
|
||||
if len(secret) < MinPskLength {
|
||||
return nil, ErrKeyTooShort
|
||||
}
|
||||
|
||||
hmacKey := make([]byte, sha256.Size)
|
||||
_, err := io.ReadFull(hkdf.New(sha256.New, []byte(secret), nil, vpnIp.ToIP()), hmacKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hmacKey, nil
|
||||
}
|
||||
78
psk_test.go
Normal file
78
psk_test.go
Normal file
@ -0,0 +1,78 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewPsk(t *testing.T) {
|
||||
t.Run("mode none", func(t *testing.T) {
|
||||
p, err := NewPsk(PskNone, nil, 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, PskNone, p.mode)
|
||||
assert.Empty(t, p.key)
|
||||
|
||||
assert.Len(t, p.Cache, 1)
|
||||
assert.Nil(t, p.Cache[0])
|
||||
|
||||
b, err := p.MakeFor(0)
|
||||
assert.Equal(t, []byte{}, b)
|
||||
})
|
||||
|
||||
t.Run("mode transitional", func(t *testing.T) {
|
||||
p, err := NewPsk(PskTransitional, nil, 1)
|
||||
assert.Error(t, ErrNotEnoughPskKeys, err)
|
||||
|
||||
p, err = NewPsk(PskTransitional, []string{"1234567"}, 1)
|
||||
assert.Error(t, ErrKeyTooShort)
|
||||
|
||||
p, err = NewPsk(PskTransitional, []string{"hi there friends"}, 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, PskTransitional, p.mode)
|
||||
assert.Empty(t, p.key)
|
||||
|
||||
assert.Len(t, p.Cache, 2)
|
||||
assert.Nil(t, p.Cache[0])
|
||||
|
||||
expectedCache := []byte{146, 120, 135, 31, 158, 102, 45, 189, 128, 190, 37, 101, 58, 254, 6, 166, 91, 209, 148, 131, 27, 193, 24, 25, 170, 65, 130, 189, 7, 179, 255, 17}
|
||||
assert.Equal(t, expectedCache, p.Cache[1])
|
||||
|
||||
b, err := p.MakeFor(0)
|
||||
assert.Equal(t, []byte{}, b)
|
||||
})
|
||||
|
||||
t.Run("mode enforced", func(t *testing.T) {
|
||||
p, err := NewPsk(PskEnforced, nil, 1)
|
||||
assert.Error(t, ErrNotEnoughPskKeys, err)
|
||||
|
||||
p, err = NewPsk(PskEnforced, []string{"hi there friends"}, 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, PskEnforced, p.mode)
|
||||
|
||||
expectedKey := []byte{156, 103, 171, 88, 121, 92, 138, 240, 170, 240, 76, 108, 154, 66, 107, 14, 226, 148, 177, 0, 40, 28, 220, 136, 68, 53, 63, 183, 213, 9, 192, 218}
|
||||
assert.Equal(t, expectedKey, p.key)
|
||||
|
||||
assert.Len(t, p.Cache, 1)
|
||||
expectedCache := []byte{146, 120, 135, 31, 158, 102, 45, 189, 128, 190, 37, 101, 58, 254, 6, 166, 91, 209, 148, 131, 27, 193, 24, 25, 170, 65, 130, 189, 7, 179, 255, 17}
|
||||
assert.Equal(t, expectedCache, p.Cache[0])
|
||||
|
||||
expectedPsk := []byte{0xd9, 0x16, 0xa3, 0x66, 0x6a, 0x20, 0x26, 0xcf, 0x5d, 0x93, 0xad, 0xa3, 0x88, 0x2d, 0x57, 0xac, 0x9b, 0xc3, 0x5a, 0xb7, 0x8f, 0x6, 0x71, 0xc4, 0x3e, 0x5, 0x9e, 0xbc, 0x4e, 0xc8, 0x24, 0x17}
|
||||
b, err := p.MakeFor(0)
|
||||
assert.Equal(t, expectedPsk, b)
|
||||
|
||||
// Make sure different vpn ips generate different psks
|
||||
expectedPsk = []byte{0x92, 0x78, 0x87, 0x1f, 0x9e, 0x66, 0x2d, 0xbd, 0x80, 0xbe, 0x25, 0x65, 0x3a, 0xfe, 0x6, 0xa6, 0x5b, 0xd1, 0x94, 0x83, 0x1b, 0xc1, 0x18, 0x19, 0xaa, 0x41, 0x82, 0xbd, 0x7, 0xb3, 0xff, 0x11}
|
||||
b, err = p.MakeFor(1)
|
||||
assert.Equal(t, expectedPsk, b)
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkPsk_MakeFor(b *testing.B) {
|
||||
p, err := NewPsk(PskEnforced, []string{"hi there friends"}, 1)
|
||||
assert.NoError(b, err)
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
p.MakeFor(99)
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user