relay rework

This commit is contained in:
JackDoan
2025-12-19 13:08:46 -06:00
parent 3338a2a2a1
commit 188b20457e

View File

@@ -20,150 +20,86 @@ const (
minFwPacketLen = 4 minFwPacketLen = 4
) )
func (f *Interface) readOutsidePacketFromRelay(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) { // handleRelayPackets handles relay packets. Returns false if there's nothing left to do, true for continuing to process an unwrapped TerminalType packet
//todo this is way too similar to readOutsidePacketsMany, find a way to eliminate // scratch must be large enough to contain a packet to be relayed if needed
err := h.Parse(packet) func (f *Interface) handleRelayPackets(via *ViaSender, hostinfo *HostInfo, segment *[]byte, scratch []byte, h *header.H, nb []byte) bool {
var err error
// The entire body is sent as AD, not encrypted.
// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
// otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice
// which will gracefully fail in the DecryptDanger call.
seg := *segment
signedPayload := seg[:len(*segment)-hostinfo.ConnectionState.dKey.Overhead()]
signatureValue := seg[len(*segment)-hostinfo.ConnectionState.dKey.Overhead():]
scratch, err = hostinfo.ConnectionState.dKey.DecryptDanger(scratch, signedPayload, signatureValue, h.MessageCounter, nb)
if err != nil { if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors return false
if len(packet) > 1 {
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err)
} }
return // Successfully validated the thing. Get rid of the Relay header.
} signedPayload = signedPayload[header.Len:]
// Pull the Roaming parts up here, and return in all call paths.
//l.Error("in packet ", header, packet[HeaderLen:]) f.handleHostRoaming(hostinfo, *via)
if !via.IsRelayed { // Track usage of both the HostInfo and the Relay for the received & authenticated packet
if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("from", via).Debug("Refusing to process double encrypted packet")
}
return
}
}
var hostinfo *HostInfo
// verify if we've seen this index before, otherwise respond to the handshake initiation
if h.Type == header.Message && h.Subtype == header.MessageRelay {
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
} else {
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
}
var ci *ConnectionState
if hostinfo != nil {
ci = hostinfo.ConnectionState
}
switch h.Type {
case header.Message:
if !f.handleEncrypted(ci, via, h) {
return
}
switch h.Subtype {
case header.MessageNone:
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache, now) {
return
}
case header.MessageRelay:
//this packet already came to us via a relay
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("from", via).Debug("Refusing to process double relayed packet")
}
return
}
case header.LightHouse:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, via, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("from", via).
WithField("packet", packet).
Error("Failed to decrypt lighthouse packet")
return
}
//TODO: assert via is not relayed
lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f)
// Fallthrough to the bottom to record incoming traffic
case header.Test:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, via, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("from", via).
WithField("packet", packet).
Error("Failed to decrypt test packet")
return
}
if h.Subtype == header.TestRequest {
// This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding
f.handleHostRoaming(hostinfo, via)
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
}
// Fallthrough to the bottom to record incoming traffic
// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
// are unauthenticated
case header.Handshake:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handshakeManager.HandleIncoming(via, packet, h)
return
case header.RecvError:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handleRecvError(via.UdpAddr, h)
return
case header.CloseTunnel:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, via, h) {
return
}
hostinfo.logger(f.l).WithField("from", via).
Info("Close tunnel received, tearing down.")
f.closeTunnel(hostinfo)
return
case header.Control:
if !f.handleEncrypted(ci, via, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("from", via).
WithField("packet", packet).
Error("Failed to decrypt Control packet")
return
}
f.relayManager.HandleControlMsg(hostinfo, d, f)
default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via)
return
}
f.handleHostRoaming(hostinfo, via)
f.connectionManager.In(hostinfo) f.connectionManager.In(hostinfo)
f.connectionManager.RelayUsed(h.RemoteIndex)
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
if !ok {
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
// its internal mapping. This should never happen.
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
return false
}
switch relay.Type {
case TerminalType:
// If I am the target of this relay, process the unwrapped packet
// We need to re-write our variables to ensure this segment is correctly parsed.
// We could set up for a recursive call here, but this makes it easier to prove that we'll never stack-overflow
*via = ViaSender{
UdpAddr: via.UdpAddr,
relayHI: hostinfo,
remoteIdx: relay.RemoteIndex,
relay: relay,
IsRelayed: true,
}
//mirrors the top of readOutsideSegment
err = h.Parse(signedPayload)
if err != nil {
// Hole punch packets are 0 or 1 byte big, so let's ignore printing those errors
if len(signedPayload) > 1 {
f.l.WithField("packet", segment).Infof("Error while parsing inbound packet from %s: %s", via, err)
}
return false
}
*segment = signedPayload
//continue flowing through readOutsideSegment()
return true
case ForwardingType:
// Find the target HostInfo relay object
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
if err != nil {
hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
return false
}
// If that relay is Established, forward the payload through it
if targetRelay.State == Established {
switch targetRelay.Type {
case ForwardingType:
// Forward this packet through the relay tunnel, and find the target HostInfo
f.SendVia(targetHI, targetRelay, signedPayload, nb, scratch[:0], false) //todo it would be nice to queue this up and do it later, or at least avoid a memcpy of signedPayload
case TerminalType:
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
default:
hostinfo.logger(f.l).WithField("targetRelay.Type", targetRelay.Type).Error("Unexpected Relay Type")
}
} else {
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
}
}
return false
} }
func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) { func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
@@ -180,6 +116,11 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe
// verify if we've seen this index before, otherwise respond to the handshake initiation // verify if we've seen this index before, otherwise respond to the handshake initiation
if h.Type == header.Message && h.Subtype == header.MessageRelay { if h.Type == header.Message && h.Subtype == header.MessageRelay {
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex) hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
keepGoing := f.handleRelayPackets(&via, hostinfo, &segment, out.Scratch[:0], h, nb)
if !keepGoing {
return
}
} else { } else {
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex) hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
} }
@@ -198,74 +139,16 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe
switch h.Subtype { switch h.Subtype {
case header.MessageNone: case header.MessageNone:
if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out, segment, fwPacket, nb, q, localCache, now) { if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out, segment, fwPacket, nb, q, localCache, now) {
//todo we've allocated a segment we aren't using.
//Unfortunately, we can't un-allocate it.
//Saving it for "next time" is also problematic.
//todo we need to give the segment back, but we don't want to actually send the packet to the tun. blanking the slice is probably the way to go?
return return
} }
case header.MessageRelay: case header.MessageRelay:
// The entire body is sent as AD, not encrypted. f.l.Error("relayed messages cannot contain relay messages, dropping packet")
// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
// otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice
// which will gracefully fail in the DecryptDanger call.
signedPayload := segment[:len(segment)-hostinfo.ConnectionState.dKey.Overhead()]
signatureValue := segment[len(segment)-hostinfo.ConnectionState.dKey.Overhead():]
out.Scratch, err = hostinfo.ConnectionState.dKey.DecryptDanger(out.Scratch, signedPayload, signatureValue, h.MessageCounter, nb)
if err != nil {
return return
} }
// Successfully validated the thing. Get rid of the Relay header.
signedPayload = signedPayload[header.Len:]
// Pull the Roaming parts up here, and return in all call paths.
f.handleHostRoaming(hostinfo, via)
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
f.connectionManager.In(hostinfo)
f.connectionManager.RelayUsed(h.RemoteIndex)
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
if !ok {
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
// its internal mapping. This should never happen.
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
return
}
switch relay.Type {
case TerminalType:
// If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
via = ViaSender{
UdpAddr: via.UdpAddr,
relayHI: hostinfo,
remoteIdx: relay.RemoteIndex,
relay: relay,
IsRelayed: true,
}
f.readOutsidePacketFromRelay(via, out.Scratch[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now)
return
case ForwardingType:
// Find the target HostInfo relay object
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
if err != nil {
hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
return
}
// If that relay is Established, forward the payload through it
if targetRelay.State == Established {
switch targetRelay.Type {
case ForwardingType:
// Forward this packet through the relay tunnel
// Find the target HostInfo
f.SendVia(targetHI, targetRelay, signedPayload, nb, out.Scratch, false)
return
case TerminalType:
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
}
} else {
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
return
}
}
}
case header.LightHouse: case header.LightHouse:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
@@ -376,12 +259,11 @@ func (f *Interface) readOutsidePacketsMany(packets []*packet.UDPPacket, out []*p
for segment := range pkt.Segments() { for segment := range pkt.Segments() {
f.readOutsideSegment(via, segment, out[i], h, fwPacket, lhf, nb, q, localCache, now) f.readOutsideSegment(via, segment, out[i], h, fwPacket, lhf, nb, q, localCache, now)
}
_, err := f.readers[q].WriteOne(out[i], false, q)
if err != nil {
f.l.WithError(err).Error("Failed to write packet")
} }
//_, err := f.readers[q].WriteOne(out[i], false, q)
//if err != nil {
// f.l.WithError(err).Error("Failed to write packet")
//}
} }
} }