remove awful per-packet scratch buf

This commit is contained in:
JackDoan
2025-12-19 13:37:36 -06:00
parent 188b20457e
commit aeded87e71
4 changed files with 33 additions and 28 deletions

View File

@@ -291,16 +291,14 @@ func (f *Interface) listenOut(q int) {
h := &header.H{} h := &header.H{}
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
scratch := make([]byte, udp.MTU)
toSend := make([][]byte, batch) toSend := make([][]byte, batch)
li.ListenOut(func(pkts []*packet.UDPPacket) { li.ListenOut(func(pkts []*packet.UDPPacket) {
toSend = toSend[:0] toSend = toSend[:0]
for i := range outPackets {
outPackets[i].SegCounter = 0
}
f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, q, ctCache.Get(f.l), time.Now()) f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, scratch, q, ctCache.Get(f.l), time.Now())
//we opportunistically tx, but try to also send stragglers //we opportunistically tx, but try to also send stragglers
if _, err := f.readers[q].WriteMany(outPackets, q); err != nil { if _, err := f.readers[q].WriteMany(outPackets, q); err != nil {
f.l.WithError(err).Error("Failed to send packets") f.l.WithError(err).Error("Failed to send packets")

View File

@@ -102,7 +102,7 @@ func (f *Interface) handleRelayPackets(via *ViaSender, hostinfo *HostInfo, segme
return false return false
} }
func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) { func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, scratch []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
err := h.Parse(segment) err := h.Parse(segment)
if err != nil { if err != nil {
// Hole punch packets are 0 or 1 byte big, so let's ignore printing those errors // Hole punch packets are 0 or 1 byte big, so let's ignore printing those errors
@@ -116,7 +116,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe
// verify if we've seen this index before, otherwise respond to the handshake initiation // verify if we've seen this index before, otherwise respond to the handshake initiation
if h.Type == header.Message && h.Subtype == header.MessageRelay { if h.Type == header.Message && h.Subtype == header.MessageRelay {
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex) hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
keepGoing := f.handleRelayPackets(&via, hostinfo, &segment, out.Scratch[:0], h, nb) keepGoing := f.handleRelayPackets(&via, hostinfo, &segment, scratch[:0], h, nb)
if !keepGoing { if !keepGoing {
return return
} }
@@ -139,10 +139,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe
switch h.Subtype { switch h.Subtype {
case header.MessageNone: case header.MessageNone:
if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out, segment, fwPacket, nb, q, localCache, now) { if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out, segment, fwPacket, nb, q, localCache, now) {
//todo we've allocated a segment we aren't using. out.DestroyLastSegment() //prevent a rejected segment from being used
//Unfortunately, we can't un-allocate it.
//Saving it for "next time" is also problematic.
//todo we need to give the segment back, but we don't want to actually send the packet to the tun. blanking the slice is probably the way to go?
return return
} }
case header.MessageRelay: case header.MessageRelay:
@@ -156,7 +153,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe
return return
} }
d, err := f.decrypt(hostinfo, h.MessageCounter, out.Scratch, segment, h, nb) d, err := f.decrypt(hostinfo, h.MessageCounter, scratch, segment, h, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", via.UdpAddr). hostinfo.logger(f.l).WithError(err).WithField("udpAddr", via.UdpAddr).
WithField("packet", segment). WithField("packet", segment).
@@ -174,7 +171,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe
return return
} }
d, err := f.decrypt(hostinfo, h.MessageCounter, out.Scratch, segment, h, nb) d, err := f.decrypt(hostinfo, h.MessageCounter, scratch, segment, h, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", via). hostinfo.logger(f.l).WithError(err).WithField("udpAddr", via).
WithField("packet", segment). WithField("packet", segment).
@@ -186,7 +183,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe
// This testRequest might be from TryPromoteBest, so we should roam // This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding // to the new IP address before responding
f.handleHostRoaming(hostinfo, via) f.handleHostRoaming(hostinfo, via)
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out.Scratch) f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, scratch)
} }
// Fallthrough to the bottom to record incoming traffic // Fallthrough to the bottom to record incoming traffic
@@ -221,7 +218,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe
return return
} }
d, err := f.decrypt(hostinfo, h.MessageCounter, out.Scratch, segment, h, nb) d, err := f.decrypt(hostinfo, h.MessageCounter, scratch, segment, h, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", via). hostinfo.logger(f.l).WithError(err).WithField("udpAddr", via).
WithField("packet", segment). WithField("packet", segment).
@@ -242,9 +239,9 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe
f.connectionManager.In(hostinfo) f.connectionManager.In(hostinfo)
} }
func (f *Interface) readOutsidePacketsMany(packets []*packet.UDPPacket, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) { func (f *Interface) readOutsidePacketsMany(packets []*packet.UDPPacket, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, scratch []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
for i, pkt := range packets { for i, pkt := range packets {
out[i].Scratch = out[i].Scratch[:0] scratch = scratch[:0]
via := ViaSender{UdpAddr: pkt.AddrPort()} via := ViaSender{UdpAddr: pkt.AddrPort()}
//l.Error("in packet ", header, packet[HeaderLen:]) //l.Error("in packet ", header, packet[HeaderLen:])
@@ -258,7 +255,7 @@ func (f *Interface) readOutsidePacketsMany(packets []*packet.UDPPacket, out []*p
} }
for segment := range pkt.Segments() { for segment := range pkt.Segments() {
f.readOutsideSegment(via, segment, out[i], h, fwPacket, lhf, nb, q, localCache, now) f.readOutsideSegment(via, segment, out[i], h, fwPacket, lhf, nb, scratch, q, localCache, now)
} }
//_, err := f.readers[q].WriteOne(out[i], false, q) //_, err := f.readers[q].WriteOne(out[i], false, q)
//if err != nil { //if err != nil {

View File

@@ -122,7 +122,7 @@ func NewDevice(options ...Option) (*Device, error) {
return nil, fmt.Errorf("refill receive queue: %w", err) return nil, fmt.Errorf("refill receive queue: %w", err)
} }
if err = dev.prefillTxQueue(); err != nil { if err = dev.prefillTxQueue(); err != nil {
return nil, fmt.Errorf("refill tx queue: %w", err) return nil, fmt.Errorf("prefill tx queue: %w", err)
} }
// Make sure to clean up even when the device gets garbage collected without // Make sure to clean up even when the device gets garbage collected without

View File

@@ -6,15 +6,14 @@ import (
) )
type OutPacket struct { type OutPacket struct {
Segments [][]byte Segments [][]byte
// SegmentHeaders maps to the first virtio.NetHdrSize+14 bytes of Segments[n]
SegmentHeaders [][]byte
// SegmentPayloads maps to the remaining bytes of Segments[n]
SegmentPayloads [][]byte SegmentPayloads [][]byte
SegmentHeaders [][]byte // SegmentIDs is the list of underlying buffer IDs of Segments.
SegmentIDs []uint16 // SegmentIDs, Segments, SegmentHeaders, SegmentPayloads should all have the same length at all times!
SegmentIDs []uint16
SegSize int
SegCounter int
Scratch []byte
} }
func NewOut() *OutPacket { func NewOut() *OutPacket {
@@ -23,7 +22,6 @@ func NewOut() *OutPacket {
out.SegmentHeaders = make([][]byte, 0, 64) out.SegmentHeaders = make([][]byte, 0, 64)
out.SegmentPayloads = make([][]byte, 0, 64) out.SegmentPayloads = make([][]byte, 0, 64)
out.SegmentIDs = make([]uint16, 0, 64) out.SegmentIDs = make([]uint16, 0, 64)
out.Scratch = make([]byte, Size)
return out return out
} }
@@ -32,7 +30,19 @@ func (pkt *OutPacket) Reset() {
pkt.SegmentPayloads = pkt.SegmentPayloads[:0] pkt.SegmentPayloads = pkt.SegmentPayloads[:0]
pkt.SegmentHeaders = pkt.SegmentHeaders[:0] pkt.SegmentHeaders = pkt.SegmentHeaders[:0]
pkt.SegmentIDs = pkt.SegmentIDs[:0] pkt.SegmentIDs = pkt.SegmentIDs[:0]
pkt.SegSize = 0 }
// DestroyLastSegment removes the contents of the last segment in the list.
// Use this to handle firewall drops or similar, but still hand the segment buffer back to the underlying driver.
// Implementations shall discard zero-length segments internally.
func (pkt *OutPacket) DestroyLastSegment() {
if len(pkt.Segments) == 0 {
return
}
lastSeg := len(pkt.SegmentIDs) - 1
pkt.SegmentPayloads[lastSeg] = pkt.SegmentPayloads[lastSeg][:0]
pkt.SegmentHeaders[lastSeg] = pkt.SegmentHeaders[lastSeg][:0]
pkt.Segments[lastSeg] = pkt.Segments[lastSeg][:0]
} }
func (pkt *OutPacket) UseSegment(segID uint16, seg []byte, isV6 bool) int { func (pkt *OutPacket) UseSegment(segID uint16, seg []byte, isV6 bool) int {