pre-fill the tx ring with bogus empty packets

This commit is contained in:
JackDoan
2025-12-18 15:18:19 -06:00
parent 726e282d0a
commit 111efc0779
7 changed files with 57 additions and 50 deletions

View File

@@ -455,11 +455,11 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
return ErrPacketTooShort return ErrPacketTooShort
} }
version := int((data[0] >> 4) & 0x0f) //version := int((data[0] >> 4) & 0x0f)
switch version { switch data[0] & 0xf0 {
case ipv4.Version: case ipv4.Version << 4:
return parseV4(data, incoming, fp) return parseV4(data, incoming, fp)
case ipv6.Version: case ipv6.Version << 4:
return parseV6(data, incoming, fp) return parseV6(data, incoming, fp)
} }
return ErrUnknownIPVersion return ErrUnknownIPVersion

View File

@@ -806,8 +806,11 @@ func (t *tun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
} }
func (t *tun) RecycleRxSeg(pkt TunPacket, kick bool, q int) error { func (t *tun) RecycleRxSeg(pkt TunPacket, kick bool, q int) error {
if pkt.GetPayload() == nil {
return nil
}
vpkt := pkt.(*vhostnet.VirtIOPacket) vpkt := pkt.(*vhostnet.VirtIOPacket)
err := t.vdev[q].ReceiveQueue.OfferDescriptorChains(vpkt.Chains, kick) err := t.vdev[q].ReceiveQueue.OfferDescriptorChains([]uint16{vpkt.Chain}, kick)
vpkt.Reset() //intentionally ignoring err! vpkt.Reset() //intentionally ignoring err!
return err return err
} }

View File

