From 530cf6b3b812364519c609a4487a55a9c72c77a6 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Fri, 19 Dec 2025 14:38:12 -0600 Subject: [PATCH] checkpt, heap problems --- interface.go | 23 ++++++++--- outside.go | 105 +++++++++++++++------------------------------------ 2 files changed, 49 insertions(+), 79 deletions(-) diff --git a/interface.go b/interface.go index 739f4d0..cb53cca 100644 --- a/interface.go +++ b/interface.go @@ -270,6 +270,22 @@ func (f *Interface) run() { } } +type Scratches struct { + h *header.H + nb []byte + fwPacket *firewall.Packet + scratch []byte +} + +func NewScratches() *Scratches { + return &Scratches{ + h: &header.H{}, + fwPacket: &firewall.Packet{}, + nb: make([]byte, 12), + scratch: make([]byte, udp.MTU), + } +} + func (f *Interface) listenOut(q int) { runtime.LockOSThread() @@ -288,17 +304,14 @@ func (f *Interface) listenOut(q int) { outPackets[i] = packet.NewOut() } - h := &header.H{} - fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) - scratch := make([]byte, udp.MTU) + scratches := NewScratches() toSend := make([][]byte, batch) li.ListenOut(func(pkts []*packet.UDPPacket) { toSend = toSend[:0] - f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, scratch, q, ctCache.Get(f.l), time.Now()) + f.readOutsidePacketsMany(pkts, outPackets, lhh, scratches, q, ctCache.Get(f.l), time.Now()) //we opportunistically tx, but try to also send stragglers if _, err := f.readers[q].WriteMany(outPackets, q); err != nil { f.l.WithError(err).Error("Failed to send packets") diff --git a/outside.go b/outside.go index 96d5d3b..487c90c 100644 --- a/outside.go +++ b/outside.go @@ -22,7 +22,7 @@ const ( // handleRelayPackets handles relay packets. Returns false if there's nothing left to do, true for continuing to process an unwrapped TerminalType packet // scratch must be large enough to contain a packet to be relayed if needed -func (f *Interface) handleRelayPackets(via *ViaSender, hostinfo *HostInfo, segment *[]byte, scratch []byte, h *header.H, nb []byte) bool { +func (f *Interface) handleRelayPackets(via ViaSender, hostinfo *HostInfo, segment *[]byte, scratch []byte, h *header.H, nb []byte) (*ViaSender, bool) { var err error // The entire body is sent as AD, not encrypted. // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. @@ -34,12 +34,12 @@ func (f *Interface) handleRelayPackets(via *ViaSender, hostinfo *HostInfo, segme signatureValue := seg[len(*segment)-hostinfo.ConnectionState.dKey.Overhead():] scratch, err = hostinfo.ConnectionState.dKey.DecryptDanger(scratch, signedPayload, signatureValue, h.MessageCounter, nb) if err != nil { - return false + return nil, false } // Successfully validated the thing. Get rid of the Relay header. signedPayload = signedPayload[header.Len:] // Pull the Roaming parts up here, and return in all call paths. - f.handleHostRoaming(hostinfo, *via) + f.handleHostRoaming(hostinfo, via) // Track usage of both the HostInfo and the Relay for the received & authenticated packet f.connectionManager.In(hostinfo) f.connectionManager.RelayUsed(h.RemoteIndex) @@ -49,7 +49,7 @@ func (f *Interface) handleRelayPackets(via *ViaSender, hostinfo *HostInfo, segme // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing // its internal mapping. This should never happen. hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") - return false + return nil, false } switch relay.Type { @@ -57,13 +57,7 @@ func (f *Interface) handleRelayPackets(via *ViaSender, hostinfo *HostInfo, segme // If I am the target of this relay, process the unwrapped packet // We need to re-write our variables to ensure this segment is correctly parsed. // We could set up for a recursive call here, but this makes it easier to prove that we'll never stack-overflow - *via = ViaSender{ - UdpAddr: via.UdpAddr, - relayHI: hostinfo, - remoteIdx: relay.RemoteIndex, - relay: relay, - IsRelayed: true, - } + //mirrors the top of readOutsideSegment err = h.Parse(signedPayload) if err != nil { @@ -71,17 +65,24 @@ func (f *Interface) handleRelayPackets(via *ViaSender, hostinfo *HostInfo, segme if len(signedPayload) > 1 { f.l.WithField("packet", segment).Infof("Error while parsing inbound packet from %s: %s", via, err) } - return false + return nil, false + } + newVia := &ViaSender{ + UdpAddr: via.UdpAddr, + relayHI: hostinfo, + remoteIdx: relay.RemoteIndex, + relay: relay, + IsRelayed: true, } *segment = signedPayload //continue flowing through readOutsideSegment() - return true + return newVia, true case ForwardingType: // Find the target HostInfo relay object targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) if err != nil { hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip") - return false + return nil, false } // If that relay is Established, forward the payload through it @@ -99,10 +100,11 @@ func (f *Interface) handleRelayPackets(via *ViaSender, hostinfo *HostInfo, segme hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") } } - return false + return nil, false } -func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, scratch []byte, q int, localCache firewall.ConntrackCache, now time.Time) { +func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packet.OutPacket, lhf *LightHouseHandler, s *Scratches, q int, localCache firewall.ConntrackCache, now time.Time) { + h := s.h err := h.Parse(segment) if err != nil { // Hole punch packets are 0 or 1 byte big, so let's ignore printing those errors @@ -116,10 +118,11 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe // verify if we've seen this index before, otherwise respond to the handshake initiation if h.Type == header.Message && h.Subtype == header.MessageRelay { hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex) - keepGoing := f.handleRelayPackets(&via, hostinfo, &segment, scratch[:0], h, nb) + newVia, keepGoing := f.handleRelayPackets(via, hostinfo, &segment, s.scratch, h, s.nb) if !keepGoing { return } + via = *newVia } else { hostinfo = f.hostMap.QueryIndex(h.RemoteIndex) @@ -138,7 +141,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe switch h.Subtype { case header.MessageNone: - if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out, segment, fwPacket, nb, q, localCache, now) { + if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out, segment, s.fwPacket, s.nb, q, localCache, now) { out.DestroyLastSegment() //prevent a rejected segment from being used return } @@ -153,7 +156,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe return } - d, err := f.decrypt(hostinfo, h.MessageCounter, scratch, segment, h, nb) + d, err := f.decrypt(hostinfo, h.MessageCounter, s.scratch, segment, h, s.nb) if err != nil { hostinfo.logger(f.l).WithError(err).WithField("udpAddr", via.UdpAddr). WithField("packet", segment). @@ -171,7 +174,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe return } - d, err := f.decrypt(hostinfo, h.MessageCounter, scratch, segment, h, nb) + d, err := f.decrypt(hostinfo, h.MessageCounter, s.scratch, segment, h, s.nb) if err != nil { hostinfo.logger(f.l).WithError(err).WithField("udpAddr", via). WithField("packet", segment). @@ -183,7 +186,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe // This testRequest might be from TryPromoteBest, so we should roam // to the new IP address before responding f.handleHostRoaming(hostinfo, via) - f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, scratch) + f.send(header.Test, header.TestReply, ci, hostinfo, d, s.nb, s.scratch) } // Fallthrough to the bottom to record incoming traffic @@ -218,7 +221,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe return } - d, err := f.decrypt(hostinfo, h.MessageCounter, scratch, segment, h, nb) + d, err := f.decrypt(hostinfo, h.MessageCounter, s.scratch, segment, h, s.nb) if err != nil { hostinfo.logger(f.l).WithError(err).WithField("udpAddr", via). WithField("packet", segment). @@ -239,23 +242,20 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe f.connectionManager.In(hostinfo) } -func (f *Interface) readOutsidePacketsMany(packets []*packet.UDPPacket, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, scratch []byte, q int, localCache firewall.ConntrackCache, now time.Time) { +func (f *Interface) readOutsidePacketsMany(packets []*packet.UDPPacket, out []*packet.OutPacket, lhf *LightHouseHandler, s *Scratches, q int, localCache firewall.ConntrackCache, now time.Time) { for i, pkt := range packets { - scratch = scratch[:0] via := ViaSender{UdpAddr: pkt.AddrPort()} //l.Error("in packet ", header, packet[HeaderLen:]) - if !via.IsRelayed { - if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) { - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("from", via).Debug("Refusing to process double encrypted packet") - } - return + if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) { + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("from", via).Debug("Refusing to process double encrypted packet") } + return } for segment := range pkt.Segments() { - f.readOutsideSegment(via, segment, out[i], h, fwPacket, lhf, nb, scratch, q, localCache, now) + f.readOutsideSegment(via, segment, out[i], lhf, s, q, localCache, now) } //_, err := f.readers[q].WriteOne(out[i], false, q) //if err != nil { @@ -560,49 +560,6 @@ func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter ui return true } -func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool { - var err error - - out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) - if err != nil { - hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") - return false - } - - err = newPacket(out, true, fwPacket) - if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("packet", out). - Warnf("Error while validating inbound packet") - return false - } - - if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) { - hostinfo.logger(f.l).WithField("fwPacket", fwPacket). - Debugln("dropping out of window packet") - return false - } - - dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache, now) - if dropReason != nil { - // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore - // This gives us a buffer to build the reject packet in - f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q) - if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).WithField("fwPacket", fwPacket). - WithField("reason", dropReason). - Debugln("dropping inbound packet") - } - return false - } - - f.connectionManager.In(hostinfo) - _, err = f.readers[q].Write(out) - if err != nil { - f.l.WithError(err).Error("Failed to write to tun") - } - return true -} - func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) { if f.sendRecvErrorConfig.ShouldSendRecvError(endpoint) { f.sendRecvError(endpoint, index)