refactoring a bit

This commit is contained in:
JackDoan
2025-12-18 13:27:28 -06:00
parent f5c46c43ce
commit 41c9a3b2eb
19 changed files with 229 additions and 387 deletions

View File

@@ -10,10 +10,6 @@ import (
)
var (
// ErrDescriptorChainEmpty is returned when a descriptor chain would contain
// no buffers, which is not allowed.
ErrDescriptorChainEmpty = errors.New("empty descriptor chains are not allowed")
// ErrNotEnoughFreeDescriptors is returned when the free descriptors are
// exhausted, meaning that the queue is full.
ErrNotEnoughFreeDescriptors = errors.New("not enough free descriptors, queue is full")
@@ -272,59 +268,6 @@ func (dt *DescriptorTable) createDescriptorForInputs() (uint16, error) {
return head, nil
}
// TODO: Implement a zero-copy variant of createDescriptorChain?
// getDescriptorChain returns the device-readable buffers (out buffers) and
// device-writable buffers (in buffers) of the descriptor chain that starts with
// the given head index. The descriptor chain must have been created using
// [createDescriptorChain] and must not have been freed yet (meaning that the
// head index must not be contained in the free chain).
//
// Be careful to only access the returned buffer slices when the device has not
// yet or is no longer using them. They must not be accessed after
// [freeDescriptorChain] has been called.
func (dt *DescriptorTable) getDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
if int(head) > len(dt.descriptors) {
return nil, nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
}
// Iterate over the chain. The iteration is limited to the queue size to
// avoid ending up in an endless loop when things go very wrong.
next := head
for range len(dt.descriptors) {
if next == dt.freeHeadIndex {
return nil, nil, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
}
desc := &dt.descriptors[next]
// The descriptor address points to memory not managed by Go, so this
// conversion is safe. See https://github.com/golang/go/issues/58625
//goland:noinspection GoVetUnsafePointer
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
if desc.flags&descriptorFlagWritable == 0 {
outBuffers = append(outBuffers, bs)
} else {
inBuffers = append(inBuffers, bs)
}
// Is this the tail of the chain?
if desc.flags&descriptorFlagHasNext == 0 {
break
}
// Detect loops.
if desc.next == head {
return nil, nil, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
}
next = desc.next
}
return
}
func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) {
if int(head) > len(dt.descriptors) {
return nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
@@ -339,121 +282,6 @@ func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) {
return bs, nil
}
func (dt *DescriptorTable) getDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
if int(head) > len(dt.descriptors) {
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
}
// Iterate over the chain. The iteration is limited to the queue size to
// avoid ending up in an endless loop when things go very wrong.
next := head
for range len(dt.descriptors) {
if next == dt.freeHeadIndex {
return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
}
desc := &dt.descriptors[next]
// The descriptor address points to memory not managed by Go, so this
// conversion is safe. See https://github.com/golang/go/issues/58625
//goland:noinspection GoVetUnsafePointer
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
if desc.flags&descriptorFlagWritable == 0 {
return fmt.Errorf("there should not be an outbuffer in %d", head)
} else {
*inBuffers = append(*inBuffers, bs)
}
// Is this the tail of the chain?
if desc.flags&descriptorFlagHasNext == 0 {
break
}
// Detect loops.
if desc.next == head {
return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
}
next = desc.next
}
return nil
}
// freeDescriptorChain can be used to free a descriptor chain when it is no
// longer in use. The descriptor chain that starts with the given index will be
// put back into the free chain, so the descriptors can be used for later calls
// of [createDescriptorChain].
// The descriptor chain must have been created using [createDescriptorChain] and
// must not have been freed yet (meaning that the head index must not be
// contained in the free chain).
func (dt *DescriptorTable) freeDescriptorChain(head uint16) error {
if int(head) > len(dt.descriptors) {
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
}
// Iterate over the chain. The iteration is limited to the queue size to
// avoid ending up in an endless loop when things go very wrong.
next := head
var tailDesc *Descriptor
var chainLen uint16
for range len(dt.descriptors) {
if next == dt.freeHeadIndex {
return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
}
desc := &dt.descriptors[next]
chainLen++
// Set the length of all unused descriptors back to zero.
desc.length = 0
// Unset all flags except the next flag.
desc.flags &= descriptorFlagHasNext
// Is this the tail of the chain?
if desc.flags&descriptorFlagHasNext == 0 {
tailDesc = desc
break
}
// Detect loops.
if desc.next == head {
return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
}
next = desc.next
}
if tailDesc == nil {
// A descriptor chain longer than the queue size but without loops
// should be impossible.
panic(fmt.Sprintf("could not find a tail for descriptor chain starting at %d", head))
}
// The tail descriptor does not have the next flag set, but when it comes
// back into the free chain, it should have.
tailDesc.flags = descriptorFlagHasNext
if dt.freeHeadIndex == noFreeHead {
// The whole free chain was used up, so we turn this returned descriptor
// chain into the new free chain by completing the circle and using its
// head.
tailDesc.next = head
dt.freeHeadIndex = head
} else {
// Attach the returned chain at the beginning of the free chain but
// right after the free chain head.
freeHeadDesc := &dt.descriptors[dt.freeHeadIndex]
tailDesc.next = freeHeadDesc.next
freeHeadDesc.next = head
}
dt.freeNum += chainLen
return nil
}
// checkUnusedDescriptorLength asserts that the length of an unused descriptor
// is zero, as it should be.
// This is not a requirement by the virtio spec but rather a thing we do to

