broken chkpt

This commit is contained in:
JackDoan
2025-11-11 15:37:54 -06:00
parent e7f01390a3
commit 1719149594
16 changed files with 275 additions and 303 deletions

View File

@@ -1,14 +1,16 @@
package vhostnet
import (
"context"
"errors"
"fmt"
"os"
"runtime"
"github.com/slackhq/nebula/overlay/vhost"
"github.com/slackhq/nebula/overlay/virtio"
"github.com/slackhq/nebula/overlay/virtqueue"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/util/virtio"
"golang.org/x/sys/unix"
)
@@ -31,12 +33,7 @@ type Device struct {
ReceiveQueue *virtqueue.SplitQueue
TransmitQueue *virtqueue.SplitQueue
// transmitted contains channels for each possible descriptor chain head
// index. This is used for packet transmit notifications.
// When a packet was transmitted and the descriptor chain was used by the
// device, the corresponding channel receives the [virtqueue.UsedElement]
// instance provided by the device.
transmitted []chan virtqueue.UsedElement
extraRx []virtqueue.UsedElement
}
// NewDevice initializes a new vhost networking device within the
@@ -126,25 +123,6 @@ func NewDevice(options ...Option) (*Device, error) {
return nil, fmt.Errorf("refill receive queue: %w", err)
}
// Initialize channels for transmit notifications.
dev.transmitted = make([]chan virtqueue.UsedElement, dev.TransmitQueue.Size())
for i := range len(dev.transmitted) {
// It is important to use a single-element buffered channel here.
// When the channel was unbuffered and the monitorTransmitQueue
// goroutine would write into it, the writing would block which could
// lead to deadlocks in case transmit notifications do not arrive in
// order.
// When the goroutine would use fire-and-forget to write into that
// channel, there may be a chance that the TransmitPacket does not
// receive the transmit notification due to this being a race condition.
// Buffering a single transmit notification resolves this without race
// conditions or possible deadlocks.
dev.transmitted[i] = make(chan virtqueue.UsedElement, 1)
}
// Monitor transmit queue in background.
go dev.monitorTransmitQueue()
dev.initialized = true
// Make sure to clean up even when the device gets garbage collected without
@@ -155,32 +133,12 @@ func NewDevice(options ...Option) (*Device, error) {
return devPtr, nil
}
// monitorTransmitQueue waits for the device to advertise used descriptor chains
// in the transmit queue and produces a transmit notification via the
// corresponding channel.
func (dev *Device) monitorTransmitQueue() {
usedChan := dev.TransmitQueue.UsedDescriptorChains()
for {
used, ok := <-usedChan
if !ok {
// The queue was closed.
return
}
if int(used.DescriptorIndex) > len(dev.transmitted) {
panic(fmt.Sprintf("device provided a used descriptor index (%d) that is out of range",
used.DescriptorIndex))
}
dev.transmitted[used.DescriptorIndex] <- used
}
}
// refillReceiveQueue offers as many new device-writable buffers to the device
// as the queue can fit. The device will then use these to write received
// packets.
func (dev *Device) refillReceiveQueue() error {
for {
_, err := dev.ReceiveQueue.OfferDescriptorChain(nil, 1, false)
_, err := dev.ReceiveQueue.OfferInDescriptorChains(1)
if err != nil {
if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
// Queue is full, job is done.
@@ -279,38 +237,6 @@ func truncateBuffers(buffers [][]byte, length int) (out [][]byte) {
return
}
// TransmitPacket writes the given packet into the transmit queue of this
// device. The packet will be prepended with the [virtio.NetHdr].
//
// When the queue is full, this will block until the queue has enough room to
// transmit the packet. This method will not return before the packet was
// transmitted and the device notifies that it has used the packet buffer.
func (dev *Device) TransmitPacket(vnethdr virtio.NetHdr, packet []byte) error {
// Prepend the packet with its virtio-net header.
vnethdrBuf := make([]byte, virtio.NetHdrSize+14) //todo WHY
if err := vnethdr.Encode(vnethdrBuf); err != nil {
return fmt.Errorf("encode vnethdr: %w", err)
}
vnethdrBuf[virtio.NetHdrSize+14-2] = 0x86
vnethdrBuf[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype
outBuffers := [][]byte{vnethdrBuf, packet}
//outBuffers := [][]byte{packet}
chainIndex, err := dev.TransmitQueue.OfferDescriptorChain(outBuffers, 0, true)
if err != nil {
return fmt.Errorf("offer descriptor chain: %w", err)
}
// Wait for the packet to have been transmitted.
<-dev.transmitted[chainIndex]
if err = dev.TransmitQueue.FreeDescriptorChain(chainIndex); err != nil {
return fmt.Errorf("free descriptor chain: %w", err)
}
return nil
}
func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) error {
// Prepend the packet with its virtio-net header.
vnethdrBuf := make([]byte, virtio.NetHdrSize+14) //todo WHY
@@ -320,15 +246,42 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
vnethdrBuf[virtio.NetHdrSize+14-2] = 0x86
vnethdrBuf[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype
chainIndexes, err := dev.TransmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets, true)
chainIndexes, err := dev.TransmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets)
if err != nil {
return fmt.Errorf("offer descriptor chain: %w", err)
}
//todo surely there's something better to do here
doneYet := map[uint16]bool{}
for _, chain := range chainIndexes {
doneYet[chain] = false
}
for {
txedChains, err := dev.TransmitQueue.BlockAndGetHeads(context.TODO())
if err != nil {
return err
} else if len(txedChains) == 0 {
continue //todo will this ever exit?
}
for c := range txedChains {
doneYet[txedChains[c].GetHead()] = true
}
done := true //optimism!
for _, x := range doneYet {
if !x {
done = false
break
}
}
if done {
break
}
}
//todo blocking here suxxxx
// Wait for the packet to have been transmitted.
for i := range chainIndexes {
<-dev.transmitted[chainIndexes[i]]
if err = dev.TransmitQueue.FreeDescriptorChain(chainIndexes[i]); err != nil {
return fmt.Errorf("free descriptor chain: %w", err)
@@ -338,106 +291,104 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
return nil
}
// ReceivePacket reads the next available packet from the receive queue of this
// device and returns its [virtio.NetHdr] and packet data separately.
//
// When no packet is available, this will block until there is one.
//
// When this method returns an error, the receive queue will likely be in a
// broken state which this implementation cannot recover from. The caller should
// close the device and not attempt any additional receives.
func (dev *Device) ReceivePacket(out []byte) (int, virtio.NetHdr, error) {
var (
chainHeads []uint16
// TODO: Make above methods cancelable by taking a context.Context argument?
// TODO: Implement zero-copy variants to transmit and receive packets?
vnethdr virtio.NetHdr
buffers [][]byte
// processChains processes as many chains as needed to create one packet. The number of processed chains is returned.
func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.UsedElement) (int, error) {
//read first element to see how many descriptors we need:
pkt.Payload = pkt.Payload[:cap(pkt.Payload)]
n, err := dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[0].DescriptorIndex), pkt.Payload)
if err != nil {
return 0, err
}
// The specification requires that the first descriptor chain starts
// with a virtio-net header. It is not clear, whether it is also
// required to be fully contained in the first buffer of that
// descriptor chain, but it is reasonable to assume that this is
// always the case.
// The decode method already does the buffer length check.
if err = pkt.Header.Decode(pkt.Payload[0:]); err != nil {
// The device misbehaved. There is no way we can gracefully
// recover from this, because we don't know how many of the
// following descriptor chains belong to this packet.
return 0, fmt.Errorf("decode vnethdr: %w", err)
}
// Each packet starts with a virtio-net header which we have to subtract
// from the total length.
packetLength = -virtio.NetHdrSize
)
//we have the header now: what do we need to do?
if int(pkt.Header.NumBuffers) > len(chains) {
return 0, fmt.Errorf("number of buffers is greater than number of chains %d", len(chains))
}
lenRead := 0
//shift the buffer out of out:
copy(pkt.Payload, pkt.Payload[virtio.NetHdrSize:])
// We presented FeatureNetMergeRXBuffers to the device, so one packet may be
// made of multiple descriptor chains which are to be merged.
for remainingChains := 1; remainingChains > 0; remainingChains-- {
// Get the next descriptor chain.
usedElement, ok := <-dev.ReceiveQueue.UsedDescriptorChains()
if !ok {
return 0, virtio.NetHdr{}, ErrDeviceClosed
}
cursor := n - virtio.NetHdrSize
// Track this chain to be freed later.
head := uint16(usedElement.DescriptorIndex)
chainHeads = append(chainHeads, head)
if uint32(n) >= chains[0].Length && pkt.Header.NumBuffers == 1 {
pkt.Payload = pkt.Payload[:chains[0].Length-virtio.NetHdrSize]
return 1, nil
}
n, err := dev.ReceiveQueue.GetDescriptorChainContents(head, out[lenRead:])
i := 1
// we used chain 0 already
for i = 1; i < len(chains); i++ {
n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:])
if err != nil {
// When this fails we may miss to free some descriptor chains. We
// could try to mitigate this by deferring the freeing somehow, but
// it's not worth the hassle. When this method fails, the queue will
// be in a broken state anyway.
return 0, virtio.NetHdr{}, fmt.Errorf("get descriptor chain: %w", err)
return i, fmt.Errorf("get descriptor chain: %w", err)
}
lenRead += n
packetLength += int(usedElement.Length)
cursor += n
}
//todo this has to be wrong
pkt.Payload = pkt.Payload[:cursor]
return i, nil
}
// Is this the first descriptor chain we process?
if len(buffers) == 0 {
// The specification requires that the first descriptor chain starts
// with a virtio-net header. It is not clear, whether it is also
// required to be fully contained in the first buffer of that
// descriptor chain, but it is reasonable to assume that this is
// always the case.
// The decode method already does the buffer length check.
if err = vnethdr.Decode(out[0:]); err != nil {
// The device misbehaved. There is no way we can gracefully
// recover from this, because we don't know how many of the
// following descriptor chains belong to this packet.
return 0, virtio.NetHdr{}, fmt.Errorf("decode vnethdr: %w", err)
}
lenRead = 0
out = out[virtio.NetHdrSize:]
func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) {
//todo optimize?
var chains []virtqueue.UsedElement
var err error
//if len(dev.extraRx) == 0 {
chains, err = dev.ReceiveQueue.BlockAndGetHeadsCapped(context.TODO(), 64) //todo config batch
if err != nil {
return 0, err
}
if len(chains) == 0 {
return 0, nil
}
//} else {
// chains = dev.extraRx
//}
// The virtio-net header tells us how many descriptor chains this
// packet is long.
remainingChains = int(vnethdr.NumBuffers)
numPackets := 0
chainsIdx := 0
for numPackets = 0; chainsIdx < len(chains); numPackets++ {
if numPackets >= len(out) {
//dev.extraRx = chains[chainsIdx:]
//return numPackets, nil
return numPackets, fmt.Errorf("dropping %d packets, no room", len(chains)-numPackets)
}
//buffers = append(buffers, inBuffers...)
numChains, err := dev.processChains(out[numPackets], chains[chainsIdx:])
if err != nil {
return 0, err
}
chainsIdx += numChains
}
// Copy all the buffers together to produce the complete packet slice.
//out = out[:packetLength]
//copied := 0
//for _, buffer := range buffers {
// copied += copy(out[copied:], buffer)
//}
//if copied != packetLength {
// panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", packetLength, copied))
//}
// Now that we have copied all buffers, we can free the used descriptor
// chains again.
// TODO: Recycling the descriptor chains would be more efficient than
// freeing them just to offer them again right after.
for _, head := range chainHeads {
if err := dev.ReceiveQueue.FreeAndOfferDescriptorChains(head); err != nil {
return 0, virtio.NetHdr{}, fmt.Errorf("free descriptor chain with head index %d: %w", head, err)
}
// Now that we have copied all buffers, we can recycle the used descriptor chains
if err := dev.ReceiveQueue.RecycleDescriptorChains(chains); err != nil {
return 0, err
}
//if we don't churn chains, maybe we don't need this?
// It's advised to always keep the receive queue fully populated with
// available buffers which the device can write new packets into.
// It's advised to always keep the rx queue fully populated with available buffers which the device can write new packets into.
//if err := dev.refillReceiveQueue(); err != nil {
// return 0, virtio.NetHdr{}, fmt.Errorf("refill receive queue: %w", err)
//}
return packetLength, vnethdr, nil
return numPackets, nil
}
// TODO: Make above methods cancelable by taking a context.Context argument?
// TODO: Implement zero-copy variants to transmit and receive packets?