From aeded87e716d62dc6b5dd2a3e49c512e9f9ab12c Mon Sep 17 00:00:00 2001 From: JackDoan Date: Fri, 19 Dec 2025 13:37:36 -0600 Subject: [PATCH] remove awful per-packet scratch buf --- interface.go | 6 ++---- outside.go | 23 ++++++++++------------- overlay/vhostnet/device.go | 2 +- packet/outpacket.go | 30 ++++++++++++++++++++---------- 4 files changed, 33 insertions(+), 28 deletions(-) diff --git a/interface.go b/interface.go index 5cde576..739f4d0 100644 --- a/interface.go +++ b/interface.go @@ -291,16 +291,14 @@ func (f *Interface) listenOut(q int) { h := &header.H{} fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) + scratch := make([]byte, udp.MTU) toSend := make([][]byte, batch) li.ListenOut(func(pkts []*packet.UDPPacket) { toSend = toSend[:0] - for i := range outPackets { - outPackets[i].SegCounter = 0 - } - f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, q, ctCache.Get(f.l), time.Now()) + f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, scratch, q, ctCache.Get(f.l), time.Now()) //we opportunistically tx, but try to also send stragglers if _, err := f.readers[q].WriteMany(outPackets, q); err != nil { f.l.WithError(err).Error("Failed to send packets") diff --git a/outside.go b/outside.go index 3f1b7c4..96d5d3b 100644 --- a/outside.go +++ b/outside.go @@ -102,7 +102,7 @@ func (f *Interface) handleRelayPackets(via *ViaSender, hostinfo *HostInfo, segme return false } -func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) { +func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, scratch []byte, q int, localCache firewall.ConntrackCache, now time.Time) { err := h.Parse(segment) if err != nil { // Hole punch packets are 0 or 1 byte big, so let's ignore printing those errors @@ -116,7 +116,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe // verify if we've seen this index before, otherwise respond to the handshake initiation if h.Type == header.Message && h.Subtype == header.MessageRelay { hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex) - keepGoing := f.handleRelayPackets(&via, hostinfo, &segment, out.Scratch[:0], h, nb) + keepGoing := f.handleRelayPackets(&via, hostinfo, &segment, scratch[:0], h, nb) if !keepGoing { return } @@ -139,10 +139,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe switch h.Subtype { case header.MessageNone: if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out, segment, fwPacket, nb, q, localCache, now) { - //todo we've allocated a segment we aren't using. - //Unfortunately, we can't un-allocate it. - //Saving it for "next time" is also problematic. - //todo we need to give the segment back, but we don't want to actually send the packet to the tun. blanking the slice is probably the way to go? + out.DestroyLastSegment() //prevent a rejected segment from being used return } case header.MessageRelay: @@ -156,7 +153,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe return } - d, err := f.decrypt(hostinfo, h.MessageCounter, out.Scratch, segment, h, nb) + d, err := f.decrypt(hostinfo, h.MessageCounter, scratch, segment, h, nb) if err != nil { hostinfo.logger(f.l).WithError(err).WithField("udpAddr", via.UdpAddr). WithField("packet", segment). @@ -174,7 +171,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe return } - d, err := f.decrypt(hostinfo, h.MessageCounter, out.Scratch, segment, h, nb) + d, err := f.decrypt(hostinfo, h.MessageCounter, scratch, segment, h, nb) if err != nil { hostinfo.logger(f.l).WithError(err).WithField("udpAddr", via). WithField("packet", segment). @@ -186,7 +183,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe // This testRequest might be from TryPromoteBest, so we should roam // to the new IP address before responding f.handleHostRoaming(hostinfo, via) - f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out.Scratch) + f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, scratch) } // Fallthrough to the bottom to record incoming traffic @@ -221,7 +218,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe return } - d, err := f.decrypt(hostinfo, h.MessageCounter, out.Scratch, segment, h, nb) + d, err := f.decrypt(hostinfo, h.MessageCounter, scratch, segment, h, nb) if err != nil { hostinfo.logger(f.l).WithError(err).WithField("udpAddr", via). WithField("packet", segment). @@ -242,9 +239,9 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe f.connectionManager.In(hostinfo) } -func (f *Interface) readOutsidePacketsMany(packets []*packet.UDPPacket, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) { +func (f *Interface) readOutsidePacketsMany(packets []*packet.UDPPacket, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, scratch []byte, q int, localCache firewall.ConntrackCache, now time.Time) { for i, pkt := range packets { - out[i].Scratch = out[i].Scratch[:0] + scratch = scratch[:0] via := ViaSender{UdpAddr: pkt.AddrPort()} //l.Error("in packet ", header, packet[HeaderLen:]) @@ -258,7 +255,7 @@ func (f *Interface) readOutsidePacketsMany(packets []*packet.UDPPacket, out []*p } for segment := range pkt.Segments() { - f.readOutsideSegment(via, segment, out[i], h, fwPacket, lhf, nb, q, localCache, now) + f.readOutsideSegment(via, segment, out[i], h, fwPacket, lhf, nb, scratch, q, localCache, now) } //_, err := f.readers[q].WriteOne(out[i], false, q) //if err != nil { diff --git a/overlay/vhostnet/device.go b/overlay/vhostnet/device.go index 2ab728b..b1eec44 100644 --- a/overlay/vhostnet/device.go +++ b/overlay/vhostnet/device.go @@ -122,7 +122,7 @@ func NewDevice(options ...Option) (*Device, error) { return nil, fmt.Errorf("refill receive queue: %w", err) } if err = dev.prefillTxQueue(); err != nil { - return nil, fmt.Errorf("refill tx queue: %w", err) + return nil, fmt.Errorf("prefill tx queue: %w", err) } // Make sure to clean up even when the device gets garbage collected without diff --git a/packet/outpacket.go b/packet/outpacket.go index ae2cc51..96b95cc 100644 --- a/packet/outpacket.go +++ b/packet/outpacket.go @@ -6,15 +6,14 @@ import ( ) type OutPacket struct { - Segments [][]byte + Segments [][]byte + // SegmentHeaders maps to the first virtio.NetHdrSize+14 bytes of Segments[n] + SegmentHeaders [][]byte + // SegmentPayloads maps to the remaining bytes of Segments[n] SegmentPayloads [][]byte - SegmentHeaders [][]byte - SegmentIDs []uint16 - - SegSize int - SegCounter int - - Scratch []byte + // SegmentIDs is the list of underlying buffer IDs of Segments. + // SegmentIDs, Segments, SegmentHeaders, SegmentPayloads should all have the same length at all times! + SegmentIDs []uint16 } func NewOut() *OutPacket { @@ -23,7 +22,6 @@ func NewOut() *OutPacket { out.SegmentHeaders = make([][]byte, 0, 64) out.SegmentPayloads = make([][]byte, 0, 64) out.SegmentIDs = make([]uint16, 0, 64) - out.Scratch = make([]byte, Size) return out } @@ -32,7 +30,19 @@ func (pkt *OutPacket) Reset() { pkt.SegmentPayloads = pkt.SegmentPayloads[:0] pkt.SegmentHeaders = pkt.SegmentHeaders[:0] pkt.SegmentIDs = pkt.SegmentIDs[:0] - pkt.SegSize = 0 +} + +// DestroyLastSegment removes the contents of the last segment in the list. +// Use this to handle firewall drops or similar, but still hand the segment buffer back to the underlying driver. +// Implementations shall discard zero-length segments internally. +func (pkt *OutPacket) DestroyLastSegment() { + if len(pkt.Segments) == 0 { + return + } + lastSeg := len(pkt.SegmentIDs) - 1 + pkt.SegmentPayloads[lastSeg] = pkt.SegmentPayloads[lastSeg][:0] + pkt.SegmentHeaders[lastSeg] = pkt.SegmentHeaders[lastSeg][:0] + pkt.Segments[lastSeg] = pkt.Segments[lastSeg][:0] } func (pkt *OutPacket) UseSegment(segID uint16, seg []byte, isV6 bool) int {