View File

@@ -128,8 +128,7 @@ func NewSplitQueue(queueSize int, itemSize int) (_ *SplitQueue, err error) {
return nil, err
}
// Consume used buffer notifications in the background.
sq.stop = sq.startConsumeUsedRing()
sq.stop = sq.kickSelfToExit()
return &sq, nil
}
@@ -169,9 +168,7 @@ func (sq *SplitQueue) CallEventFD() int {
return sq.callEventFD.FD()
}
// startConsumeUsedRing starts a goroutine that runs [consumeUsedRing].
// A function is returned that can be used to gracefully cancel it. todo rename
func (sq *SplitQueue) startConsumeUsedRing() func() error {
func (sq *SplitQueue) kickSelfToExit() func() error {
return func() error {
// The goroutine blocks until it receives a signal on the event file
@@ -185,7 +182,15 @@ func (sq *SplitQueue) startConsumeUsedRing() func() error {
}
}
func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) {
func (sq *SplitQueue) TakeSingleIndex(ctx context.Context) (uint16, error) {
element, err := sq.TakeSingle(ctx)
if err != nil {
return 0xffff, err
}
return element.GetHead(), nil
}
func (sq *SplitQueue) TakeSingle(ctx context.Context) (UsedElement, error) {
var n int
var err error
for ctx.Err() == nil {
@@ -195,7 +200,7 @@ func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) {
}
// Wait for a signal from the device.
if n, err = sq.epoll.Block(); err != nil {
return 0, fmt.Errorf("wait: %w", err)
return UsedElement{}, fmt.Errorf("wait: %w", err)
}
if n > 0 {
@@ -208,7 +213,31 @@ func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) {
}
}
}
return 0, ctx.Err()
return UsedElement{}, ctx.Err()
}
func (sq *SplitQueue) TakeSingleNoBlock() (UsedElement, bool) {
return sq.usedRing.takeOne()
}
func (sq *SplitQueue) WaitForUsedElements(ctx context.Context) error {
if sq.usedRing.availableToTake() != 0 {
return nil
}
for ctx.Err() == nil {
// Wait for a signal from the device.
n, err := sq.epoll.Block()
if err != nil {
return fmt.Errorf("wait: %w", err)
}
if n > 0 {
_ = sq.epoll.Clear()
if sq.usedRing.availableToTake() != 0 {
return nil
}
}
}
return ctx.Err()
}
func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) ([]UsedElement, error) {
@@ -235,7 +264,7 @@ func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int)
return nil, fmt.Errorf("wait: %w", err)
}
if n > 0 {
_ = sq.epoll.Clear() //???
_ = sq.epoll.Clear()
stillNeedToTake, out = sq.usedRing.take(maxToTake)
sq.more = stillNeedToTake
return out, nil
@@ -296,16 +325,14 @@ func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) {
sq.availableRing.offerSingle(head)
// Notify the device to make it process the updated available ring.
if err := sq.kickEventFD.Kick(); err != nil {
if err = sq.kickEventFD.Kick(); err != nil {
return head, fmt.Errorf("notify device: %w", err)
}
return head, nil
}
// GetDescriptorChain returns the device-readable buffers (out buffers) and
// device-writable buffers (in buffers) of the descriptor chain with the given
// head index.
// GetDescriptorItem returns the buffer of a given index
// The head index must be one that was returned by a previous call to
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
// freed yet.
@@ -313,37 +340,11 @@ func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) {
// Be careful to only access the returned buffer slices when the device is no
// longer using them. They must not be accessed after
// [SplitQueue.FreeDescriptorChain] has been called.
func (sq *SplitQueue) GetDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
return sq.descriptorTable.getDescriptorChain(head)
}
func (sq *SplitQueue) GetDescriptorItem(head uint16) ([]byte, error) {
sq.descriptorTable.descriptors[head].length = uint32(sq.descriptorTable.itemSize)
return sq.descriptorTable.getDescriptorItem(head)
}
func (sq *SplitQueue) GetDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
return sq.descriptorTable.getDescriptorInbuffers(head, inBuffers)
}
// FreeDescriptorChain frees the descriptor chain with the given head index.
// The head index must be one that was returned by a previous call to
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
// freed yet.
//
// This creates new room in the queue which can be used by following
// [SplitQueue.OfferDescriptorChain] calls.
// When there are outstanding calls for [SplitQueue.OfferDescriptorChain] that
// are waiting for free room in the queue, they may become unblocked by this.
func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
//not called under lock
if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
return fmt.Errorf("free: %w", err)
}
return nil
}
func (sq *SplitQueue) SetDescSize(head uint16, sz int) {
//not called under lock
sq.descriptorTable.descriptors[int(head)].length = uint32(sz)

View File

@@ -84,17 +84,11 @@ func (r *UsedRing) Address() uintptr {
return uintptr(unsafe.Pointer(r.flags))
}
// take returns all new [UsedElement]s that the device put into the ring and
// that weren't already returned by a previous call to this method.
// had a lock, I removed it
func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
//r.mu.Lock()
//defer r.mu.Unlock()
func (r *UsedRing) availableToTake() int {
ringIndex := *r.ringIndex
if ringIndex == r.lastIndex {
// Nothing new.
return 0, nil
return 0
}
// Calculate the number new used elements that we can read from the ring.
@@ -103,6 +97,16 @@ func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
if count < 0 {
count += 0xffff
}
return count
}
// take returns all new [UsedElement]s that the device put into the ring and
// that weren't already returned by a previous call to this method.
func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
count := r.availableToTake()
if count == 0 {
return 0, nil
}
stillNeedToTake := 0
@@ -128,21 +132,13 @@ func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
return stillNeedToTake, elems
}
func (r *UsedRing) takeOne() (uint16, bool) {
func (r *UsedRing) takeOne() (UsedElement, bool) {
//r.mu.Lock()
//defer r.mu.Unlock()
ringIndex := *r.ringIndex
if ringIndex == r.lastIndex {
// Nothing new.
return 0xffff, false
}
// Calculate the number new used elements that we can read from the ring.
// The ring index may wrap, so special handling for that case is needed.
count := int(ringIndex - r.lastIndex)
if count < 0 {
count += 0xffff
count := r.availableToTake()
if count == 0 {
return UsedElement{}, false
}
// The number of new elements can never exceed the queue size.
@@ -150,11 +146,7 @@ func (r *UsedRing) takeOne() (uint16, bool) {
panic("used ring contains more new elements than the ring is long")
}
if count == 0 {
return 0xffff, false
}
out := r.ring[r.lastIndex%uint16(len(r.ring))].GetHead()
out := r.ring[r.lastIndex%uint16(len(r.ring))]
r.lastIndex++
return out, true