BAM! a hit from my spice-weasel. Now TX is zero-copy, at the expense of my sanity!

This commit is contained in:
JackDoan
2025-11-13 16:29:16 -06:00
parent 994bc8c32b
commit e8ea021bdd
12 changed files with 431 additions and 123 deletions

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"os"
"runtime"
"slices"
"github.com/slackhq/nebula/overlay/vhost"
"github.com/slackhq/nebula/overlay/virtqueue"
@@ -31,6 +30,7 @@ type Device struct {
initialized bool
controlFD int
fullTable bool
ReceiveQueue *virtqueue.SplitQueue
TransmitQueue *virtqueue.SplitQueue
}
@@ -123,6 +123,9 @@ func NewDevice(options ...Option) (*Device, error) {
if err = dev.refillReceiveQueue(); err != nil {
return nil, fmt.Errorf("refill receive queue: %w", err)
}
if err = dev.refillTransmitQueue(); err != nil {
return nil, fmt.Errorf("refill receive queue: %w", err)
}
dev.initialized = true
@@ -139,7 +142,7 @@ func NewDevice(options ...Option) (*Device, error) {
// packets.
func (dev *Device) refillReceiveQueue() error {
for {
_, err := dev.ReceiveQueue.OfferInDescriptorChains(1)
_, err := dev.ReceiveQueue.OfferInDescriptorChains()
if err != nil {
if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
// Queue is full, job is done.
@@ -150,6 +153,22 @@ func (dev *Device) refillReceiveQueue() error {
}
}
func (dev *Device) refillTransmitQueue() error {
//for {
// desc, err := dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
// if err != nil {
// if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
// // Queue is full, job is done.
// return nil
// }
// return fmt.Errorf("offer descriptor chain: %w", err)
// } else {
// dev.TransmitQueue.UsedRing().InitOfferSingle(desc, 0)
// }
//}
return nil
}
// Close cleans up the vhost networking device within the kernel and releases
// all resources used for it.
// The implementation will try to release as many resources as possible and
@@ -238,49 +257,67 @@ func truncateBuffers(buffers [][]byte, length int) (out [][]byte) {
return
}
func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) error {
// Prepend the packet with its virtio-net header.
vnethdrBuf := make([]byte, virtio.NetHdrSize+14) //todo WHY
if err := vnethdr.Encode(vnethdrBuf); err != nil {
return fmt.Errorf("encode vnethdr: %w", err)
}
vnethdrBuf[virtio.NetHdrSize+14-2] = 0x86
vnethdrBuf[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype
func (dev *Device) GetPacketForTx() (uint16, []byte, error) {
var err error
var idx uint16
if !dev.fullTable {
chainIndexes, err := dev.TransmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets)
idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
if err == virtqueue.ErrNotEnoughFreeDescriptors {
dev.fullTable = true
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
}
} else {
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
}
if err != nil {
return fmt.Errorf("offer descriptor chain: %w", err)
return 0, nil, fmt.Errorf("transmit queue: %w", err)
}
//todo surely there's something better to do here
buf, err := dev.TransmitQueue.GetDescriptorItem(idx)
if err != nil {
return 0, nil, fmt.Errorf("get descriptor chain: %w", err)
}
return idx, buf, nil
}
for {
txedChains, err := dev.TransmitQueue.BlockAndGetHeadsCapped(context.TODO(), len(chainIndexes))
if err != nil {
func (dev *Device) TransmitPacket(pkt *packet.OutPacket, kick bool) error {
//if pkt.Valid {
if len(pkt.SegmentIDs) == 0 {
return nil
}
for idx := range pkt.SegmentIDs {
segmentID := pkt.SegmentIDs[idx]
dev.TransmitQueue.SetDescSize(segmentID, len(pkt.Segments[idx]))
}
err := dev.TransmitQueue.OfferDescriptorChains(pkt.SegmentIDs, false)
if err != nil {
return fmt.Errorf("offer descriptor chains: %w", err)
}
pkt.Reset()
//}
//if kick {
if err := dev.TransmitQueue.Kick(); err != nil {
return err
}
//}
return nil
}
func (dev *Device) TransmitPackets(pkts []*packet.OutPacket) error {
if len(pkts) == 0 {
return nil
}
for i := range pkts {
if err := dev.TransmitPacket(pkts[i], false); err != nil {
return err
} else if len(txedChains) == 0 {
continue //todo will this ever exit?
}
for _, c := range txedChains {
idx := slices.Index(chainIndexes, c.GetHead())
if idx < 0 {
continue
} else {
_ = dev.TransmitQueue.FreeDescriptorChain(chainIndexes[idx])
chainIndexes[idx] = 0 //todo I hope this works
}
}
done := true //optimism!
for _, x := range chainIndexes {
if x != 0 {
done = false
break
}
}
if done {
return nil
}
}
if err := dev.TransmitQueue.Kick(); err != nil {
return err
}
return nil
}
// TODO: Make above methods cancelable by taking a context.Context argument?
@@ -327,7 +364,7 @@ func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.Us
//shift the buffer out of out:
pkt.Payload = pkt.ChainRefs[0][virtio.NetHdrSize:chains[0].Length]
pkt.Chains = append(pkt.Chains, uint16(chains[0].DescriptorIndex))
pkt.Recycler = dev.ReceiveQueue.RecycleDescriptorChains
pkt.Recycler = dev.ReceiveQueue.OfferDescriptorChains
return 1, nil
//cursor := n - virtio.NetHdrSize
@@ -385,7 +422,7 @@ func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) {
}
// Now that we have copied all buffers, we can recycle the used descriptor chains
//if err = dev.ReceiveQueue.RecycleDescriptorChains(chains); err != nil {
//if err = dev.ReceiveQueue.OfferDescriptorChains(chains); err != nil {
// return 0, err
//}