mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 16:34:25 +01:00
505 lines
16 KiB
Go
505 lines
16 KiB
Go
package virtqueue
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"sync"
|
|
"unsafe"
|
|
|
|
"golang.org/x/sys/unix"
|
|
)
|
|
|
|
var (
|
|
// ErrDescriptorChainEmpty is returned when a descriptor chain would contain
|
|
// no buffers, which is not allowed.
|
|
ErrDescriptorChainEmpty = errors.New("empty descriptor chains are not allowed")
|
|
|
|
// ErrNotEnoughFreeDescriptors is returned when the free descriptors are
|
|
// exhausted, meaning that the queue is full.
|
|
ErrNotEnoughFreeDescriptors = errors.New("not enough free descriptors, queue is full")
|
|
|
|
// ErrInvalidDescriptorChain is returned when a descriptor chain is not
|
|
// valid for a given operation.
|
|
ErrInvalidDescriptorChain = errors.New("invalid descriptor chain")
|
|
)
|
|
|
|
// noFreeHead is used to mark when all descriptors are in use and we have no
|
|
// free chain. This value is impossible to occur as an index naturally, because
|
|
// it exceeds the maximum queue size.
|
|
const noFreeHead = uint16(math.MaxUint16)
|
|
|
|
// descriptorTableSize is the number of bytes needed to store a
|
|
// [DescriptorTable] with the given queue size in memory.
|
|
func descriptorTableSize(queueSize int) int {
|
|
return descriptorSize * queueSize
|
|
}
|
|
|
|
// descriptorTableAlignment is the minimum alignment of a [DescriptorTable]
|
|
// in memory, as required by the virtio spec.
|
|
const descriptorTableAlignment = 16
|
|
|
|
// DescriptorTable is a table that holds [Descriptor]s, addressed via their
|
|
// index in the slice.
|
|
type DescriptorTable struct {
|
|
descriptors []Descriptor
|
|
|
|
// freeHeadIndex is the index of the head of the descriptor chain which
|
|
// contains all currently unused descriptors. When all descriptors are in
|
|
// use, this has the special value of noFreeHead.
|
|
freeHeadIndex uint16
|
|
// freeNum tracks the number of descriptors which are currently not in use.
|
|
freeNum uint16
|
|
|
|
bufferBase uintptr
|
|
bufferSize int
|
|
itemSize int
|
|
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// newDescriptorTable creates a descriptor table that uses the given underlying
|
|
// memory. The Length of the memory slice must match the size needed for the
|
|
// descriptor table (see [descriptorTableSize]) for the given queue size.
|
|
//
|
|
// Before this descriptor table can be used, [initialize] must be called.
|
|
func newDescriptorTable(queueSize int, mem []byte, itemSize int) *DescriptorTable {
|
|
dtSize := descriptorTableSize(queueSize)
|
|
if len(mem) != dtSize {
|
|
panic(fmt.Sprintf("memory size (%v) does not match required size "+
|
|
"for descriptor table: %v", len(mem), dtSize))
|
|
}
|
|
|
|
return &DescriptorTable{
|
|
descriptors: unsafe.Slice((*Descriptor)(unsafe.Pointer(&mem[0])), queueSize),
|
|
// We have no free descriptors until they were initialized.
|
|
freeHeadIndex: noFreeHead,
|
|
freeNum: 0,
|
|
itemSize: itemSize, //todo configurable? needs to be page-aligned
|
|
}
|
|
}
|
|
|
|
// Address returns the pointer to the beginning of the descriptor table in
|
|
// memory. Do not modify the memory directly to not interfere with this
|
|
// implementation.
|
|
func (dt *DescriptorTable) Address() uintptr {
|
|
if dt.descriptors == nil {
|
|
panic("descriptor table is not initialized")
|
|
}
|
|
//should be same as dt.bufferBase
|
|
return uintptr(unsafe.Pointer(&dt.descriptors[0]))
|
|
}
|
|
|
|
func (dt *DescriptorTable) Size() uintptr {
|
|
if dt.descriptors == nil {
|
|
panic("descriptor table is not initialized")
|
|
}
|
|
return uintptr(dt.bufferSize)
|
|
}
|
|
|
|
// BufferAddresses returns a map of pointer->size for all allocations used by the table
|
|
func (dt *DescriptorTable) BufferAddresses() map[uintptr]int {
|
|
if dt.descriptors == nil {
|
|
panic("descriptor table is not initialized")
|
|
}
|
|
|
|
return map[uintptr]int{dt.bufferBase: dt.bufferSize}
|
|
}
|
|
|
|
// initializeDescriptors allocates buffers with the size of a full memory page
|
|
// for each descriptor in the table. While this may be a bit wasteful, it makes
|
|
// dealing with descriptors way easier. Without this preallocation, we would
|
|
// have to allocate and free memory on demand, increasing complexity.
|
|
//
|
|
// All descriptors will be marked as free and will form a free chain. The
|
|
// addresses of all descriptors will be populated while their length remains
|
|
// zero.
|
|
func (dt *DescriptorTable) initializeDescriptors() error {
|
|
numDescriptors := len(dt.descriptors)
|
|
|
|
// Allocate ONE large region for all buffers
|
|
totalSize := dt.itemSize * numDescriptors
|
|
basePtr, err := unix.MmapPtr(-1, 0, nil, uintptr(totalSize),
|
|
unix.PROT_READ|unix.PROT_WRITE,
|
|
unix.MAP_PRIVATE|unix.MAP_ANONYMOUS)
|
|
if err != nil {
|
|
return fmt.Errorf("allocate buffer memory for descriptors: %w", err)
|
|
}
|
|
|
|
dt.mu.Lock()
|
|
defer dt.mu.Unlock()
|
|
|
|
// Store the base for cleanup later
|
|
dt.bufferBase = uintptr(basePtr)
|
|
dt.bufferSize = totalSize
|
|
|
|
for i := range dt.descriptors {
|
|
dt.descriptors[i] = Descriptor{
|
|
address: dt.bufferBase + uintptr(i*dt.itemSize),
|
|
length: 0,
|
|
// All descriptors should form a free chain that loops around.
|
|
flags: descriptorFlagHasNext,
|
|
next: uint16((i + 1) % len(dt.descriptors)),
|
|
}
|
|
}
|
|
|
|
// All descriptors are free to use now.
|
|
dt.freeHeadIndex = 0
|
|
dt.freeNum = uint16(len(dt.descriptors))
|
|
|
|
return nil
|
|
}
|
|
|
|
// releaseBuffers releases all allocated buffers for this descriptor table.
|
|
// The implementation will try to release as many buffers as possible and
|
|
// collect potential errors before returning them.
|
|
// The descriptor table should no longer be used after calling this.
|
|
func (dt *DescriptorTable) releaseBuffers() error {
|
|
dt.mu.Lock()
|
|
defer dt.mu.Unlock()
|
|
|
|
for i := range dt.descriptors {
|
|
descriptor := &dt.descriptors[i]
|
|
descriptor.address = 0
|
|
}
|
|
|
|
// As a safety measure, make sure no descriptors can be used anymore.
|
|
dt.freeHeadIndex = noFreeHead
|
|
dt.freeNum = 0
|
|
|
|
if dt.bufferBase != 0 {
|
|
// The pointer points to memory not managed by Go, so this conversion
|
|
// is safe. See https://github.com/golang/go/issues/58625
|
|
dt.bufferBase = 0
|
|
//goland:noinspection GoVetUnsafePointer
|
|
err := unix.MunmapPtr(unsafe.Pointer(dt.bufferBase), uintptr(dt.bufferSize))
|
|
if err != nil {
|
|
return fmt.Errorf("release buffer memory: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// createDescriptorChain creates a new descriptor chain within the descriptor
|
|
// table which contains a number of device-readable buffers (out buffers) and
|
|
// device-writable buffers (in buffers).
|
|
//
|
|
// All buffers in the outBuffers slice will be concatenated by chaining
|
|
// descriptors, one for each buffer in the slice. The size of the single buffers
|
|
// must not exceed the size of a memory page (see [os.Getpagesize]).
|
|
// When numInBuffers is greater than zero, the given number of device-writable
|
|
// descriptors will be appended to the end of the chain, each referencing a
|
|
// whole memory page.
|
|
//
|
|
// The index of the head of the new descriptor chain will be returned. Callers
|
|
// should make sure to free the descriptor chain using [freeDescriptorChain]
|
|
// after it was used by the device.
|
|
//
|
|
// When there are not enough free descriptors to hold the given number of
|
|
// buffers, an [ErrNotEnoughFreeDescriptors] will be returned. In this case, the
|
|
// caller should try again after some descriptor chains were used by the device
|
|
// and returned back into the free chain.
|
|
func (dt *DescriptorTable) createDescriptorChain(outBuffers [][]byte, numInBuffers int) (uint16, error) {
|
|
// Calculate the number of descriptors needed to build the chain.
|
|
numDesc := uint16(len(outBuffers) + numInBuffers)
|
|
|
|
// Descriptor chains must always contain at least one descriptor.
|
|
if numDesc < 1 {
|
|
return 0, ErrDescriptorChainEmpty
|
|
}
|
|
|
|
dt.mu.Lock()
|
|
defer dt.mu.Unlock()
|
|
|
|
// Do we still have enough free descriptors?
|
|
if numDesc > dt.freeNum {
|
|
return 0, ErrNotEnoughFreeDescriptors
|
|
}
|
|
|
|
// Above validation ensured that there is at least one free descriptor, so
|
|
// the free descriptor chain head should be valid.
|
|
if dt.freeHeadIndex == noFreeHead {
|
|
panic("free descriptor chain head is unset but there should be free descriptors")
|
|
}
|
|
|
|
// To avoid having to iterate over the whole table to find the descriptor
|
|
// pointing to the head just to replace the free head, we instead always
|
|
// create descriptor chains from the descriptors coming after the head.
|
|
// This way we only have to touch the head as a last resort, when all other
|
|
// descriptors are already used.
|
|
head := dt.descriptors[dt.freeHeadIndex].next
|
|
next := head
|
|
tail := head
|
|
for i, buffer := range outBuffers {
|
|
desc := &dt.descriptors[next]
|
|
checkUnusedDescriptorLength(next, desc)
|
|
|
|
if len(buffer) > dt.itemSize {
|
|
// The caller should already prevent that from happening.
|
|
panic(fmt.Sprintf("out buffer %d has size %d which exceeds desc length %d", i, len(buffer), dt.itemSize))
|
|
}
|
|
|
|
// Copy the buffer to the memory referenced by the descriptor.
|
|
// The descriptor address points to memory not managed by Go, so this
|
|
// conversion is safe. See https://github.com/golang/go/issues/58625
|
|
//goland:noinspection GoVetUnsafePointer
|
|
copy(unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), dt.itemSize), buffer)
|
|
desc.length = uint32(len(buffer))
|
|
|
|
// Clear the flags in case there were any others set.
|
|
desc.flags = descriptorFlagHasNext
|
|
|
|
tail = next
|
|
next = desc.next
|
|
}
|
|
for range numInBuffers {
|
|
desc := &dt.descriptors[next]
|
|
checkUnusedDescriptorLength(next, desc)
|
|
|
|
// Give the device the maximum available number of bytes to write into.
|
|
desc.length = uint32(dt.itemSize)
|
|
|
|
// Mark the descriptor as device-writable.
|
|
desc.flags = descriptorFlagHasNext | descriptorFlagWritable
|
|
|
|
tail = next
|
|
next = desc.next
|
|
}
|
|
|
|
// The last descriptor should end the chain.
|
|
tailDesc := &dt.descriptors[tail]
|
|
tailDesc.flags &= ^descriptorFlagHasNext
|
|
tailDesc.next = 0 // Not necessary to clear this, it's just for looks.
|
|
|
|
dt.freeNum -= numDesc
|
|
|
|
if dt.freeNum == 0 {
|
|
// The last descriptor in the chain should be the free chain head
|
|
// itself.
|
|
if tail != dt.freeHeadIndex {
|
|
panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
|
|
}
|
|
|
|
// When this new chain takes up all remaining descriptors, we no longer
|
|
// have a free chain.
|
|
dt.freeHeadIndex = noFreeHead
|
|
} else {
|
|
// We took some descriptors out of the free chain, so make sure to close
|
|
// the circle again.
|
|
dt.descriptors[dt.freeHeadIndex].next = next
|
|
}
|
|
|
|
return head, nil
|
|
}
|
|
|
|
// TODO: Implement a zero-copy variant of createDescriptorChain?
|
|
|
|
// getDescriptorChain returns the device-readable buffers (out buffers) and
|
|
// device-writable buffers (in buffers) of the descriptor chain that starts with
|
|
// the given head index. The descriptor chain must have been created using
|
|
// [createDescriptorChain] and must not have been freed yet (meaning that the
|
|
// head index must not be contained in the free chain).
|
|
//
|
|
// Be careful to only access the returned buffer slices when the device has not
|
|
// yet or is no longer using them. They must not be accessed after
|
|
// [freeDescriptorChain] has been called.
|
|
func (dt *DescriptorTable) getDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
|
|
if int(head) > len(dt.descriptors) {
|
|
return nil, nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
|
}
|
|
|
|
dt.mu.Lock()
|
|
defer dt.mu.Unlock()
|
|
|
|
// Iterate over the chain. The iteration is limited to the queue size to
|
|
// avoid ending up in an endless loop when things go very wrong.
|
|
next := head
|
|
for range len(dt.descriptors) {
|
|
if next == dt.freeHeadIndex {
|
|
return nil, nil, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
|
|
}
|
|
|
|
desc := &dt.descriptors[next]
|
|
|
|
// The descriptor address points to memory not managed by Go, so this
|
|
// conversion is safe. See https://github.com/golang/go/issues/58625
|
|
//goland:noinspection GoVetUnsafePointer
|
|
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
|
|
|
|
if desc.flags&descriptorFlagWritable == 0 {
|
|
outBuffers = append(outBuffers, bs)
|
|
} else {
|
|
inBuffers = append(inBuffers, bs)
|
|
}
|
|
|
|
// Is this the tail of the chain?
|
|
if desc.flags&descriptorFlagHasNext == 0 {
|
|
break
|
|
}
|
|
|
|
// Detect loops.
|
|
if desc.next == head {
|
|
return nil, nil, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
|
|
}
|
|
|
|
next = desc.next
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (dt *DescriptorTable) getDescriptorChainContents(head uint16, out []byte) (int, error) {
|
|
if int(head) > len(dt.descriptors) {
|
|
return 0, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
|
}
|
|
|
|
dt.mu.Lock()
|
|
defer dt.mu.Unlock()
|
|
|
|
// Iterate over the chain. The iteration is limited to the queue size to
|
|
// avoid ending up in an endless loop when things go very wrong.
|
|
|
|
length := 0
|
|
//find length
|
|
next := head
|
|
for range len(dt.descriptors) {
|
|
if next == dt.freeHeadIndex {
|
|
return 0, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
|
|
}
|
|
|
|
desc := &dt.descriptors[next]
|
|
|
|
if desc.flags&descriptorFlagWritable == 0 {
|
|
return 0, fmt.Errorf("receive queue contains device-readable buffer")
|
|
}
|
|
length += int(desc.length)
|
|
|
|
// Is this the tail of the chain?
|
|
if desc.flags&descriptorFlagHasNext == 0 {
|
|
break
|
|
}
|
|
|
|
// Detect loops.
|
|
if desc.next == head {
|
|
return 0, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
|
|
}
|
|
|
|
next = desc.next
|
|
}
|
|
|
|
//set out to length:
|
|
out = out[:length]
|
|
|
|
//now do the copying
|
|
copied := 0
|
|
for range len(dt.descriptors) {
|
|
desc := &dt.descriptors[next]
|
|
|
|
// The descriptor address points to memory not managed by Go, so this
|
|
// conversion is safe. See https://github.com/golang/go/issues/58625
|
|
//goland:noinspection GoVetUnsafePointer
|
|
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
|
|
copied += copy(out[copied:], bs)
|
|
|
|
// Is this the tail of the chain?
|
|
if desc.flags&descriptorFlagHasNext == 0 {
|
|
break
|
|
}
|
|
|
|
// we did this already, no need to detect loops.
|
|
next = desc.next
|
|
}
|
|
if copied != length {
|
|
panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", length, copied))
|
|
}
|
|
|
|
return length, nil
|
|
}
|
|
|
|
// freeDescriptorChain can be used to free a descriptor chain when it is no
|
|
// longer in use. The descriptor chain that starts with the given index will be
|
|
// put back into the free chain, so the descriptors can be used for later calls
|
|
// of [createDescriptorChain].
|
|
// The descriptor chain must have been created using [createDescriptorChain] and
|
|
// must not have been freed yet (meaning that the head index must not be
|
|
// contained in the free chain).
|
|
func (dt *DescriptorTable) freeDescriptorChain(head uint16) error {
|
|
if int(head) > len(dt.descriptors) {
|
|
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
|
}
|
|
|
|
dt.mu.Lock()
|
|
defer dt.mu.Unlock()
|
|
|
|
// Iterate over the chain. The iteration is limited to the queue size to
|
|
// avoid ending up in an endless loop when things go very wrong.
|
|
next := head
|
|
var tailDesc *Descriptor
|
|
var chainLen uint16
|
|
for range len(dt.descriptors) {
|
|
if next == dt.freeHeadIndex {
|
|
return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
|
|
}
|
|
|
|
desc := &dt.descriptors[next]
|
|
chainLen++
|
|
|
|
// Set the length of all unused descriptors back to zero.
|
|
desc.length = 0
|
|
|
|
// Unset all flags except the next flag.
|
|
desc.flags &= descriptorFlagHasNext
|
|
|
|
// Is this the tail of the chain?
|
|
if desc.flags&descriptorFlagHasNext == 0 {
|
|
tailDesc = desc
|
|
break
|
|
}
|
|
|
|
// Detect loops.
|
|
if desc.next == head {
|
|
return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
|
|
}
|
|
|
|
next = desc.next
|
|
}
|
|
if tailDesc == nil {
|
|
// A descriptor chain longer than the queue size but without loops
|
|
// should be impossible.
|
|
panic(fmt.Sprintf("could not find a tail for descriptor chain starting at %d", head))
|
|
}
|
|
|
|
// The tail descriptor does not have the next flag set, but when it comes
|
|
// back into the free chain, it should have.
|
|
tailDesc.flags = descriptorFlagHasNext
|
|
|
|
if dt.freeHeadIndex == noFreeHead {
|
|
// The whole free chain was used up, so we turn this returned descriptor
|
|
// chain into the new free chain by completing the circle and using its
|
|
// head.
|
|
tailDesc.next = head
|
|
dt.freeHeadIndex = head
|
|
} else {
|
|
// Attach the returned chain at the beginning of the free chain but
|
|
// right after the free chain head.
|
|
freeHeadDesc := &dt.descriptors[dt.freeHeadIndex]
|
|
tailDesc.next = freeHeadDesc.next
|
|
freeHeadDesc.next = head
|
|
}
|
|
|
|
dt.freeNum += chainLen
|
|
|
|
return nil
|
|
}
|
|
|
|
// checkUnusedDescriptorLength asserts that the length of an unused descriptor
|
|
// is zero, as it should be.
|
|
// This is not a requirement by the virtio spec but rather a thing we do to
|
|
// notice when our algorithm goes sideways.
|
|
func checkUnusedDescriptorLength(index uint16, desc *Descriptor) {
|
|
if desc.length != 0 {
|
|
panic(fmt.Sprintf("descriptor %d should be unused but has a non-zero length", index))
|
|
}
|
|
}
|