BAM! a hit from my spice-weasel. Now TX is zero-copy, at the expense of my sanity!

This commit is contained in:
JackDoan
2025-11-13 16:29:16 -06:00
parent 994bc8c32b
commit e8ea021bdd
12 changed files with 431 additions and 123 deletions

View File

@@ -17,7 +17,11 @@ const DefaultMTU = 1300
type TunDev interface {
io.WriteCloser
ReadMany(x []*packet.VirtIOPacket, q int) (int, error)
WriteMany(x [][]byte, q int) (int, error)
//todo this interface sux
AllocSeg(pkt *packet.OutPacket, q int) (int, error)
WriteOne(x *packet.OutPacket, kick bool, q int) (int, error)
WriteMany(x []*packet.OutPacket, q int) (int, error)
}
// TODO: We may be able to remove routines

View File

@@ -111,16 +111,16 @@ func (t *disabledTun) Write(b []byte) (int, error) {
return len(b), nil
}
func (t *disabledTun) WriteMany(b [][]byte, _ int) (int, error) {
out := 0
for i := range b {
x, err := t.Write(b[i])
if err != nil {
return out, err
}
out += x
}
return out, nil
func (t *disabledTun) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
return 0, fmt.Errorf("tun_disabled: AllocSeg not implemented")
}
func (t *disabledTun) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
return 0, fmt.Errorf("tun_disabled: WriteOne not implemented")
}
func (t *disabledTun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
return 0, fmt.Errorf("tun_disabled: WriteMany not implemented")
}
func (t *disabledTun) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {

View File

@@ -717,17 +717,15 @@ func (t *tun) ReadMany(p []*packet.VirtIOPacket, q int) (int, error) {
func (t *tun) Write(b []byte) (int, error) {
maximum := len(b) //we are RXing
hdr := virtio.NetHdr{ //todo
Flags: unix.VIRTIO_NET_HDR_F_DATA_VALID,
GSOType: unix.VIRTIO_NET_HDR_GSO_NONE,
HdrLen: 0,
GSOSize: 0,
CsumStart: 0,
CsumOffset: 0,
NumBuffers: 0,
//todo garbagey
out := packet.NewOut()
x, err := t.AllocSeg(out, 0)
if err != nil {
return 0, err
}
copy(out.SegmentPayloads[x], b)
err = t.vdev[0].TransmitPacket(out, true)
err := t.vdev[0].TransmitPackets(hdr, [][]byte{b})
if err != nil {
t.l.WithError(err).Error("Transmitting packet")
return 0, err
@@ -735,22 +733,30 @@ func (t *tun) Write(b []byte) (int, error) {
return maximum, nil
}
func (t *tun) WriteMany(b [][]byte, q int) (int, error) {
maximum := len(b) //we are RXing
func (t *tun) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
idx, buf, err := t.vdev[q].GetPacketForTx()
if err != nil {
return 0, err
}
x := pkt.UseSegment(idx, buf)
return x, nil
}
func (t *tun) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
if err := t.vdev[q].TransmitPacket(x, kick); err != nil {
t.l.WithError(err).Error("Transmitting packet")
return 0, err
}
return 1, nil
}
func (t *tun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
maximum := len(x) //we are RXing
if maximum == 0 {
return 0, nil
}
hdr := virtio.NetHdr{ //todo
Flags: unix.VIRTIO_NET_HDR_F_DATA_VALID,
GSOType: unix.VIRTIO_NET_HDR_GSO_NONE,
HdrLen: 0,
GSOSize: 0,
CsumStart: 0,
CsumOffset: 0,
NumBuffers: 0,
}
err := t.vdev[q].TransmitPackets(hdr, b)
err := t.vdev[q].TransmitPackets(x)
if err != nil {
t.l.WithError(err).Error("Transmitting packet")
return 0, err

View File

@@ -1,6 +1,7 @@
package overlay
import (
"fmt"
"io"
"net/netip"
@@ -71,14 +72,14 @@ func (d *UserDevice) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
return d.Read(b[0].Payload)
}
func (d *UserDevice) WriteMany(b [][]byte, _ int) (int, error) {
out := 0
for i := range b {
x, err := d.Write(b[i])
if err != nil {
return out, err
}
out += x
}
return out, nil
func (d *UserDevice) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
return 0, fmt.Errorf("user: AllocSeg not implemented")
}
func (d *UserDevice) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
return 0, fmt.Errorf("user: WriteOne not implemented")
}
func (d *UserDevice) WriteMany(x []*packet.OutPacket, q int) (int, error) {
return 0, fmt.Errorf("user: WriteMany not implemented")
}

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"os"
"runtime"
"slices"
"github.com/slackhq/nebula/overlay/vhost"
"github.com/slackhq/nebula/overlay/virtqueue"
@@ -31,6 +30,7 @@ type Device struct {
initialized bool
controlFD int
fullTable bool
ReceiveQueue *virtqueue.SplitQueue
TransmitQueue *virtqueue.SplitQueue
}
@@ -123,6 +123,9 @@ func NewDevice(options ...Option) (*Device, error) {
if err = dev.refillReceiveQueue(); err != nil {
return nil, fmt.Errorf("refill receive queue: %w", err)
}
if err = dev.refillTransmitQueue(); err != nil {
return nil, fmt.Errorf("refill receive queue: %w", err)
}
dev.initialized = true
@@ -139,7 +142,7 @@ func NewDevice(options ...Option) (*Device, error) {
// packets.
func (dev *Device) refillReceiveQueue() error {
for {
_, err := dev.ReceiveQueue.OfferInDescriptorChains(1)
_, err := dev.ReceiveQueue.OfferInDescriptorChains()
if err != nil {
if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
// Queue is full, job is done.
@@ -150,6 +153,22 @@ func (dev *Device) refillReceiveQueue() error {
}
}
func (dev *Device) refillTransmitQueue() error {
//for {
// desc, err := dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
// if err != nil {
// if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
// // Queue is full, job is done.
// return nil
// }
// return fmt.Errorf("offer descriptor chain: %w", err)
// } else {
// dev.TransmitQueue.UsedRing().InitOfferSingle(desc, 0)
// }
//}
return nil
}
// Close cleans up the vhost networking device within the kernel and releases
// all resources used for it.
// The implementation will try to release as many resources as possible and
@@ -238,49 +257,67 @@ func truncateBuffers(buffers [][]byte, length int) (out [][]byte) {
return
}
func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) error {
// Prepend the packet with its virtio-net header.
vnethdrBuf := make([]byte, virtio.NetHdrSize+14) //todo WHY
if err := vnethdr.Encode(vnethdrBuf); err != nil {
return fmt.Errorf("encode vnethdr: %w", err)
}
vnethdrBuf[virtio.NetHdrSize+14-2] = 0x86
vnethdrBuf[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype
func (dev *Device) GetPacketForTx() (uint16, []byte, error) {
var err error
var idx uint16
if !dev.fullTable {
chainIndexes, err := dev.TransmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets)
idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
if err == virtqueue.ErrNotEnoughFreeDescriptors {
dev.fullTable = true
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
}
} else {
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
}
if err != nil {
return fmt.Errorf("offer descriptor chain: %w", err)
return 0, nil, fmt.Errorf("transmit queue: %w", err)
}
//todo surely there's something better to do here
buf, err := dev.TransmitQueue.GetDescriptorItem(idx)
if err != nil {
return 0, nil, fmt.Errorf("get descriptor chain: %w", err)
}
return idx, buf, nil
}
for {
txedChains, err := dev.TransmitQueue.BlockAndGetHeadsCapped(context.TODO(), len(chainIndexes))
if err != nil {
func (dev *Device) TransmitPacket(pkt *packet.OutPacket, kick bool) error {
//if pkt.Valid {
if len(pkt.SegmentIDs) == 0 {
return nil
}
for idx := range pkt.SegmentIDs {
segmentID := pkt.SegmentIDs[idx]
dev.TransmitQueue.SetDescSize(segmentID, len(pkt.Segments[idx]))
}
err := dev.TransmitQueue.OfferDescriptorChains(pkt.SegmentIDs, false)
if err != nil {
return fmt.Errorf("offer descriptor chains: %w", err)
}
pkt.Reset()
//}
//if kick {
if err := dev.TransmitQueue.Kick(); err != nil {
return err
}
//}
return nil
}
func (dev *Device) TransmitPackets(pkts []*packet.OutPacket) error {
if len(pkts) == 0 {
return nil
}
for i := range pkts {
if err := dev.TransmitPacket(pkts[i], false); err != nil {
return err
} else if len(txedChains) == 0 {
continue //todo will this ever exit?
}
for _, c := range txedChains {
idx := slices.Index(chainIndexes, c.GetHead())
if idx < 0 {
continue
} else {
_ = dev.TransmitQueue.FreeDescriptorChain(chainIndexes[idx])
chainIndexes[idx] = 0 //todo I hope this works
}
}
done := true //optimism!
for _, x := range chainIndexes {
if x != 0 {
done = false
break
}
}
if done {
return nil
}
}
if err := dev.TransmitQueue.Kick(); err != nil {
return err
}
return nil
}
// TODO: Make above methods cancelable by taking a context.Context argument?
@@ -327,7 +364,7 @@ func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.Us
//shift the buffer out of out:
pkt.Payload = pkt.ChainRefs[0][virtio.NetHdrSize:chains[0].Length]
pkt.Chains = append(pkt.Chains, uint16(chains[0].DescriptorIndex))
pkt.Recycler = dev.ReceiveQueue.RecycleDescriptorChains
pkt.Recycler = dev.ReceiveQueue.OfferDescriptorChains
return 1, nil
//cursor := n - virtio.NetHdrSize
@@ -385,7 +422,7 @@ func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) {
}
// Now that we have copied all buffers, we can recycle the used descriptor chains
//if err = dev.ReceiveQueue.RecycleDescriptorChains(chains); err != nil {
//if err = dev.ReceiveQueue.OfferDescriptorChains(chains); err != nil {
// return 0, err
//}

View File

@@ -281,6 +281,106 @@ func (dt *DescriptorTable) createDescriptorChain(outBuffers [][]byte, numInBuffe
return head, nil
}
func (dt *DescriptorTable) CreateDescriptorForOutputs() (uint16, error) {
//todo just fill the damn table
// Do we still have enough free descriptors?
if 1 > dt.freeNum {
return 0, ErrNotEnoughFreeDescriptors
}
// Above validation ensured that there is at least one free descriptor, so
// the free descriptor chain head should be valid.
if dt.freeHeadIndex == noFreeHead {
panic("free descriptor chain head is unset but there should be free descriptors")
}
// To avoid having to iterate over the whole table to find the descriptor
// pointing to the head just to replace the free head, we instead always
// create descriptor chains from the descriptors coming after the head.
// This way we only have to touch the head as a last resort, when all other
// descriptors are already used.
head := dt.descriptors[dt.freeHeadIndex].next
desc := &dt.descriptors[head]
next := desc.next
checkUnusedDescriptorLength(head, desc)
// Give the device the maximum available number of bytes to write into.
desc.length = uint32(dt.itemSize)
desc.flags = 0 // descriptorFlagWritable
desc.next = 0 // Not necessary to clear this, it's just for looks.
dt.freeNum -= 1
if dt.freeNum == 0 {
// The last descriptor in the chain should be the free chain head
// itself.
if next != dt.freeHeadIndex {
panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
}
// When this new chain takes up all remaining descriptors, we no longer
// have a free chain.
dt.freeHeadIndex = noFreeHead
} else {
// We took some descriptors out of the free chain, so make sure to close
// the circle again.
dt.descriptors[dt.freeHeadIndex].next = next
}
return head, nil
}
func (dt *DescriptorTable) createDescriptorForInputs() (uint16, error) {
// Do we still have enough free descriptors?
if 1 > dt.freeNum {
return 0, ErrNotEnoughFreeDescriptors
}
// Above validation ensured that there is at least one free descriptor, so
// the free descriptor chain head should be valid.
if dt.freeHeadIndex == noFreeHead {
panic("free descriptor chain head is unset but there should be free descriptors")
}
// To avoid having to iterate over the whole table to find the descriptor
// pointing to the head just to replace the free head, we instead always
// create descriptor chains from the descriptors coming after the head.
// This way we only have to touch the head as a last resort, when all other
// descriptors are already used.
head := dt.descriptors[dt.freeHeadIndex].next
desc := &dt.descriptors[head]
next := desc.next
checkUnusedDescriptorLength(head, desc)
// Give the device the maximum available number of bytes to write into.
desc.length = uint32(dt.itemSize)
desc.flags = descriptorFlagWritable
desc.next = 0 // Not necessary to clear this, it's just for looks.
dt.freeNum -= 1
if dt.freeNum == 0 {
// The last descriptor in the chain should be the free chain head
// itself.
if next != dt.freeHeadIndex {
panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
}
// When this new chain takes up all remaining descriptors, we no longer
// have a free chain.
dt.freeHeadIndex = noFreeHead
} else {
// We took some descriptors out of the free chain, so make sure to close
// the circle again.
dt.descriptors[dt.freeHeadIndex].next = next
}
return head, nil
}
// TODO: Implement a zero-copy variant of createDescriptorChain?
// getDescriptorChain returns the device-readable buffers (out buffers) and
@@ -334,6 +434,20 @@ func (dt *DescriptorTable) getDescriptorChain(head uint16) (outBuffers, inBuffer
return
}
func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) {
if int(head) > len(dt.descriptors) {
return nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
}
desc := &dt.descriptors[head] //todo this is a pretty nasty hack with no checks
// The descriptor address points to memory not managed by Go, so this
// conversion is safe. See https://github.com/golang/go/issues/58625
//goland:noinspection GoVetUnsafePointer
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
return bs, nil
}
func (dt *DescriptorTable) getDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
if int(head) > len(dt.descriptors) {
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)

View File

@@ -208,6 +208,32 @@ func (sq *SplitQueue) BlockAndGetHeads(ctx context.Context) ([]UsedElement, erro
return nil, ctx.Err()
}
func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) {
var n int
var err error
for ctx.Err() == nil {
out, ok := sq.usedRing.takeOne()
if ok {
return out, nil
}
// Wait for a signal from the device.
if n, err = sq.epoll.Block(); err != nil {
return 0, fmt.Errorf("wait: %w", err)
}
if n > 0 {
out, ok = sq.usedRing.takeOne()
if ok {
_ = sq.epoll.Clear() //???
return out, nil
} else {
continue //???
}
}
}
return 0, ctx.Err()
}
func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) ([]UsedElement, error) {
var n int
var err error
@@ -268,14 +294,14 @@ func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int)
// they're done with them. When this does not happen, the queue will run full
// and any further calls to [SplitQueue.OfferDescriptorChain] will stall.
func (sq *SplitQueue) OfferInDescriptorChains(numInBuffers int) (uint16, error) {
func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) {
// Create a descriptor chain for the given buffers.
var (
head uint16
err error
)
for {
head, err = sq.descriptorTable.createDescriptorChain(nil, numInBuffers)
head, err = sq.descriptorTable.createDescriptorForInputs()
if err == nil {
break
}
@@ -361,6 +387,11 @@ func (sq *SplitQueue) GetDescriptorChain(head uint16) (outBuffers, inBuffers [][
return sq.descriptorTable.getDescriptorChain(head)
}
func (sq *SplitQueue) GetDescriptorItem(head uint16) ([]byte, error) {
sq.descriptorTable.descriptors[head].length = uint32(sq.descriptorTable.itemSize)
return sq.descriptorTable.getDescriptorItem(head)
}
func (sq *SplitQueue) GetDescriptorChainContents(head uint16, out []byte, maxLen int) (int, error) {
return sq.descriptorTable.getDescriptorChainContents(head, out, maxLen)
}
@@ -387,7 +418,12 @@ func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
return nil
}
func (sq *SplitQueue) RecycleDescriptorChains(chains []uint16, kick bool) error {
func (sq *SplitQueue) SetDescSize(head uint16, sz int) {
//not called under lock
sq.descriptorTable.descriptors[int(head)].length = uint32(sz)
}
func (sq *SplitQueue) OfferDescriptorChains(chains []uint16, kick bool) error {
//todo not doing this may break eventually?
//not called under lock
//if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
@@ -399,14 +435,19 @@ func (sq *SplitQueue) RecycleDescriptorChains(chains []uint16, kick bool) error
// Notify the device to make it process the updated available ring.
if kick {
if err := sq.kickEventFD.Kick(); err != nil {
return fmt.Errorf("notify device: %w", err)
}
return sq.Kick()
}
return nil
}
func (sq *SplitQueue) Kick() error {
if err := sq.kickEventFD.Kick(); err != nil {
return fmt.Errorf("notify device: %w", err)
}
return nil
}
// Close releases all resources used for this queue.
// The implementation will try to release as many resources as possible and
// collect potential errors before returning them.

View File

@@ -127,3 +127,58 @@ func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
return stillNeedToTake, elems
}
func (r *UsedRing) takeOne() (uint16, bool) {
//r.mu.Lock()
//defer r.mu.Unlock()
ringIndex := *r.ringIndex
if ringIndex == r.lastIndex {
// Nothing new.
return 0xffff, false
}
// Calculate the number new used elements that we can read from the ring.
// The ring index may wrap, so special handling for that case is needed.
count := int(ringIndex - r.lastIndex)
if count < 0 {
count += 0xffff
}
// The number of new elements can never exceed the queue size.
if count > len(r.ring) {
panic("used ring contains more new elements than the ring is long")
}
if count == 0 {
return 0xffff, false
}
out := r.ring[r.lastIndex%uint16(len(r.ring))].GetHead()
r.lastIndex++
return out, true
}
// InitOfferSingle is only used to pre-fill the used queue at startup, and should not be used if the device is running!
func (r *UsedRing) InitOfferSingle(x uint16, size int) {
//always called under lock
//r.mu.Lock()
//defer r.mu.Unlock()
offset := 0
// Add descriptor chain heads to the ring.
// The 16-bit ring index may overflow. This is expected and is not an
// issue because the size of the ring array (which equals the queue
// size) is always a power of 2 and smaller than the highest possible
// 16-bit value.
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
r.ring[insertIndex] = UsedElement{
DescriptorIndex: uint32(x),
Length: uint32(size),
}
// Increase the ring index by the number of descriptor chains added to the ring.
*r.ringIndex += 1
}