This commit is contained in:
JackDoan
2025-11-08 13:01:40 -06:00
parent edff19a05b
commit 987f45baf0
5 changed files with 42 additions and 28 deletions

View File

@@ -736,6 +736,7 @@ func (t *tun) Write(b []byte) (int, error) {
err := t.vdev.TransmitPacket(hdr, b) err := t.vdev.TransmitPacket(hdr, b)
if err != nil { if err != nil {
t.l.WithError(err).Error("Transmitting packet")
return 0, err return 0, err
} }
return maximum, nil return maximum, nil

View File

@@ -2,7 +2,6 @@ package virtqueue
import ( import (
"fmt" "fmt"
"sync"
"unsafe" "unsafe"
) )
@@ -49,7 +48,7 @@ type AvailableRing struct {
// virtio specification. // virtio specification.
usedEvent *uint16 usedEvent *uint16
mu sync.Mutex //mu sync.Mutex
} }
// newAvailableRing creates an available ring that uses the given underlying // newAvailableRing creates an available ring that uses the given underlying
@@ -84,8 +83,9 @@ func (r *AvailableRing) Address() uintptr {
// advances the ring index accordingly to make the device process the new // advances the ring index accordingly to make the device process the new
// descriptor chains. // descriptor chains.
func (r *AvailableRing) offer(chainHeads []uint16) { func (r *AvailableRing) offer(chainHeads []uint16) {
r.mu.Lock() //always called under lock
defer r.mu.Unlock() //r.mu.Lock()
//defer r.mu.Unlock()
// Add descriptor chain heads to the ring. // Add descriptor chain heads to the ring.
for offset, head := range chainHeads { for offset, head := range chainHeads {

View File

@@ -4,7 +4,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"math" "math"
"os"
"sync" "sync"
"unsafe" "unsafe"
@@ -54,6 +53,7 @@ type DescriptorTable struct {
bufferBase uintptr bufferBase uintptr
bufferSize int bufferSize int
itemSize int
mu sync.Mutex mu sync.Mutex
} }
@@ -63,7 +63,7 @@ type DescriptorTable struct {
// descriptor table (see [descriptorTableSize]) for the given queue size. // descriptor table (see [descriptorTableSize]) for the given queue size.
// //
// Before this descriptor table can be used, [initialize] must be called. // Before this descriptor table can be used, [initialize] must be called.
func newDescriptorTable(queueSize int, mem []byte) *DescriptorTable { func newDescriptorTable(queueSize int, mem []byte, itemSize int) *DescriptorTable {
dtSize := descriptorTableSize(queueSize) dtSize := descriptorTableSize(queueSize)
if len(mem) != dtSize { if len(mem) != dtSize {
panic(fmt.Sprintf("memory size (%v) does not match required size "+ panic(fmt.Sprintf("memory size (%v) does not match required size "+
@@ -75,6 +75,7 @@ func newDescriptorTable(queueSize int, mem []byte) *DescriptorTable {
// We have no free descriptors until they were initialized. // We have no free descriptors until they were initialized.
freeHeadIndex: noFreeHead, freeHeadIndex: noFreeHead,
freeNum: 0, freeNum: 0,
itemSize: itemSize, //todo configurable? needs to be page-aligned
} }
} }
@@ -116,10 +117,8 @@ func (dt *DescriptorTable) BufferAddresses() map[uintptr]int {
func (dt *DescriptorTable) initializeDescriptors() error { func (dt *DescriptorTable) initializeDescriptors() error {
numDescriptors := len(dt.descriptors) numDescriptors := len(dt.descriptors)
itemSize := os.Getpagesize() //todo configurable? needs to be page-aligned
// Allocate ONE large region for all buffers // Allocate ONE large region for all buffers
totalSize := itemSize * numDescriptors totalSize := dt.itemSize * numDescriptors
basePtr, err := unix.MmapPtr(-1, 0, nil, uintptr(totalSize), basePtr, err := unix.MmapPtr(-1, 0, nil, uintptr(totalSize),
unix.PROT_READ|unix.PROT_WRITE, unix.PROT_READ|unix.PROT_WRITE,
unix.MAP_PRIVATE|unix.MAP_ANONYMOUS) unix.MAP_PRIVATE|unix.MAP_ANONYMOUS)
@@ -136,7 +135,7 @@ func (dt *DescriptorTable) initializeDescriptors() error {
for i := range dt.descriptors { for i := range dt.descriptors {
dt.descriptors[i] = Descriptor{ dt.descriptors[i] = Descriptor{
address: dt.bufferBase + uintptr(i*itemSize), address: dt.bufferBase + uintptr(i*dt.itemSize),
length: 0, length: 0,
// All descriptors should form a free chain that loops around. // All descriptors should form a free chain that loops around.
flags: descriptorFlagHasNext, flags: descriptorFlagHasNext,
@@ -202,8 +201,6 @@ func (dt *DescriptorTable) releaseBuffers() error {
// caller should try again after some descriptor chains were used by the device // caller should try again after some descriptor chains were used by the device
// and returned back into the free chain. // and returned back into the free chain.
func (dt *DescriptorTable) createDescriptorChain(outBuffers [][]byte, numInBuffers int) (uint16, error) { func (dt *DescriptorTable) createDescriptorChain(outBuffers [][]byte, numInBuffers int) (uint16, error) {
pageSize := os.Getpagesize()
// Calculate the number of descriptors needed to build the chain. // Calculate the number of descriptors needed to build the chain.
numDesc := uint16(len(outBuffers) + numInBuffers) numDesc := uint16(len(outBuffers) + numInBuffers)
@@ -217,7 +214,7 @@ func (dt *DescriptorTable) createDescriptorChain(outBuffers [][]byte, numInBuffe
// Do we still have enough free descriptors? // Do we still have enough free descriptors?
if numDesc > dt.freeNum { if numDesc > dt.freeNum {
return 0, fmt.Errorf("%w: %d free but needed %d", ErrNotEnoughFreeDescriptors, dt.freeNum, numDesc) return 0, ErrNotEnoughFreeDescriptors
} }
// Above validation ensured that there is at least one free descriptor, so // Above validation ensured that there is at least one free descriptor, so
@@ -238,16 +235,16 @@ func (dt *DescriptorTable) createDescriptorChain(outBuffers [][]byte, numInBuffe
desc := &dt.descriptors[next] desc := &dt.descriptors[next]
checkUnusedDescriptorLength(next, desc) checkUnusedDescriptorLength(next, desc)
if len(buffer) > pageSize { if len(buffer) > dt.itemSize {
// The caller should already prevent that from happening. // The caller should already prevent that from happening.
panic(fmt.Sprintf("out buffer %d has size %d which exceeds page size %d", i, len(buffer), pageSize)) panic(fmt.Sprintf("out buffer %d has size %d which exceeds desc length %d", i, len(buffer), dt.itemSize))
} }
// Copy the buffer to the memory referenced by the descriptor. // Copy the buffer to the memory referenced by the descriptor.
// The descriptor address points to memory not managed by Go, so this // The descriptor address points to memory not managed by Go, so this
// conversion is safe. See https://github.com/golang/go/issues/58625 // conversion is safe. See https://github.com/golang/go/issues/58625
//goland:noinspection GoVetUnsafePointer //goland:noinspection GoVetUnsafePointer
copy(unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), pageSize), buffer) copy(unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), dt.itemSize), buffer)
desc.length = uint32(len(buffer)) desc.length = uint32(len(buffer))
// Clear the flags in case there were any others set. // Clear the flags in case there were any others set.
@@ -261,7 +258,7 @@ func (dt *DescriptorTable) createDescriptorChain(outBuffers [][]byte, numInBuffe
checkUnusedDescriptorLength(next, desc) checkUnusedDescriptorLength(next, desc)
// Give the device the maximum available number of bytes to write into. // Give the device the maximum available number of bytes to write into.
desc.length = uint32(pageSize) desc.length = uint32(dt.itemSize)
// Mark the descriptor as device-writable. // Mark the descriptor as device-writable.
desc.flags = descriptorFlagHasNext | descriptorFlagWritable desc.flags = descriptorFlagHasNext | descriptorFlagWritable

View File

@@ -48,6 +48,7 @@ type SplitQueue struct {
// [SplitQueue.OfferDescriptorChain]. // [SplitQueue.OfferDescriptorChain].
offerMutex sync.Mutex offerMutex sync.Mutex
pageSize int pageSize int
itemSize int
} }
// NewSplitQueue allocates a new [SplitQueue] in memory. The given queue size // NewSplitQueue allocates a new [SplitQueue] in memory. The given queue size
@@ -61,6 +62,7 @@ func NewSplitQueue(queueSize int) (_ *SplitQueue, err error) {
sq := SplitQueue{ sq := SplitQueue{
size: queueSize, size: queueSize,
pageSize: os.Getpagesize(), pageSize: os.Getpagesize(),
itemSize: os.Getpagesize(), //todo config
} }
// Clean up a partially initialized queue when something fails. // Clean up a partially initialized queue when something fails.
@@ -109,7 +111,7 @@ func NewSplitQueue(queueSize int) (_ *SplitQueue, err error) {
return nil, fmt.Errorf("allocate virtqueue buffer: %w", err) return nil, fmt.Errorf("allocate virtqueue buffer: %w", err)
} }
sq.descriptorTable = newDescriptorTable(queueSize, sq.buf[descriptorTableStart:descriptorTableEnd]) sq.descriptorTable = newDescriptorTable(queueSize, sq.buf[descriptorTableStart:descriptorTableEnd], sq.itemSize)
sq.availableRing = newAvailableRing(queueSize, sq.buf[availableRingStart:availableRingEnd]) sq.availableRing = newAvailableRing(queueSize, sq.buf[availableRingStart:availableRingEnd])
sq.usedRing = newUsedRing(queueSize, sq.buf[usedRingStart:usedRingEnd]) sq.usedRing = newUsedRing(queueSize, sq.buf[usedRingStart:usedRingEnd])
@@ -241,6 +243,12 @@ func (sq *SplitQueue) consumeUsedRing(ctx context.Context) error {
return nil return nil
} }
// blockForMoreDescriptors blocks on a channel waiting for more descriptors to free up.
// it is its own function so maybe it might show up in pprof
func (sq *SplitQueue) blockForMoreDescriptors() {
<-sq.moreFreeDescriptors
}
// OfferDescriptorChain offers a descriptor chain to the device which contains a // OfferDescriptorChain offers a descriptor chain to the device which contains a
// number of device-readable buffers (out buffers) and device-writable buffers // number of device-readable buffers (out buffers) and device-writable buffers
// (in buffers). // (in buffers).
@@ -292,12 +300,19 @@ func (sq *SplitQueue) OfferDescriptorChain(outBuffers [][]byte, numInBuffers int
if err == nil { if err == nil {
break break
} }
if waitFree && errors.Is(err, ErrNotEnoughFreeDescriptors) {
// Wait for more free descriptors to be put back into the queue. // I don't wanna use errors.Is, it's slow
// If the number of free descriptors is still not sufficient, we'll //goland:noinspection GoDirectComparisonOfErrors
// land here again. if err == ErrNotEnoughFreeDescriptors {
<-sq.moreFreeDescriptors if waitFree {
continue // Wait for more free descriptors to be put back into the queue.
// If the number of free descriptors is still not sufficient, we'll
// land here again.
sq.blockForMoreDescriptors()
continue
} else {
return 0, err
}
} }
return 0, fmt.Errorf("create descriptor chain: %w", err) return 0, fmt.Errorf("create descriptor chain: %w", err)
} }
@@ -340,6 +355,7 @@ func (sq *SplitQueue) GetDescriptorChain(head uint16) (outBuffers, inBuffers [][
func (sq *SplitQueue) FreeDescriptorChain(head uint16) error { func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
sq.ensureInitialized() sq.ensureInitialized()
//not called under lock
if err := sq.descriptorTable.freeDescriptorChain(head); err != nil { if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
return fmt.Errorf("free: %w", err) return fmt.Errorf("free: %w", err)
} }

View File

@@ -2,7 +2,6 @@ package virtqueue
import ( import (
"fmt" "fmt"
"sync"
"unsafe" "unsafe"
) )
@@ -52,7 +51,7 @@ type UsedRing struct {
// processed. // processed.
lastIndex uint16 lastIndex uint16
mu sync.Mutex //mu sync.Mutex
} }
// newUsedRing creates a used ring that uses the given underlying memory. The // newUsedRing creates a used ring that uses the given underlying memory. The
@@ -87,9 +86,10 @@ func (r *UsedRing) Address() uintptr {
// take returns all new [UsedElement]s that the device put into the ring and // take returns all new [UsedElement]s that the device put into the ring and
// that weren't already returned by a previous call to this method. // that weren't already returned by a previous call to this method.
// had a lock, I removed it
func (r *UsedRing) take() []UsedElement { func (r *UsedRing) take() []UsedElement {
r.mu.Lock() //r.mu.Lock()
defer r.mu.Unlock() //defer r.mu.Unlock()
ringIndex := *r.ringIndex ringIndex := *r.ringIndex
if ringIndex == r.lastIndex { if ringIndex == r.lastIndex {