mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-08 22:33:58 +01:00
There are some subtle race conditions with the previous handshake_ix implementation, mostly around collisions with localIndexId. This change refactors it so that we have a "commit" phase during the handshake where we grab the lock for the hostmap and ensure that we have a unique local index before storing it. We also now avoid using the pending hostmap at all for receiving stage1 packets, since we have everything we need to just store the completed handshake. Co-authored-by: Nate Brown <nbrown.us@gmail.com> Co-authored-by: Ryan Huber <rhuber@gmail.com> Co-authored-by: forfuncsake <drussell@slack-corp.com>
360 lines
12 KiB
Go
360 lines
12 KiB
Go
package nebula
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"encoding/binary"
|
|
"errors"
|
|
"net"
|
|
"time"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const (
|
|
// Total time to try a handshake = sequence of HandshakeTryInterval * HandshakeRetries
|
|
// With 100ms interval and 20 retries is 23.5 seconds
|
|
DefaultHandshakeTryInterval = time.Millisecond * 100
|
|
DefaultHandshakeRetries = 20
|
|
// DefaultHandshakeWaitRotation is the number of handshake attempts to do before starting to use other ips addresses
|
|
DefaultHandshakeWaitRotation = 5
|
|
DefaultHandshakeTriggerBuffer = 64
|
|
)
|
|
|
|
var (
|
|
defaultHandshakeConfig = HandshakeConfig{
|
|
tryInterval: DefaultHandshakeTryInterval,
|
|
retries: DefaultHandshakeRetries,
|
|
waitRotation: DefaultHandshakeWaitRotation,
|
|
triggerBuffer: DefaultHandshakeTriggerBuffer,
|
|
}
|
|
)
|
|
|
|
type HandshakeConfig struct {
|
|
tryInterval time.Duration
|
|
retries int
|
|
waitRotation int
|
|
triggerBuffer int
|
|
|
|
messageMetrics *MessageMetrics
|
|
}
|
|
|
|
type HandshakeManager struct {
|
|
pendingHostMap *HostMap
|
|
mainHostMap *HostMap
|
|
lightHouse *LightHouse
|
|
outside *udpConn
|
|
config HandshakeConfig
|
|
|
|
// can be used to trigger outbound handshake for the given vpnIP
|
|
trigger chan uint32
|
|
|
|
OutboundHandshakeTimer *SystemTimerWheel
|
|
InboundHandshakeTimer *SystemTimerWheel
|
|
|
|
messageMetrics *MessageMetrics
|
|
}
|
|
|
|
func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
|
|
return &HandshakeManager{
|
|
pendingHostMap: NewHostMap("pending", tunCidr, preferredRanges),
|
|
mainHostMap: mainHostMap,
|
|
lightHouse: lightHouse,
|
|
outside: outside,
|
|
|
|
config: config,
|
|
|
|
trigger: make(chan uint32, config.triggerBuffer),
|
|
|
|
OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
|
|
InboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
|
|
|
|
messageMetrics: config.messageMetrics,
|
|
}
|
|
}
|
|
|
|
func (c *HandshakeManager) Run(f EncWriter) {
|
|
clockSource := time.Tick(c.config.tryInterval)
|
|
for {
|
|
select {
|
|
case vpnIP := <-c.trigger:
|
|
l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
|
|
c.handleOutbound(vpnIP, f, true)
|
|
case now := <-clockSource:
|
|
c.NextOutboundHandshakeTimerTick(now, f)
|
|
c.NextInboundHandshakeTimerTick(now)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) {
|
|
c.OutboundHandshakeTimer.advance(now)
|
|
for {
|
|
ep := c.OutboundHandshakeTimer.Purge()
|
|
if ep == nil {
|
|
break
|
|
}
|
|
vpnIP := ep.(uint32)
|
|
c.handleOutbound(vpnIP, f, false)
|
|
}
|
|
}
|
|
|
|
func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseTriggered bool) {
|
|
hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP)
|
|
if err != nil {
|
|
return
|
|
}
|
|
hostinfo.Lock()
|
|
defer hostinfo.Unlock()
|
|
|
|
// If we haven't finished the handshake and we haven't hit max retries, query
|
|
// lighthouse and then send the handshake packet again.
|
|
if hostinfo.HandshakeCounter < c.config.retries && !hostinfo.HandshakeComplete {
|
|
if hostinfo.remote == nil {
|
|
// We continue to query the lighthouse because hosts may
|
|
// come online during handshake retries. If the query
|
|
// succeeds (no error), add the lighthouse info to hostinfo
|
|
ips := c.lightHouse.QueryCache(vpnIP)
|
|
// If we have no responses yet, or only one IP (the host hadn't
|
|
// finished reporting its own IPs yet), then send another query to
|
|
// the LH.
|
|
if len(ips) <= 1 {
|
|
ips, err = c.lightHouse.Query(vpnIP, f)
|
|
}
|
|
if err == nil {
|
|
for _, ip := range ips {
|
|
hostinfo.AddRemote(ip)
|
|
}
|
|
hostinfo.ForcePromoteBest(c.mainHostMap.preferredRanges)
|
|
}
|
|
} else if lighthouseTriggered {
|
|
// We were triggered by a lighthouse HostQueryReply packet, but
|
|
// we have already picked a remote for this host (this can happen
|
|
// if we are configured with multiple lighthouses). So we can skip
|
|
// this trigger and let the timerwheel handle the rest of the
|
|
// process
|
|
return
|
|
}
|
|
|
|
hostinfo.HandshakeCounter++
|
|
|
|
// We want to use the "best" calculated ip for the first 5 attempts, after that we just blindly rotate through
|
|
// all the others until we can stand up a connection.
|
|
if hostinfo.HandshakeCounter > c.config.waitRotation {
|
|
hostinfo.rotateRemote()
|
|
}
|
|
|
|
// Ensure the handshake is ready to avoid a race in timer tick and stage 0 handshake generation
|
|
if hostinfo.HandshakeReady && hostinfo.remote != nil {
|
|
c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
|
|
err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote)
|
|
if err != nil {
|
|
hostinfo.logger().WithField("udpAddr", hostinfo.remote).
|
|
WithField("initiatorIndex", hostinfo.localIndexId).
|
|
WithField("remoteIndex", hostinfo.remoteIndexId).
|
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
|
WithError(err).Error("Failed to send handshake message")
|
|
} else {
|
|
//TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should
|
|
// keep the real packet struct around for logging purposes
|
|
hostinfo.logger().WithField("udpAddr", hostinfo.remote).
|
|
WithField("initiatorIndex", hostinfo.localIndexId).
|
|
WithField("remoteIndex", hostinfo.remoteIndexId).
|
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
|
Info("Handshake message sent")
|
|
}
|
|
}
|
|
|
|
// Readd to the timer wheel so we continue trying wait HandshakeTryInterval * counter longer for next try
|
|
if !lighthouseTriggered {
|
|
//l.Infoln("Interval: ", HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter))
|
|
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
|
|
}
|
|
} else {
|
|
c.pendingHostMap.DeleteHostInfo(hostinfo)
|
|
}
|
|
}
|
|
|
|
func (c *HandshakeManager) NextInboundHandshakeTimerTick(now time.Time) {
|
|
c.InboundHandshakeTimer.advance(now)
|
|
for {
|
|
ep := c.InboundHandshakeTimer.Purge()
|
|
if ep == nil {
|
|
break
|
|
}
|
|
index := ep.(uint32)
|
|
|
|
c.pendingHostMap.DeleteIndex(index)
|
|
}
|
|
}
|
|
|
|
func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo {
|
|
hostinfo := c.pendingHostMap.AddVpnIP(vpnIP)
|
|
// We lock here and use an array to insert items to prevent locking the
|
|
// main receive thread for very long by waiting to add items to the pending map
|
|
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval)
|
|
|
|
return hostinfo
|
|
}
|
|
|
|
var (
|
|
ErrExistingHostInfo = errors.New("existing hostinfo")
|
|
ErrAlreadySeen = errors.New("already seen")
|
|
ErrLocalIndexCollision = errors.New("local index collision")
|
|
)
|
|
|
|
// CheckAndComplete checks for any conflicts in the main and pending hostmap
|
|
// before adding hostinfo to main. If err is nil, it was added. Otherwise err will be:
|
|
|
|
// ErrAlreadySeen if we already have an entry in the hostmap that has seen the
|
|
// exact same handshake packet
|
|
//
|
|
// ErrExistingHostInfo if we already have an entry in the hostmap for this
|
|
// VpnIP and overwrite was false.
|
|
//
|
|
// ErrLocalIndexCollision if we already have an entry in the main or pending
|
|
// hostmap for the hostinfo.localIndexId.
|
|
func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, overwrite bool, f *Interface) (*HostInfo, error) {
|
|
c.pendingHostMap.RLock()
|
|
defer c.pendingHostMap.RUnlock()
|
|
c.mainHostMap.Lock()
|
|
defer c.mainHostMap.Unlock()
|
|
|
|
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
|
|
if found && existingHostInfo != nil {
|
|
if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
|
|
return existingHostInfo, ErrAlreadySeen
|
|
}
|
|
if !overwrite {
|
|
return existingHostInfo, ErrExistingHostInfo
|
|
}
|
|
}
|
|
|
|
existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId]
|
|
if found {
|
|
// We have a collision, but for a different hostinfo
|
|
return existingIndex, ErrLocalIndexCollision
|
|
}
|
|
existingIndex, found = c.pendingHostMap.Indexes[hostinfo.localIndexId]
|
|
if found && existingIndex != hostinfo {
|
|
// We have a collision, but for a different hostinfo
|
|
return existingIndex, ErrLocalIndexCollision
|
|
}
|
|
|
|
existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
|
|
if found && existingRemoteIndex != nil && existingRemoteIndex.hostId != hostinfo.hostId {
|
|
// We have a collision, but this can happen since we can't control
|
|
// the remote ID. Just log about the situation as a note.
|
|
hostinfo.logger().
|
|
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
|
|
Info("New host shadows existing host remoteIndex")
|
|
}
|
|
|
|
if existingHostInfo != nil {
|
|
// We are going to overwrite this entry, so remove the old references
|
|
delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
|
|
delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
|
|
delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
|
|
}
|
|
|
|
c.mainHostMap.addHostInfo(hostinfo, f)
|
|
return existingHostInfo, nil
|
|
}
|
|
|
|
// Complete is a simpler version of CheckAndComplete when we already know we
|
|
// won't have a localIndexId collision because we already have an entry in the
|
|
// pendingHostMap
|
|
func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
|
|
c.mainHostMap.Lock()
|
|
defer c.mainHostMap.Unlock()
|
|
|
|
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
|
|
if found && existingHostInfo != nil {
|
|
// We are going to overwrite this entry, so remove the old references
|
|
delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
|
|
delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
|
|
delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
|
|
}
|
|
|
|
existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
|
|
if found && existingRemoteIndex != nil {
|
|
// We have a collision, but this can happen since we can't control
|
|
// the remote ID. Just log about the situation as a note.
|
|
hostinfo.logger().
|
|
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
|
|
Info("New host shadows existing host remoteIndex")
|
|
}
|
|
|
|
c.mainHostMap.addHostInfo(hostinfo, f)
|
|
}
|
|
|
|
// AddIndexHostInfo generates a unique localIndexId for this HostInfo
|
|
// and adds it to the pendingHostMap. Will error if we are unable to generate
|
|
// a unique localIndexId
|
|
func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error {
|
|
c.pendingHostMap.Lock()
|
|
defer c.pendingHostMap.Unlock()
|
|
c.mainHostMap.RLock()
|
|
defer c.mainHostMap.RUnlock()
|
|
|
|
for i := 0; i < 32; i++ {
|
|
index, err := generateIndex()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, inPending := c.pendingHostMap.Indexes[index]
|
|
_, inMain := c.mainHostMap.Indexes[index]
|
|
|
|
if !inMain && !inPending {
|
|
h.localIndexId = index
|
|
c.pendingHostMap.Indexes[index] = h
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return errors.New("failed to generate unique localIndexId")
|
|
}
|
|
|
|
func (c *HandshakeManager) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
|
|
c.pendingHostMap.addRemoteIndexHostInfo(index, h)
|
|
}
|
|
|
|
func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
|
|
//l.Debugln("Deleting pending hostinfo :", hostinfo)
|
|
c.pendingHostMap.DeleteHostInfo(hostinfo)
|
|
}
|
|
|
|
func (c *HandshakeManager) QueryIndex(index uint32) (*HostInfo, error) {
|
|
return c.pendingHostMap.QueryIndex(index)
|
|
}
|
|
|
|
func (c *HandshakeManager) EmitStats() {
|
|
c.pendingHostMap.EmitStats("pending")
|
|
c.mainHostMap.EmitStats("main")
|
|
}
|
|
|
|
// Utility functions below
|
|
|
|
func generateIndex() (uint32, error) {
|
|
b := make([]byte, 4)
|
|
|
|
// Let zero mean we don't know the ID, so don't generate zero
|
|
var index uint32
|
|
for index == 0 {
|
|
_, err := rand.Read(b)
|
|
if err != nil {
|
|
l.Errorln(err)
|
|
return 0, err
|
|
}
|
|
|
|
index = binary.BigEndian.Uint32(b)
|
|
}
|
|
|
|
if l.Level >= logrus.DebugLevel {
|
|
l.WithField("index", index).
|
|
Debug("Generated index")
|
|
}
|
|
return index, nil
|
|
}
|