broken chkpt

This commit is contained in:
JackDoan
2025-11-11 11:38:43 -06:00
parent c645a45438
commit e7f01390a3
8 changed files with 271 additions and 113 deletions

View File

@@ -28,8 +28,8 @@ type Device struct {
initialized bool
controlFD int
receiveQueue *virtqueue.SplitQueue
transmitQueue *virtqueue.SplitQueue
ReceiveQueue *virtqueue.SplitQueue
TransmitQueue *virtqueue.SplitQueue
// transmitted contains channels for each possible descriptor chain head
// index. This is used for packet transmit notifications.
@@ -96,17 +96,17 @@ func NewDevice(options ...Option) (*Device, error) {
}
// Initialize and register the queues needed for the networking device.
if dev.receiveQueue, err = createQueue(dev.controlFD, receiveQueueIndex, opts.queueSize); err != nil {
if dev.ReceiveQueue, err = createQueue(dev.controlFD, receiveQueueIndex, opts.queueSize); err != nil {
return nil, fmt.Errorf("create receive queue: %w", err)
}
if dev.transmitQueue, err = createQueue(dev.controlFD, transmitQueueIndex, opts.queueSize); err != nil {
if dev.TransmitQueue, err = createQueue(dev.controlFD, transmitQueueIndex, opts.queueSize); err != nil {
return nil, fmt.Errorf("create transmit queue: %w", err)
}
// Set up memory mappings for all buffers used by the queues. This has to
// happen before a backend for the queues can be registered.
memoryLayout := vhost.NewMemoryLayoutForQueues(
[]*virtqueue.SplitQueue{dev.receiveQueue, dev.transmitQueue},
[]*virtqueue.SplitQueue{dev.ReceiveQueue, dev.TransmitQueue},
)
if err = vhost.SetMemoryLayout(dev.controlFD, memoryLayout); err != nil {
return nil, fmt.Errorf("setup memory layout: %w", err)
@@ -127,7 +127,7 @@ func NewDevice(options ...Option) (*Device, error) {
}
// Initialize channels for transmit notifications.
dev.transmitted = make([]chan virtqueue.UsedElement, dev.transmitQueue.Size())
dev.transmitted = make([]chan virtqueue.UsedElement, dev.TransmitQueue.Size())
for i := range len(dev.transmitted) {
// It is important to use a single-element buffered channel here.
// When the channel was unbuffered and the monitorTransmitQueue
@@ -159,7 +159,7 @@ func NewDevice(options ...Option) (*Device, error) {
// in the transmit queue and produces a transmit notification via the
// corresponding channel.
func (dev *Device) monitorTransmitQueue() {
usedChan := dev.transmitQueue.UsedDescriptorChains()
usedChan := dev.TransmitQueue.UsedDescriptorChains()
for {
used, ok := <-usedChan
if !ok {
@@ -180,7 +180,7 @@ func (dev *Device) monitorTransmitQueue() {
// packets.
func (dev *Device) refillReceiveQueue() error {
for {
_, err := dev.receiveQueue.OfferDescriptorChain(nil, 1, false)
_, err := dev.ReceiveQueue.OfferDescriptorChain(nil, 1, false)
if err != nil {
if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
// Queue is full, job is done.
@@ -212,17 +212,17 @@ func (dev *Device) Close() error {
var errs []error
if dev.receiveQueue != nil {
if err := dev.receiveQueue.Close(); err == nil {
dev.receiveQueue = nil
if dev.ReceiveQueue != nil {
if err := dev.ReceiveQueue.Close(); err == nil {
dev.ReceiveQueue = nil
} else {
errs = append(errs, fmt.Errorf("close receive queue: %w", err))
}
}
if dev.transmitQueue != nil {
if err := dev.transmitQueue.Close(); err == nil {
dev.transmitQueue = nil
if dev.TransmitQueue != nil {
if err := dev.TransmitQueue.Close(); err == nil {
dev.TransmitQueue = nil
} else {
errs = append(errs, fmt.Errorf("close transmit queue: %w", err))
}
@@ -296,7 +296,7 @@ func (dev *Device) TransmitPacket(vnethdr virtio.NetHdr, packet []byte) error {
outBuffers := [][]byte{vnethdrBuf, packet}
//outBuffers := [][]byte{packet}
chainIndex, err := dev.transmitQueue.OfferDescriptorChain(outBuffers, 0, true)
chainIndex, err := dev.TransmitQueue.OfferDescriptorChain(outBuffers, 0, true)
if err != nil {
return fmt.Errorf("offer descriptor chain: %w", err)
}
@@ -304,7 +304,7 @@ func (dev *Device) TransmitPacket(vnethdr virtio.NetHdr, packet []byte) error {
// Wait for the packet to have been transmitted.
<-dev.transmitted[chainIndex]
if err = dev.transmitQueue.FreeDescriptorChain(chainIndex); err != nil {
if err = dev.TransmitQueue.FreeDescriptorChain(chainIndex); err != nil {
return fmt.Errorf("free descriptor chain: %w", err)
}
@@ -320,7 +320,7 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
vnethdrBuf[virtio.NetHdrSize+14-2] = 0x86
vnethdrBuf[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype
chainIndexes, err := dev.transmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets, true)
chainIndexes, err := dev.TransmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets, true)
if err != nil {
return fmt.Errorf("offer descriptor chain: %w", err)
}
@@ -330,7 +330,7 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
for i := range chainIndexes {
<-dev.transmitted[chainIndexes[i]]
if err = dev.transmitQueue.FreeDescriptorChain(chainIndexes[i]); err != nil {
if err = dev.TransmitQueue.FreeDescriptorChain(chainIndexes[i]); err != nil {
return fmt.Errorf("free descriptor chain: %w", err)
}
}
@@ -346,7 +346,7 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
// When this method returns an error, the receive queue will likely be in a
// broken state which this implementation cannot recover from. The caller should
// close the device and not attempt any additional receives.
func (dev *Device) ReceivePacket() (virtio.NetHdr, []byte, error) {
func (dev *Device) ReceivePacket(out []byte) (int, virtio.NetHdr, error) {
var (
chainHeads []uint16
@@ -358,41 +358,30 @@ func (dev *Device) ReceivePacket() (virtio.NetHdr, []byte, error) {
packetLength = -virtio.NetHdrSize
)
lenRead := 0
// We presented FeatureNetMergeRXBuffers to the device, so one packet may be
// made of multiple descriptor chains which are to be merged.
for remainingChains := 1; remainingChains > 0; remainingChains-- {
// Get the next descriptor chain.
usedElement, ok := <-dev.receiveQueue.UsedDescriptorChains()
usedElement, ok := <-dev.ReceiveQueue.UsedDescriptorChains()
if !ok {
return virtio.NetHdr{}, nil, ErrDeviceClosed
return 0, virtio.NetHdr{}, ErrDeviceClosed
}
// Track this chain to be freed later.
head := uint16(usedElement.DescriptorIndex)
chainHeads = append(chainHeads, head)
outBuffers, inBuffers, err := dev.receiveQueue.GetDescriptorChain(head)
n, err := dev.ReceiveQueue.GetDescriptorChainContents(head, out[lenRead:])
if err != nil {
// When this fails we may miss to free some descriptor chains. We
// could try to mitigate this by deferring the freeing somehow, but
// it's not worth the hassle. When this method fails, the queue will
// be in a broken state anyway.
return virtio.NetHdr{}, nil, fmt.Errorf("get descriptor chain: %w", err)
return 0, virtio.NetHdr{}, fmt.Errorf("get descriptor chain: %w", err)
}
if len(outBuffers) > 0 {
// How did this happen!?
panic("receive queue contains device-readable buffers")
}
if len(inBuffers) == 0 {
// Empty descriptor chains should not be possible.
panic("descriptor chain contains no buffers")
}
// The device tells us how many bytes of the descriptor chain it has
// actually written to. The specification forces the device to fully
// fill up all but the last descriptor chain when multiple descriptor
// chains are being merged, but being more compatible here doesn't hurt.
inBuffers = truncateBuffers(inBuffers, int(usedElement.Length))
lenRead += n
packetLength += int(usedElement.Length)
// Is this the first descriptor chain we process?
@@ -403,49 +392,51 @@ func (dev *Device) ReceivePacket() (virtio.NetHdr, []byte, error) {
// descriptor chain, but it is reasonable to assume that this is
// always the case.
// The decode method already does the buffer length check.
if err = vnethdr.Decode(inBuffers[0]); err != nil {
if err = vnethdr.Decode(out[0:]); err != nil {
// The device misbehaved. There is no way we can gracefully
// recover from this, because we don't know how many of the
// following descriptor chains belong to this packet.
return virtio.NetHdr{}, nil, fmt.Errorf("decode vnethdr: %w", err)
return 0, virtio.NetHdr{}, fmt.Errorf("decode vnethdr: %w", err)
}
inBuffers[0] = inBuffers[0][virtio.NetHdrSize:]
lenRead = 0
out = out[virtio.NetHdrSize:]
// The virtio-net header tells us how many descriptor chains this
// packet is long.
remainingChains = int(vnethdr.NumBuffers)
}
buffers = append(buffers, inBuffers...)
//buffers = append(buffers, inBuffers...)
}
// Copy all the buffers together to produce the complete packet slice.
packet := make([]byte, packetLength)
copied := 0
for _, buffer := range buffers {
copied += copy(packet[copied:], buffer)
}
if copied != packetLength {
panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", packetLength, copied))
}
//out = out[:packetLength]
//copied := 0
//for _, buffer := range buffers {
// copied += copy(out[copied:], buffer)
//}
//if copied != packetLength {
// panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", packetLength, copied))
//}
// Now that we have copied all buffers, we can free the used descriptor
// chains again.
// TODO: Recycling the descriptor chains would be more efficient than
// freeing them just to offer them again right after.
for _, head := range chainHeads {
if err := dev.receiveQueue.FreeDescriptorChain(head); err != nil {
return virtio.NetHdr{}, nil, fmt.Errorf("free descriptor chain with head index %d: %w", head, err)
if err := dev.ReceiveQueue.FreeAndOfferDescriptorChains(head); err != nil {
return 0, virtio.NetHdr{}, fmt.Errorf("free descriptor chain with head index %d: %w", head, err)
}
}
//if we don't churn chains, maybe we don't need this?
// It's advised to always keep the receive queue fully populated with
// available buffers which the device can write new packets into.
if err := dev.refillReceiveQueue(); err != nil {
return virtio.NetHdr{}, nil, fmt.Errorf("refill receive queue: %w", err)
}
//if err := dev.refillReceiveQueue(); err != nil {
// return 0, virtio.NetHdr{}, fmt.Errorf("refill receive queue: %w", err)
//}
return vnethdr, packet, nil
return packetLength, vnethdr, nil
}
// TODO: Make above methods cancelable by taking a context.Context argument?