mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-08 21:43:57 +01:00
PSK support for v2
This commit is contained in:
parent
50473bd2a8
commit
2c9cc63c1a
@ -27,7 +27,7 @@ type ConnectionState struct {
|
|||||||
writeLock sync.Mutex
|
writeLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
|
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern, psk []byte) (*ConnectionState, error) {
|
||||||
var dhFunc noise.DHFunc
|
var dhFunc noise.DHFunc
|
||||||
switch crt.Curve() {
|
switch crt.Curve() {
|
||||||
case cert.Curve_CURVE25519:
|
case cert.Curve_CURVE25519:
|
||||||
@ -56,13 +56,12 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
|
|||||||
b.Update(l, 0)
|
b.Update(l, 0)
|
||||||
|
|
||||||
hs, err := noise.NewHandshakeState(noise.Config{
|
hs, err := noise.NewHandshakeState(noise.Config{
|
||||||
CipherSuite: ncs,
|
CipherSuite: ncs,
|
||||||
Random: rand.Reader,
|
Random: rand.Reader,
|
||||||
Pattern: pattern,
|
Pattern: pattern,
|
||||||
Initiator: initiator,
|
Initiator: initiator,
|
||||||
StaticKeypair: static,
|
StaticKeypair: static,
|
||||||
//NOTE: These should come from CertState (pki.go) when we finally implement it
|
PresharedKey: psk,
|
||||||
PresharedKey: []byte{},
|
|
||||||
PresharedKeyPlacement: 0,
|
PresharedKeyPlacement: 0,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -1224,3 +1224,135 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) {
|
|||||||
myControl.Stop()
|
myControl.Stop()
|
||||||
theirControl.Stop()
|
theirControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPSK(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
myPskMode nebula.PskMode
|
||||||
|
theirPskMode nebula.PskMode
|
||||||
|
}{
|
||||||
|
// All accepting
|
||||||
|
{
|
||||||
|
name: "both accepting",
|
||||||
|
myPskMode: nebula.PskAccepting,
|
||||||
|
theirPskMode: nebula.PskAccepting,
|
||||||
|
},
|
||||||
|
|
||||||
|
// accepting and sending both ways
|
||||||
|
{
|
||||||
|
name: "accepting to sending",
|
||||||
|
myPskMode: nebula.PskAccepting,
|
||||||
|
theirPskMode: nebula.PskSending,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sending to accepting",
|
||||||
|
myPskMode: nebula.PskSending,
|
||||||
|
theirPskMode: nebula.PskAccepting,
|
||||||
|
},
|
||||||
|
|
||||||
|
// All sending
|
||||||
|
{
|
||||||
|
name: "sending to sending",
|
||||||
|
myPskMode: nebula.PskSending,
|
||||||
|
theirPskMode: nebula.PskSending,
|
||||||
|
},
|
||||||
|
|
||||||
|
// enforced and sending both ways
|
||||||
|
{
|
||||||
|
name: "enforced to sending",
|
||||||
|
myPskMode: nebula.PskEnforced,
|
||||||
|
theirPskMode: nebula.PskSending,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sending to enforced",
|
||||||
|
myPskMode: nebula.PskSending,
|
||||||
|
theirPskMode: nebula.PskEnforced,
|
||||||
|
},
|
||||||
|
|
||||||
|
// All enforced
|
||||||
|
{
|
||||||
|
name: "both enforced",
|
||||||
|
myPskMode: nebula.PskEnforced,
|
||||||
|
theirPskMode: nebula.PskEnforced,
|
||||||
|
},
|
||||||
|
|
||||||
|
// Enforced can technically handshake with an accepting node, but it is bad to be in this state
|
||||||
|
{
|
||||||
|
name: "enforced to accepting",
|
||||||
|
myPskMode: nebula.PskEnforced,
|
||||||
|
theirPskMode: nebula.PskAccepting,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
var myPskSettings, theirPskSettings m
|
||||||
|
|
||||||
|
switch test.myPskMode {
|
||||||
|
case nebula.PskAccepting:
|
||||||
|
myPskSettings = m{"psk": &m{"mode": "accepting", "keys": []string{"garbage0", "this is a key"}}}
|
||||||
|
case nebula.PskSending:
|
||||||
|
myPskSettings = m{"psk": &m{"mode": "sending", "keys": []string{"this is a key", "garbage1"}}}
|
||||||
|
case nebula.PskEnforced:
|
||||||
|
myPskSettings = m{"psk": &m{"mode": "enforced", "keys": []string{"this is a key", "garbage2"}}}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch test.theirPskMode {
|
||||||
|
case nebula.PskAccepting:
|
||||||
|
theirPskSettings = m{"psk": &m{"mode": "accepting", "keys": []string{"garbage3", "this is a key"}}}
|
||||||
|
case nebula.PskSending:
|
||||||
|
theirPskSettings = m{"psk": &m{"mode": "sending", "keys": []string{"this is a key", "garbage4"}}}
|
||||||
|
case nebula.PskEnforced:
|
||||||
|
theirPskSettings = m{"psk": &m{"mode": "enforced", "keys": []string{"this is a key", "garbage5"}}}
|
||||||
|
}
|
||||||
|
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
|
||||||
|
myControl, myVpnIp, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.0.0.1/24", myPskSettings)
|
||||||
|
theirControl, theirVpnIp, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.0.0.2/24", theirPskSettings)
|
||||||
|
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIp[0].Addr(), theirUdpAddr)
|
||||||
|
r := router.NewR(t, myControl, theirControl)
|
||||||
|
|
||||||
|
// Start the servers
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
t.Log("Route until we see our cached packet flow")
|
||||||
|
myControl.InjectTunUDPPacket(theirVpnIp[0].Addr(), 80, myVpnIp[0].Addr(), 80, []byte("Hi from me"))
|
||||||
|
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
|
||||||
|
h := &header.H{}
|
||||||
|
err := h.Parse(p.Data)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this is the stage 1 handshake packet and I am configured to send with a psk, my cert name should
|
||||||
|
// not appear. It would likely be more obvious to unmarshal the payload and check but this works fine for now
|
||||||
|
if test.myPskMode == nebula.PskEnforced || test.myPskMode == nebula.PskSending {
|
||||||
|
if h.Type == 0 && h.MessageCounter == 1 {
|
||||||
|
assert.NotContains(t, string(p.Data), "test me")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.To == theirUdpAddr && h.Type == 1 {
|
||||||
|
return router.RouteAndExit
|
||||||
|
}
|
||||||
|
|
||||||
|
return router.KeepRouting
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Log("My cached packet should be received by them")
|
||||||
|
myCachedPacket := theirControl.GetFromTun(true)
|
||||||
|
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp[0].Addr(), theirVpnIp[0].Addr(), 80, 80)
|
||||||
|
|
||||||
|
t.Log("Test the tunnel with them")
|
||||||
|
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
|
||||||
|
assertTunnel(t, myVpnIp[0].Addr(), theirVpnIp[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
//TODO: assert hostmaps
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|||||||
@ -111,10 +111,6 @@ type ExitFunc func(packet *udp.Packet, receiver *nebula.Control) ExitType
|
|||||||
func NewR(t testing.TB, controls ...*nebula.Control) *R {
|
func NewR(t testing.TB, controls ...*nebula.Control) *R {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
if err := os.MkdirAll("mermaid", 0755); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r := &R{
|
r := &R{
|
||||||
controls: make(map[netip.AddrPort]*nebula.Control),
|
controls: make(map[netip.AddrPort]*nebula.Control),
|
||||||
vpnControls: make(map[netip.Addr]*nebula.Control),
|
vpnControls: make(map[netip.Addr]*nebula.Control),
|
||||||
@ -194,6 +190,9 @@ func (r *R) renderFlow() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := os.MkdirAll(filepath.Dir(r.fn), 0755); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
f, err := os.OpenFile(r.fn, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0644)
|
f, err := os.OpenFile(r.fn, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
|||||||
@ -19,6 +19,38 @@ pki:
|
|||||||
# After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed.
|
# After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed.
|
||||||
# default_version: 1
|
# default_version: 1
|
||||||
|
|
||||||
|
# psk can be used to mask the contents of handshakes.
|
||||||
|
psk:
|
||||||
|
# `mode` defines how the pre shared keys can be used in a handshake.
|
||||||
|
# `accepting` (the default) will initiate handshakes using an empty key and will try to use any keys provided when
|
||||||
|
# receiving handshakes, including an empty key.
|
||||||
|
# `sending` will initiate handshakes with the first key provided and will try to use any keys provided when
|
||||||
|
# receiving handshakes, including an empty key.
|
||||||
|
# `enforced` will initiate handshakes with the first psk key provided and will try to use any keys provided when
|
||||||
|
# responding to handshakes. An empty key will not be allowed.
|
||||||
|
#
|
||||||
|
# To change a mesh from not using a psk to enforcing psk:
|
||||||
|
# 1. Leave `mode` as `accepting` and configure `psk.keys` to match on all nodes in the mesh and reload.
|
||||||
|
# 2. Change `mode` to `sending` on all nodes in the mesh and reload.
|
||||||
|
# 3. Change `mode` to `enforced` on all nodes in the mesh and reload.
|
||||||
|
#mode: accepting
|
||||||
|
|
||||||
|
# The keys provided are sent through hkdf to ensure the shared secret used in the noise protocol is the
|
||||||
|
# correct byte length.
|
||||||
|
#
|
||||||
|
# Only the first key is used for outbound handshakes but all keys provided will be tried in the order specified, on
|
||||||
|
# incoming handshakes. This is to allow for psk rotation.
|
||||||
|
#
|
||||||
|
# To rotate a primary key:
|
||||||
|
# 1. Put the new key in the 2nd slot on every node in the mesh and reload.
|
||||||
|
# 2. Move the key from the 2nd slot to the 1st slot, the old primary key is now in the 2nd slot, reload.
|
||||||
|
# 3. Remove the old primary key once it is no longer in use on every node in the mesh and reload.
|
||||||
|
#keys:
|
||||||
|
# - shared secret string, this one is used in all outbound handshakes # This is the primary key used when sending handshakes
|
||||||
|
# - this is a fallback key, received handshakes can use this
|
||||||
|
# - another fallback, received handshakes can use this one too
|
||||||
|
# - "\x68\x65\x6c\x6c\x6f\x20\x66\x72\x69\x65\x6e\x64\x73" # for raw bytes if you desire
|
||||||
|
|
||||||
# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
|
# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
|
||||||
# A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
|
# A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
|
||||||
# The syntax is:
|
# The syntax is:
|
||||||
@ -313,7 +345,6 @@ logging:
|
|||||||
# after receiving the response for lighthouse queries
|
# after receiving the response for lighthouse queries
|
||||||
#trigger_buffer: 64
|
#trigger_buffer: 64
|
||||||
|
|
||||||
|
|
||||||
# Nebula security group configuration
|
# Nebula security group configuration
|
||||||
firewall:
|
firewall:
|
||||||
# Action to take when a packet is not allowed by the firewall rules.
|
# Action to take when a packet is not allowed by the firewall rules.
|
||||||
|
|||||||
@ -50,7 +50,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||||
}
|
}
|
||||||
|
|
||||||
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
|
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX, cs.psk.primary)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||||
@ -104,34 +104,53 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
Error("Unable to handshake with host because no certificate is available")
|
Error("Unable to handshake with host because no certificate is available")
|
||||||
}
|
}
|
||||||
|
|
||||||
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
var (
|
||||||
if err != nil {
|
err error
|
||||||
f.l.WithError(err).WithField("udpAddr", addr).
|
ci *ConnectionState
|
||||||
|
msg []byte
|
||||||
|
)
|
||||||
|
|
||||||
|
hs := &NebulaHandshake{}
|
||||||
|
|
||||||
|
for _, psk := range cs.psk.keys {
|
||||||
|
ci, err = NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX, psk)
|
||||||
|
if err != nil {
|
||||||
|
//TODO: should be bother logging this, if we have multiple psks and the error is unrelated it will be verbose.
|
||||||
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
|
Error("Failed to create connection state")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, _, _, err = ci.H.ReadMessage(nil, packet[header.Len:])
|
||||||
|
if err != nil {
|
||||||
|
// Calls to ReadMessage with an incorrect psk should fail, try the next one if we have one
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sometimes ReadMessage returns fine with a nil psk even if the handshake is using a psk, ensure our protobuf
|
||||||
|
// comes out clean as well
|
||||||
|
err = hs.Unmarshal(msg)
|
||||||
|
if err == nil {
|
||||||
|
// There was no error, we can continue with this handshake
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// The unmarshal failed, try the next psk if we have one
|
||||||
|
}
|
||||||
|
|
||||||
|
// We finished with an error, log it and get out
|
||||||
|
if err != nil || hs.Details == nil {
|
||||||
|
// We aren't logging the error here because we can't be sure of the failure when using psk
|
||||||
|
f.l.WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
Error("Failed to create connection state")
|
Error("Was unable to decrypt the handshake")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark packet 1 as seen so it doesn't show up as missed
|
// Mark packet 1 as seen so it doesn't show up as missed
|
||||||
ci.window.Update(f.l, 1)
|
ci.window.Update(f.l, 1)
|
||||||
|
|
||||||
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).WithField("udpAddr", addr).
|
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
|
||||||
Error("Failed to call noise.ReadMessage")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
hs := &NebulaHandshake{}
|
|
||||||
err = hs.Unmarshal(msg)
|
|
||||||
if err != nil || hs.Details == nil {
|
|
||||||
f.l.WithError(err).WithField("udpAddr", addr).
|
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
|
||||||
Error("Failed unmarshal handshake message")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
||||||
@ -23,11 +24,15 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
|||||||
|
|
||||||
lh := newTestLighthouse()
|
lh := newTestLighthouse()
|
||||||
|
|
||||||
|
psk, err := NewPsk(PskAccepting, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
cs := &CertState{
|
cs := &CertState{
|
||||||
defaultVersion: cert.Version1,
|
defaultVersion: cert.Version1,
|
||||||
privateKey: []byte{},
|
privateKey: []byte{},
|
||||||
v1Cert: &dummyCert{version: cert.Version1},
|
v1Cert: &dummyCert{version: cert.Version1},
|
||||||
v1HandshakeBytes: []byte{},
|
v1HandshakeBytes: []byte{},
|
||||||
|
psk: psk,
|
||||||
}
|
}
|
||||||
|
|
||||||
blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
|
blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
|
||||||
|
|||||||
12
pki.go
12
pki.go
@ -38,6 +38,8 @@ type CertState struct {
|
|||||||
pkcs11Backed bool
|
pkcs11Backed bool
|
||||||
cipher string
|
cipher string
|
||||||
|
|
||||||
|
psk *Psk
|
||||||
|
|
||||||
myVpnNetworks []netip.Prefix
|
myVpnNetworks []netip.Prefix
|
||||||
myVpnNetworksTable *bart.Table[struct{}]
|
myVpnNetworksTable *bart.Table[struct{}]
|
||||||
myVpnAddrs []netip.Addr
|
myVpnAddrs []netip.Addr
|
||||||
@ -171,6 +173,16 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
psk, err := NewPskFromConfig(c)
|
||||||
|
if err != nil {
|
||||||
|
return util.NewContextualError("Failed to load psk from config", nil, err)
|
||||||
|
}
|
||||||
|
if len(psk.keys) > 0 {
|
||||||
|
p.l.WithField("pskMode", psk.mode).WithField("keysLen", len(psk.keys)).
|
||||||
|
Info("pre shared keys are in use")
|
||||||
|
}
|
||||||
|
newState.psk = psk
|
||||||
|
|
||||||
p.cs.Store(newState)
|
p.cs.Store(newState)
|
||||||
|
|
||||||
//TODO: CERT-V2 newState needs a stringer that does json
|
//TODO: CERT-V2 newState needs a stringer that does json
|
||||||
|
|||||||
150
psk.go
Normal file
150
psk.go
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
package nebula
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/util"
|
||||||
|
"golang.org/x/crypto/hkdf"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrNotAPskMode = errors.New("not a psk mode")
|
||||||
|
var ErrKeyTooShort = errors.New("key is too short")
|
||||||
|
var ErrNotEnoughPskKeys = errors.New("at least 1 key is required")
|
||||||
|
|
||||||
|
// MinPskLength is the minimum bytes that we accept for a user defined psk, the choice is arbitrary
|
||||||
|
const MinPskLength = 8
|
||||||
|
|
||||||
|
type PskMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
PskAccepting PskMode = 0
|
||||||
|
PskSending PskMode = 1
|
||||||
|
PskEnforced PskMode = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewPskMode(m string) (PskMode, error) {
|
||||||
|
switch m {
|
||||||
|
case "accepting":
|
||||||
|
return PskAccepting, nil
|
||||||
|
case "sending":
|
||||||
|
return PskSending, nil
|
||||||
|
case "enforced":
|
||||||
|
return PskEnforced, nil
|
||||||
|
}
|
||||||
|
return PskAccepting, ErrNotAPskMode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PskMode) String() string {
|
||||||
|
switch p {
|
||||||
|
case PskAccepting:
|
||||||
|
return "accepting"
|
||||||
|
case PskSending:
|
||||||
|
return "sending"
|
||||||
|
case PskEnforced:
|
||||||
|
return "enforced"
|
||||||
|
}
|
||||||
|
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PskMode) IsValid() bool {
|
||||||
|
switch p {
|
||||||
|
case PskAccepting, PskSending, PskEnforced:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Psk struct {
|
||||||
|
// pskMode sets how psk works, ignored, allowed for incoming, or enforced for all
|
||||||
|
mode PskMode
|
||||||
|
|
||||||
|
// primary is the key to use when sending, it may be nil
|
||||||
|
primary []byte
|
||||||
|
|
||||||
|
// keys holds all pre-computed psk hkdfs
|
||||||
|
// Handshakes iterate this directly
|
||||||
|
keys [][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPskFromConfig is a helper for initial boot and config reloading.
|
||||||
|
func NewPskFromConfig(c *config.C) (*Psk, error) {
|
||||||
|
sMode := c.GetString("psk.mode", "accepting")
|
||||||
|
mode, err := NewPskMode(sMode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, util.NewContextualError("Could not parse psk.mode", m{"mode": mode}, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewPsk(
|
||||||
|
mode,
|
||||||
|
c.GetStringSlice("psk.keys", nil),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPsk creates a new Psk object and handles the caching of all accepted keys
|
||||||
|
func NewPsk(mode PskMode, keys []string) (*Psk, error) {
|
||||||
|
if !mode.IsValid() {
|
||||||
|
return nil, ErrNotAPskMode
|
||||||
|
}
|
||||||
|
|
||||||
|
psk := &Psk{
|
||||||
|
mode: mode,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := psk.cachePsks(keys)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return psk, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cachePsks generates all psks we accept and caches them to speed up handshaking
|
||||||
|
func (p *Psk) cachePsks(keys []string) error {
|
||||||
|
if p.mode != PskAccepting && len(keys) < 1 {
|
||||||
|
return ErrNotEnoughPskKeys
|
||||||
|
}
|
||||||
|
|
||||||
|
p.keys = [][]byte{}
|
||||||
|
|
||||||
|
for i, rk := range keys {
|
||||||
|
k, err := sha256KdfFromString(rk)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate key for position %v: %w", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.keys = append(p.keys, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.mode != PskAccepting {
|
||||||
|
// We are either sending or enforcing, the primary key must the first slot
|
||||||
|
p.primary = p.keys[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.mode != PskEnforced {
|
||||||
|
// If we are not enforcing psk use then a nil psk is allowed
|
||||||
|
p.keys = append(p.keys, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sha256KdfFromString generates a useful key to use from a provided secret
|
||||||
|
func sha256KdfFromString(secret string) ([]byte, error) {
|
||||||
|
if len(secret) < MinPskLength {
|
||||||
|
return nil, ErrKeyTooShort
|
||||||
|
}
|
||||||
|
|
||||||
|
hmacKey := make([]byte, sha256.Size)
|
||||||
|
_, err := io.ReadFull(hkdf.New(sha256.New, []byte(secret), nil, nil), hmacKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return hmacKey, nil
|
||||||
|
}
|
||||||
71
psk_test.go
Normal file
71
psk_test.go
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
package nebula
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewPsk(t *testing.T) {
|
||||||
|
t.Run("mode accepting", func(t *testing.T) {
|
||||||
|
p, err := NewPsk(PskAccepting, nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, PskAccepting, p.mode)
|
||||||
|
assert.Nil(t, p.keys[0])
|
||||||
|
assert.Nil(t, p.primary)
|
||||||
|
|
||||||
|
p, err = NewPsk(PskAccepting, []string{"1234567"})
|
||||||
|
assert.Error(t, ErrKeyTooShort)
|
||||||
|
|
||||||
|
p, err = NewPsk(PskAccepting, []string{"hi there friends"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, PskAccepting, p.mode)
|
||||||
|
assert.Nil(t, p.primary)
|
||||||
|
assert.Len(t, p.keys, 2)
|
||||||
|
assert.Nil(t, p.keys[1])
|
||||||
|
|
||||||
|
expectedCache := []byte{
|
||||||
|
0xb9, 0x8c, 0xdc, 0xac, 0x77, 0xf4, 0x8c, 0xf8, 0x1d, 0xe7, 0xe7, 0xb, 0x53, 0x25, 0xd3, 0x65,
|
||||||
|
0xa3, 0x9f, 0x78, 0xb2, 0xc7, 0x2d, 0xa5, 0xd8, 0x84, 0x81, 0x7b, 0xb5, 0xdb, 0xe0, 0x9a, 0xef,
|
||||||
|
}
|
||||||
|
assert.Equal(t, expectedCache, p.keys[0])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("mode sending", func(t *testing.T) {
|
||||||
|
p, err := NewPsk(PskSending, nil)
|
||||||
|
assert.Error(t, ErrNotEnoughPskKeys, err)
|
||||||
|
|
||||||
|
p, err = NewPsk(PskSending, []string{"1234567"})
|
||||||
|
assert.Error(t, ErrKeyTooShort)
|
||||||
|
|
||||||
|
p, err = NewPsk(PskSending, []string{"hi there friends"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, PskSending, p.mode)
|
||||||
|
assert.Len(t, p.keys, 2)
|
||||||
|
assert.Nil(t, p.keys[1])
|
||||||
|
|
||||||
|
expectedCache := []byte{
|
||||||
|
0xb9, 0x8c, 0xdc, 0xac, 0x77, 0xf4, 0x8c, 0xf8, 0x1d, 0xe7, 0xe7, 0xb, 0x53, 0x25, 0xd3, 0x65,
|
||||||
|
0xa3, 0x9f, 0x78, 0xb2, 0xc7, 0x2d, 0xa5, 0xd8, 0x84, 0x81, 0x7b, 0xb5, 0xdb, 0xe0, 0x9a, 0xef,
|
||||||
|
}
|
||||||
|
assert.Equal(t, expectedCache, p.keys[0])
|
||||||
|
assert.Equal(t, p.keys[0], p.primary)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("mode enforced", func(t *testing.T) {
|
||||||
|
p, err := NewPsk(PskEnforced, nil)
|
||||||
|
assert.Error(t, ErrNotEnoughPskKeys, err)
|
||||||
|
|
||||||
|
p, err = NewPsk(PskEnforced, []string{"hi there friends"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, PskEnforced, p.mode)
|
||||||
|
assert.Len(t, p.keys, 1)
|
||||||
|
|
||||||
|
expectedCache := []byte{
|
||||||
|
0xb9, 0x8c, 0xdc, 0xac, 0x77, 0xf4, 0x8c, 0xf8, 0x1d, 0xe7, 0xe7, 0xb, 0x53, 0x25, 0xd3, 0x65,
|
||||||
|
0xa3, 0x9f, 0x78, 0xb2, 0xc7, 0x2d, 0xa5, 0xd8, 0x84, 0x81, 0x7b, 0xb5, 0xdb, 0xe0, 0x9a, 0xef,
|
||||||
|
}
|
||||||
|
assert.Equal(t, expectedCache, p.keys[0])
|
||||||
|
assert.Equal(t, p.keys[0], p.primary)
|
||||||
|
})
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user