@@ -123,6 +123,9 @@ func NewDevice(options ...Option) (*Device, error) {
if err = dev.refillReceiveQueue(); err != nil { if err = dev.refillReceiveQueue(); err != nil {
return nil, fmt.Errorf("refill receive queue: %w", err) return nil, fmt.Errorf("refill receive queue: %w", err)
} }
if err = dev.prefillTxQueue(); err != nil {
return nil, fmt.Errorf("refill tx queue: %w", err)
}
dev.initialized = true dev.initialized = true
@@ -150,6 +153,27 @@ func (dev *Device) refillReceiveQueue() error {
} }
} }
func (dev *Device) prefillTxQueue() error {
for {
dt := dev.TransmitQueue.DescriptorTable()
for {
x, _, err := dt.CreateDescriptorForOutputs()
if err != nil {
if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
// Queue is full, job is done.
return nil
}
return err
}
err = dev.TransmitQueue.OfferDescriptorChains([]uint16{x}, false)
if err != nil {
return err
}
}
}
}
// Close cleans up the vhost networking device within the kernel and releases // Close cleans up the vhost networking device within the kernel and releases
// all resources used for it. // all resources used for it.
// The implementation will try to release as many resources as possible and // The implementation will try to release as many resources as possible and
@@ -211,15 +235,16 @@ func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*v
func (dev *Device) GetPacketForTx() (uint16, []byte, error) { func (dev *Device) GetPacketForTx() (uint16, []byte, error) {
var err error var err error
var idx uint16 var idx uint16
if !dev.fullTable { //if !dev.fullTable {
idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs() // idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
if err == virtqueue.ErrNotEnoughFreeDescriptors { // if err == virtqueue.ErrNotEnoughFreeDescriptors {
dev.fullTable = true // dev.fullTable = true
idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO()) // idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO())
} // }
} else { //} else {
idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO()) // idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO())
} //}
idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO())
if err != nil { if err != nil {
return 0, nil, fmt.Errorf("transmit queue: %w", err) return 0, nil, fmt.Errorf("transmit queue: %w", err)
} }
@@ -304,26 +329,26 @@ func (dev *Device) ProcessRxChain(pkt *VirtIOPacket, chain virtqueue.UsedElement
//shift the buffer out of out: //shift the buffer out of out:
pkt.payload = buf[virtio.NetHdrSize:chain.Length] pkt.payload = buf[virtio.NetHdrSize:chain.Length]
pkt.Chains = append(pkt.Chains, idx) pkt.Chain = idx
return 1, nil return 1, nil
} }
type VirtIOPacket struct { type VirtIOPacket struct {
payload []byte payload []byte
//header virtio.NetHdr //header virtio.NetHdr
Chains []uint16 Chain uint16
} }
func NewVIO() *VirtIOPacket { func NewVIO() *VirtIOPacket {
out := new(VirtIOPacket) out := new(VirtIOPacket)
out.payload = nil out.payload = nil
out.Chains = make([]uint16, 0, 8) out.Chain = 0
return out return out
} }
func (v *VirtIOPacket) Reset() { func (v *VirtIOPacket) Reset() {
v.payload = nil v.payload = nil
v.Chains = v.Chains[:0] v.Chain = 0
} }
func (v *VirtIOPacket) GetPayload() []byte { func (v *VirtIOPacket) GetPayload() []byte {

View File

@@ -168,12 +168,12 @@ func (dt *DescriptorTable) releaseBuffers() error {
return nil return nil
} }
func (dt *DescriptorTable) CreateDescriptorForOutputs() (uint16, error) { func (dt *DescriptorTable) CreateDescriptorForOutputs() (uint16, uint32, error) {
//todo just fill the damn table //todo just fill the damn table
// Do we still have enough free descriptors? // Do we still have enough free descriptors?
if 1 > dt.freeNum { if 1 > dt.freeNum {
return 0, ErrNotEnoughFreeDescriptors return 0, 0, ErrNotEnoughFreeDescriptors
} }
// Above validation ensured that there is at least one free descriptor, so // Above validation ensured that there is at least one free descriptor, so
@@ -216,7 +216,7 @@ func (dt *DescriptorTable) CreateDescriptorForOutputs() (uint16, error) {
dt.descriptors[dt.freeHeadIndex].next = next dt.descriptors[dt.freeHeadIndex].next = next
} }
return head, nil return head, desc.length, nil
} }
func (dt *DescriptorTable) createDescriptorForInputs() (uint16, error) { func (dt *DescriptorTable) createDescriptorForInputs() (uint16, error) {

View File

@@ -301,7 +301,6 @@ func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int)
// and any further calls to [SplitQueue.OfferDescriptorChain] will stall. // and any further calls to [SplitQueue.OfferDescriptorChain] will stall.
func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) { func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) {
// Create a descriptor chain for the given buffers.
var ( var (
head uint16 head uint16
err error err error
@@ -350,18 +349,6 @@ func (sq *SplitQueue) SetDescSize(head uint16, sz int) {
sq.descriptorTable.descriptors[int(head)].length = uint32(sz) sq.descriptorTable.descriptors[int(head)].length = uint32(sz)
} }
func (sq *SplitQueue) OfferDescriptor(chain uint16, kick bool) error {
// Make the descriptor chain available to the device.
sq.availableRing.offerSingle(chain)
// Notify the device to make it process the updated available ring.
if kick {
return sq.Kick()
}
return nil
}
func (sq *SplitQueue) OfferDescriptorChains(chains []uint16, kick bool) error { func (sq *SplitQueue) OfferDescriptorChains(chains []uint16, kick bool) error {
// Make the descriptor chain available to the device. // Make the descriptor chain available to the device.
sq.availableRing.offer(chains) sq.availableRing.offer(chains)

View File

@@ -153,11 +153,7 @@ func (r *UsedRing) takeOne() (UsedElement, bool) {
} }
// InitOfferSingle is only used to pre-fill the used queue at startup, and should not be used if the device is running! // InitOfferSingle is only used to pre-fill the used queue at startup, and should not be used if the device is running!
func (r *UsedRing) InitOfferSingle(x uint16, size int) { func (r *UsedRing) InitOfferSingle(x uint16, size uint32) {
//always called under lock
//r.mu.Lock()
//defer r.mu.Unlock()
offset := 0 offset := 0
// Add descriptor chain heads to the ring. // Add descriptor chain heads to the ring.
@@ -166,10 +162,8 @@ func (r *UsedRing) InitOfferSingle(x uint16, size int) {
// size) is always a power of 2 and smaller than the highest possible // size) is always a power of 2 and smaller than the highest possible
// 16-bit value. // 16-bit value.
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring) insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
r.ring[insertIndex] = UsedElement{ r.ring[insertIndex].DescriptorIndex = uint32(x)
DescriptorIndex: uint32(x), r.ring[insertIndex].Length = size
Length: uint32(size),
}
// Increase the ring index by the number of descriptor chains added to the ring. // Increase the ring index by the number of descriptor chains added to the ring.
*r.ringIndex += 1 *r.ringIndex += 1

View File

@@ -87,17 +87,15 @@ func (p *UDPPacket) updateCtrl(ctrlLen int) {
if len(p.Control) == 0 { if len(p.Control) == 0 {
return return
} }
cmsgs, err := unix.ParseSocketControlMessage(p.Control) header, data, _ /*remain*/, err := unix.ParseOneSocketControlMessage(p.Control)
if err != nil { if err != nil {
return // oh well return // oh well
} }
for _, c := range cmsgs { if header.Level == unix.SOL_UDP && header.Type == unix.UDP_GRO && len(data) >= 2 {
if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 { p.wasSegmented = true
p.wasSegmented = true p.SegSize = int(binary.LittleEndian.Uint16(data[:2]))
p.SegSize = int(binary.LittleEndian.Uint16(c.Data[:2])) return
return
}
} }
} }