From 188b20457e509b86885045b9635d226a9c355723 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Fri, 19 Dec 2025 13:08:46 -0600 Subject: [PATCH] relay rework --- outside.go | 302 ++++++++++++++++------------------------------------- 1 file changed, 92 insertions(+), 210 deletions(-) diff --git a/outside.go b/outside.go index 0c44cec..3f1b7c4 100644 --- a/outside.go +++ b/outside.go @@ -20,150 +20,86 @@ const ( minFwPacketLen = 4 ) -func (f *Interface) readOutsidePacketFromRelay(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) { - //todo this is way too similar to readOutsidePacketsMany, find a way to eliminate - err := h.Parse(packet) +// handleRelayPackets handles relay packets. Returns false if there's nothing left to do, true for continuing to process an unwrapped TerminalType packet +// scratch must be large enough to contain a packet to be relayed if needed +func (f *Interface) handleRelayPackets(via *ViaSender, hostinfo *HostInfo, segment *[]byte, scratch []byte, h *header.H, nb []byte) bool { + var err error + // The entire body is sent as AD, not encrypted. + // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. + // The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's + // otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice + // which will gracefully fail in the DecryptDanger call. + seg := *segment + signedPayload := seg[:len(*segment)-hostinfo.ConnectionState.dKey.Overhead()] + signatureValue := seg[len(*segment)-hostinfo.ConnectionState.dKey.Overhead():] + scratch, err = hostinfo.ConnectionState.dKey.DecryptDanger(scratch, signedPayload, signatureValue, h.MessageCounter, nb) if err != nil { - // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors - if len(packet) > 1 { - f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err) - } - return + return false } - - //l.Error("in packet ", header, packet[HeaderLen:]) - if !via.IsRelayed { - if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) { - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("from", via).Debug("Refusing to process double encrypted packet") - } - return - } - } - - var hostinfo *HostInfo - // verify if we've seen this index before, otherwise respond to the handshake initiation - if h.Type == header.Message && h.Subtype == header.MessageRelay { - hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex) - } else { - hostinfo = f.hostMap.QueryIndex(h.RemoteIndex) - } - - var ci *ConnectionState - if hostinfo != nil { - ci = hostinfo.ConnectionState - } - - switch h.Type { - case header.Message: - if !f.handleEncrypted(ci, via, h) { - return - } - - switch h.Subtype { - case header.MessageNone: - if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache, now) { - return - } - case header.MessageRelay: - //this packet already came to us via a relay - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("from", via).Debug("Refusing to process double relayed packet") - } - return - } - - case header.LightHouse: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, via, h) { - return - } - - d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) - if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("from", via). - WithField("packet", packet). - Error("Failed to decrypt lighthouse packet") - return - } - - //TODO: assert via is not relayed - lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f) - - // Fallthrough to the bottom to record incoming traffic - - case header.Test: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, via, h) { - return - } - - d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) - if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("from", via). - WithField("packet", packet). - Error("Failed to decrypt test packet") - return - } - - if h.Subtype == header.TestRequest { - // This testRequest might be from TryPromoteBest, so we should roam - // to the new IP address before responding - f.handleHostRoaming(hostinfo, via) - f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out) - } - - // Fallthrough to the bottom to record incoming traffic - - // Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they - // are unauthenticated - - case header.Handshake: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handshakeManager.HandleIncoming(via, packet, h) - return - - case header.RecvError: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handleRecvError(via.UdpAddr, h) - return - - case header.CloseTunnel: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, via, h) { - return - } - - hostinfo.logger(f.l).WithField("from", via). - Info("Close tunnel received, tearing down.") - - f.closeTunnel(hostinfo) - return - - case header.Control: - if !f.handleEncrypted(ci, via, h) { - return - } - - d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) - if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("from", via). - WithField("packet", packet). - Error("Failed to decrypt Control packet") - return - } - - f.relayManager.HandleControlMsg(hostinfo, d, f) - - default: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via) - return - } - - f.handleHostRoaming(hostinfo, via) - + // Successfully validated the thing. Get rid of the Relay header. + signedPayload = signedPayload[header.Len:] + // Pull the Roaming parts up here, and return in all call paths. + f.handleHostRoaming(hostinfo, *via) + // Track usage of both the HostInfo and the Relay for the received & authenticated packet f.connectionManager.In(hostinfo) + f.connectionManager.RelayUsed(h.RemoteIndex) + + relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex) + if !ok { + // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing + // its internal mapping. This should never happen. + hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") + return false + } + + switch relay.Type { + case TerminalType: + // If I am the target of this relay, process the unwrapped packet + // We need to re-write our variables to ensure this segment is correctly parsed. + // We could set up for a recursive call here, but this makes it easier to prove that we'll never stack-overflow + *via = ViaSender{ + UdpAddr: via.UdpAddr, + relayHI: hostinfo, + remoteIdx: relay.RemoteIndex, + relay: relay, + IsRelayed: true, + } + //mirrors the top of readOutsideSegment + err = h.Parse(signedPayload) + if err != nil { + // Hole punch packets are 0 or 1 byte big, so let's ignore printing those errors + if len(signedPayload) > 1 { + f.l.WithField("packet", segment).Infof("Error while parsing inbound packet from %s: %s", via, err) + } + return false + } + *segment = signedPayload + //continue flowing through readOutsideSegment() + return true + case ForwardingType: + // Find the target HostInfo relay object + targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) + if err != nil { + hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip") + return false + } + + // If that relay is Established, forward the payload through it + if targetRelay.State == Established { + switch targetRelay.Type { + case ForwardingType: + // Forward this packet through the relay tunnel, and find the target HostInfo + f.SendVia(targetHI, targetRelay, signedPayload, nb, scratch[:0], false) //todo it would be nice to queue this up and do it later, or at least avoid a memcpy of signedPayload + case TerminalType: + hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") + default: + hostinfo.logger(f.l).WithField("targetRelay.Type", targetRelay.Type).Error("Unexpected Relay Type") + } + } else { + hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") + } + } + return false } func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) { @@ -180,6 +116,11 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe // verify if we've seen this index before, otherwise respond to the handshake initiation if h.Type == header.Message && h.Subtype == header.MessageRelay { hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex) + keepGoing := f.handleRelayPackets(&via, hostinfo, &segment, out.Scratch[:0], h, nb) + if !keepGoing { + return + } + } else { hostinfo = f.hostMap.QueryIndex(h.RemoteIndex) } @@ -198,73 +139,15 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe switch h.Subtype { case header.MessageNone: if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out, segment, fwPacket, nb, q, localCache, now) { + //todo we've allocated a segment we aren't using. + //Unfortunately, we can't un-allocate it. + //Saving it for "next time" is also problematic. + //todo we need to give the segment back, but we don't want to actually send the packet to the tun. blanking the slice is probably the way to go? return } case header.MessageRelay: - // The entire body is sent as AD, not encrypted. - // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. - // The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's - // otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice - // which will gracefully fail in the DecryptDanger call. - signedPayload := segment[:len(segment)-hostinfo.ConnectionState.dKey.Overhead()] - signatureValue := segment[len(segment)-hostinfo.ConnectionState.dKey.Overhead():] - out.Scratch, err = hostinfo.ConnectionState.dKey.DecryptDanger(out.Scratch, signedPayload, signatureValue, h.MessageCounter, nb) - if err != nil { - return - } - // Successfully validated the thing. Get rid of the Relay header. - signedPayload = signedPayload[header.Len:] - // Pull the Roaming parts up here, and return in all call paths. - f.handleHostRoaming(hostinfo, via) - // Track usage of both the HostInfo and the Relay for the received & authenticated packet - f.connectionManager.In(hostinfo) - f.connectionManager.RelayUsed(h.RemoteIndex) - - relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex) - if !ok { - // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing - // its internal mapping. This should never happen. - hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") - return - } - - switch relay.Type { - case TerminalType: - // If I am the target of this relay, process the unwrapped packet - // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. - via = ViaSender{ - UdpAddr: via.UdpAddr, - relayHI: hostinfo, - remoteIdx: relay.RemoteIndex, - relay: relay, - IsRelayed: true, - } - f.readOutsidePacketFromRelay(via, out.Scratch[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now) - return - case ForwardingType: - // Find the target HostInfo relay object - targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) - if err != nil { - hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip") - return - } - - // If that relay is Established, forward the payload through it - if targetRelay.State == Established { - switch targetRelay.Type { - case ForwardingType: - // Forward this packet through the relay tunnel - // Find the target HostInfo - f.SendVia(targetHI, targetRelay, signedPayload, nb, out.Scratch, false) - return - case TerminalType: - hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") - } - } else { - hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") - return - } - } + f.l.Error("relayed messages cannot contain relay messages, dropping packet") + return } case header.LightHouse: @@ -376,12 +259,11 @@ func (f *Interface) readOutsidePacketsMany(packets []*packet.UDPPacket, out []*p for segment := range pkt.Segments() { f.readOutsideSegment(via, segment, out[i], h, fwPacket, lhf, nb, q, localCache, now) - - } - _, err := f.readers[q].WriteOne(out[i], false, q) - if err != nil { - f.l.WithError(err).Error("Failed to write packet") } + //_, err := f.readers[q].WriteOne(out[i], false, q) + //if err != nil { + // f.l.WithError(err).Error("Failed to write packet") + //} } }