mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 16:34:25 +01:00
pull deps in for optimization, maybe slice back out later
This commit is contained in:
@@ -16,10 +16,10 @@ import (
|
|||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/hetznercloud/virtio-go/virtio"
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/overlay/vhostnet"
|
"github.com/slackhq/nebula/overlay/vhostnet"
|
||||||
|
"github.com/slackhq/nebula/overlay/virtio"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
@@ -281,39 +281,6 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Read(p []byte) (int, error) {
|
|
||||||
hdr, out, err := t.vdev.ReceivePacket() //we are TXing
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if hdr.NumBuffers == 0 {
|
|
||||||
|
|
||||||
}
|
|
||||||
p = p[:len(out)]
|
|
||||||
copy(p, out)
|
|
||||||
return len(out), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Write(b []byte) (int, error) {
|
|
||||||
maximum := len(b) //we are RXing
|
|
||||||
|
|
||||||
hdr := virtio.NetHdr{ //todo
|
|
||||||
Flags: unix.VIRTIO_NET_HDR_F_DATA_VALID,
|
|
||||||
GSOType: unix.VIRTIO_NET_HDR_GSO_NONE,
|
|
||||||
HdrLen: 0,
|
|
||||||
GSOSize: 0,
|
|
||||||
CsumStart: 0,
|
|
||||||
CsumOffset: 0,
|
|
||||||
NumBuffers: 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
err := t.vdev.TransmitPacket(hdr, b)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return maximum, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) deviceBytes() (o [16]byte) {
|
func (t *tun) deviceBytes() (o [16]byte) {
|
||||||
for i, c := range t.Device {
|
for i, c := range t.Device {
|
||||||
o[i] = byte(c)
|
o[i] = byte(c)
|
||||||
@@ -740,3 +707,36 @@ func (t *tun) Close() error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) Read(p []byte) (int, error) {
|
||||||
|
hdr, out, err := t.vdev.ReceivePacket() //we are TXing
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if hdr.NumBuffers > 1 {
|
||||||
|
t.l.WithField("num_buffers", hdr.NumBuffers).Info("wow, lots to TX from tun")
|
||||||
|
}
|
||||||
|
p = p[:len(out)]
|
||||||
|
copy(p, out)
|
||||||
|
return len(out), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Write(b []byte) (int, error) {
|
||||||
|
maximum := len(b) //we are RXing
|
||||||
|
|
||||||
|
hdr := virtio.NetHdr{ //todo
|
||||||
|
Flags: unix.VIRTIO_NET_HDR_F_DATA_VALID,
|
||||||
|
GSOType: unix.VIRTIO_NET_HDR_GSO_NONE,
|
||||||
|
HdrLen: 0,
|
||||||
|
GSOSize: 0,
|
||||||
|
CsumStart: 0,
|
||||||
|
CsumOffset: 0,
|
||||||
|
NumBuffers: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := t.vdev.TransmitPacket(hdr, b)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return maximum, nil
|
||||||
|
}
|
||||||
|
|||||||
4
overlay/vhost/doc.go
Normal file
4
overlay/vhost/doc.go
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
// Package vhost implements the basic ioctl requests needed to interact with the
|
||||||
|
// kernel-level virtio server that provides accelerated virtio devices for
|
||||||
|
// networking and more.
|
||||||
|
package vhost
|
||||||
218
overlay/vhost/ioctl.go
Normal file
218
overlay/vhost/ioctl.go
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
package vhost
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/overlay/virtio"
|
||||||
|
"github.com/slackhq/nebula/overlay/virtqueue"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// vhostIoctlGetFeatures can be used to retrieve the features supported by
|
||||||
|
// the vhost implementation in the kernel.
|
||||||
|
//
|
||||||
|
// Response payload: [virtio.Feature]
|
||||||
|
// Kernel name: VHOST_GET_FEATURES
|
||||||
|
vhostIoctlGetFeatures = 0x8008af00
|
||||||
|
|
||||||
|
// vhostIoctlSetFeatures can be used to communicate the features supported
|
||||||
|
// by this virtio implementation to the kernel.
|
||||||
|
//
|
||||||
|
// Request payload: [virtio.Feature]
|
||||||
|
// Kernel name: VHOST_SET_FEATURES
|
||||||
|
vhostIoctlSetFeatures = 0x4008af00
|
||||||
|
|
||||||
|
// vhostIoctlSetOwner can be used to set the current process as the
|
||||||
|
// exclusive owner of a control file descriptor.
|
||||||
|
//
|
||||||
|
// Request payload: none
|
||||||
|
// Kernel name: VHOST_SET_OWNER
|
||||||
|
vhostIoctlSetOwner = 0x0000af01
|
||||||
|
|
||||||
|
// vhostIoctlSetMemoryLayout can be used to set up or modify the memory
|
||||||
|
// layout which describes the IOTLB mappings in the kernel.
|
||||||
|
//
|
||||||
|
// Request payload: [MemoryLayout] with custom serialization
|
||||||
|
// Kernel name: VHOST_SET_MEM_TABLE
|
||||||
|
vhostIoctlSetMemoryLayout = 0x4008af03
|
||||||
|
|
||||||
|
// vhostIoctlSetQueueSize can be used to set the size of the virtqueue.
|
||||||
|
//
|
||||||
|
// Request payload: [QueueState]
|
||||||
|
// Kernel name: VHOST_SET_VRING_NUM
|
||||||
|
vhostIoctlSetQueueSize = 0x4008af10
|
||||||
|
|
||||||
|
// vhostIoctlSetQueueAddress can be used to set the addresses of the
|
||||||
|
// different parts of the virtqueue.
|
||||||
|
//
|
||||||
|
// Request payload: [QueueAddresses]
|
||||||
|
// Kernel name: VHOST_SET_VRING_ADDR
|
||||||
|
vhostIoctlSetQueueAddress = 0x4028af11
|
||||||
|
|
||||||
|
// vhostIoctlSetAvailableRingBase can be used to set the index of the next
|
||||||
|
// available ring entry the device will process.
|
||||||
|
//
|
||||||
|
// Request payload: [QueueState]
|
||||||
|
// Kernel name: VHOST_SET_VRING_BASE
|
||||||
|
vhostIoctlSetAvailableRingBase = 0x4008af12
|
||||||
|
|
||||||
|
// vhostIoctlSetQueueKickEventFD can be used to set the event file
|
||||||
|
// descriptor to signal the device when descriptor chains were added to the
|
||||||
|
// available ring.
|
||||||
|
//
|
||||||
|
// Request payload: [QueueFile]
|
||||||
|
// Kernel name: VHOST_SET_VRING_KICK
|
||||||
|
vhostIoctlSetQueueKickEventFD = 0x4008af20
|
||||||
|
|
||||||
|
// vhostIoctlSetQueueCallEventFD can be used to set the event file
|
||||||
|
// descriptor that gets signaled by the device when descriptor chains have
|
||||||
|
// been used by it.
|
||||||
|
//
|
||||||
|
// Request payload: [QueueFile]
|
||||||
|
// Kernel name: VHOST_SET_VRING_CALL
|
||||||
|
vhostIoctlSetQueueCallEventFD = 0x4008af21
|
||||||
|
)
|
||||||
|
|
||||||
|
// QueueState is an ioctl request payload that can hold a queue index and any
|
||||||
|
// 32-bit number.
|
||||||
|
//
|
||||||
|
// Kernel name: vhost_vring_state
|
||||||
|
type QueueState struct {
|
||||||
|
// QueueIndex is the index of the virtqueue.
|
||||||
|
QueueIndex uint32
|
||||||
|
// Num is any 32-bit number, depending on the request.
|
||||||
|
Num uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueueAddresses is an ioctl request payload that can hold the addresses of the
|
||||||
|
// different parts of a virtqueue.
|
||||||
|
//
|
||||||
|
// Kernel name: vhost_vring_addr
|
||||||
|
type QueueAddresses struct {
|
||||||
|
// QueueIndex is the index of the virtqueue.
|
||||||
|
QueueIndex uint32
|
||||||
|
// Flags that are not used in this implementation.
|
||||||
|
Flags uint32
|
||||||
|
// DescriptorTableAddress is the address of the descriptor table in user
|
||||||
|
// space memory. It must be 16-byte aligned.
|
||||||
|
DescriptorTableAddress uintptr
|
||||||
|
// UsedRingAddress is the address of the used ring in user space memory. It
|
||||||
|
// must be 4-byte aligned.
|
||||||
|
UsedRingAddress uintptr
|
||||||
|
// AvailableRingAddress is the address of the available ring in user space
|
||||||
|
// memory. It must be 2-byte aligned.
|
||||||
|
AvailableRingAddress uintptr
|
||||||
|
// LogAddress is used for an optional logging support, not supported by this
|
||||||
|
// implementation.
|
||||||
|
LogAddress uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueueFile is an ioctl request payload that can hold a queue index and a file
|
||||||
|
// descriptor.
|
||||||
|
//
|
||||||
|
// Kernel name: vhost_vring_file
|
||||||
|
type QueueFile struct {
|
||||||
|
// QueueIndex is the index of the virtqueue.
|
||||||
|
QueueIndex uint32
|
||||||
|
// FD is the file descriptor of the file. Pass -1 to unbind from a file.
|
||||||
|
FD int32
|
||||||
|
}
|
||||||
|
|
||||||
|
// IoctlPtr is a copy of the similarly named unexported function from the Go
|
||||||
|
// unix package. This is needed to do custom ioctl requests not supported by the
|
||||||
|
// standard library.
|
||||||
|
func IoctlPtr(fd int, req uint, arg unsafe.Pointer) error {
|
||||||
|
_, _, err := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(req), uintptr(arg))
|
||||||
|
if err != 0 {
|
||||||
|
return fmt.Errorf("ioctl request %d: %w", req, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFeatures requests the supported feature bits from the virtio device
|
||||||
|
// associated with the given control file descriptor.
|
||||||
|
func GetFeatures(controlFD int) (virtio.Feature, error) {
|
||||||
|
var features virtio.Feature
|
||||||
|
if err := IoctlPtr(controlFD, vhostIoctlGetFeatures, unsafe.Pointer(&features)); err != nil {
|
||||||
|
return 0, fmt.Errorf("get features: %w", err)
|
||||||
|
}
|
||||||
|
return features, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFeatures communicates the feature bits supported by this implementation
|
||||||
|
// to the virtio device associated with the given control file descriptor.
|
||||||
|
func SetFeatures(controlFD int, features virtio.Feature) error {
|
||||||
|
if err := IoctlPtr(controlFD, vhostIoctlSetFeatures, unsafe.Pointer(&features)); err != nil {
|
||||||
|
return fmt.Errorf("set features: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OwnControlFD sets the current process as the exclusive owner for the
|
||||||
|
// given control file descriptor. This must be called before interacting with
|
||||||
|
// the control file descriptor in any other way.
|
||||||
|
func OwnControlFD(controlFD int) error {
|
||||||
|
if err := IoctlPtr(controlFD, vhostIoctlSetOwner, unsafe.Pointer(nil)); err != nil {
|
||||||
|
return fmt.Errorf("set control file descriptor owner: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMemoryLayout sets up or modifies the memory layout for the kernel-level
|
||||||
|
// virtio device associated with the given control file descriptor.
|
||||||
|
func SetMemoryLayout(controlFD int, layout MemoryLayout) error {
|
||||||
|
payload := layout.serializePayload()
|
||||||
|
if err := IoctlPtr(controlFD, vhostIoctlSetMemoryLayout, unsafe.Pointer(&payload[0])); err != nil {
|
||||||
|
return fmt.Errorf("set memory layout: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterQueue registers a virtio queue with the kernel-level virtio server.
|
||||||
|
// The virtqueue will be linked to the given control file descriptor and will
|
||||||
|
// have the given index. The kernel will use this queue until the control file
|
||||||
|
// descriptor is closed.
|
||||||
|
func RegisterQueue(controlFD int, queueIndex uint32, queue *virtqueue.SplitQueue) error {
|
||||||
|
if err := IoctlPtr(controlFD, vhostIoctlSetQueueSize, unsafe.Pointer(&QueueState{
|
||||||
|
QueueIndex: queueIndex,
|
||||||
|
Num: uint32(queue.Size()),
|
||||||
|
})); err != nil {
|
||||||
|
return fmt.Errorf("set queue size: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := IoctlPtr(controlFD, vhostIoctlSetQueueAddress, unsafe.Pointer(&QueueAddresses{
|
||||||
|
QueueIndex: queueIndex,
|
||||||
|
Flags: 0,
|
||||||
|
DescriptorTableAddress: queue.DescriptorTable().Address(),
|
||||||
|
UsedRingAddress: queue.UsedRing().Address(),
|
||||||
|
AvailableRingAddress: queue.AvailableRing().Address(),
|
||||||
|
LogAddress: 0,
|
||||||
|
})); err != nil {
|
||||||
|
return fmt.Errorf("set queue addresses: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := IoctlPtr(controlFD, vhostIoctlSetAvailableRingBase, unsafe.Pointer(&QueueState{
|
||||||
|
QueueIndex: queueIndex,
|
||||||
|
Num: 0,
|
||||||
|
})); err != nil {
|
||||||
|
return fmt.Errorf("set available ring base: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := IoctlPtr(controlFD, vhostIoctlSetQueueKickEventFD, unsafe.Pointer(&QueueFile{
|
||||||
|
QueueIndex: queueIndex,
|
||||||
|
FD: int32(queue.KickEventFD()),
|
||||||
|
})); err != nil {
|
||||||
|
return fmt.Errorf("set kick event file descriptor: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := IoctlPtr(controlFD, vhostIoctlSetQueueCallEventFD, unsafe.Pointer(&QueueFile{
|
||||||
|
QueueIndex: queueIndex,
|
||||||
|
FD: int32(queue.CallEventFD()),
|
||||||
|
})); err != nil {
|
||||||
|
return fmt.Errorf("set call event file descriptor: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
21
overlay/vhost/ioctl_test.go
Normal file
21
overlay/vhost/ioctl_test.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package vhost_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/hetznercloud/virtio-go/vhost"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestQueueState_Size(t *testing.T) {
|
||||||
|
assert.EqualValues(t, 8, unsafe.Sizeof(vhost.QueueState{}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueAddresses_Size(t *testing.T) {
|
||||||
|
assert.EqualValues(t, 40, unsafe.Sizeof(vhost.QueueAddresses{}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueueFile_Size(t *testing.T) {
|
||||||
|
assert.EqualValues(t, 8, unsafe.Sizeof(vhost.QueueFile{}))
|
||||||
|
}
|
||||||
75
overlay/vhost/memory.go
Normal file
75
overlay/vhost/memory.go
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
package vhost
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/overlay/virtqueue"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MemoryRegion describes a region of userspace memory which is being made
|
||||||
|
// accessible to a vhost device.
|
||||||
|
//
|
||||||
|
// Kernel name: vhost_memory_region
|
||||||
|
type MemoryRegion struct {
|
||||||
|
// GuestPhysicalAddress is the physical address of the memory region within
|
||||||
|
// the guest, when virtualization is used. When no virtualization is used,
|
||||||
|
// this should be the same as UserspaceAddress.
|
||||||
|
GuestPhysicalAddress uintptr
|
||||||
|
// Size is the size of the memory region.
|
||||||
|
Size uint64
|
||||||
|
// UserspaceAddress is the virtual address in the userspace of the host
|
||||||
|
// where the memory region can be found.
|
||||||
|
UserspaceAddress uintptr
|
||||||
|
// Padding and room for flags. Currently unused.
|
||||||
|
_ uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// MemoryLayout is a list of [MemoryRegion]s.
|
||||||
|
type MemoryLayout []MemoryRegion
|
||||||
|
|
||||||
|
// NewMemoryLayoutForQueues returns a new [MemoryLayout] that describes the
|
||||||
|
// memory pages used by the descriptor tables of the given queues.
|
||||||
|
func NewMemoryLayoutForQueues(queues []*virtqueue.SplitQueue) MemoryLayout {
|
||||||
|
pageSize := os.Getpagesize()
|
||||||
|
regions := make([]MemoryRegion, 0)
|
||||||
|
for _, queue := range queues {
|
||||||
|
for _, address := range queue.DescriptorTable().BufferAddresses() {
|
||||||
|
regions = append(regions, MemoryRegion{
|
||||||
|
// There is no virtualization in play here, so the guest address
|
||||||
|
// is the same as in the host's userspace.
|
||||||
|
GuestPhysicalAddress: address,
|
||||||
|
Size: uint64(pageSize),
|
||||||
|
UserspaceAddress: address,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return regions
|
||||||
|
}
|
||||||
|
|
||||||
|
// serializePayload serializes the list of memory regions into a format that is
|
||||||
|
// compatible to the vhost_memory kernel struct. The returned byte slice can be
|
||||||
|
// used as a payload for the vhostIoctlSetMemoryLayout ioctl.
|
||||||
|
func (regions MemoryLayout) serializePayload() []byte {
|
||||||
|
regionCount := len(regions)
|
||||||
|
regionSize := int(unsafe.Sizeof(MemoryRegion{}))
|
||||||
|
payload := make([]byte, 8+regionCount*regionSize)
|
||||||
|
|
||||||
|
// The first 32 bits contain the number of memory regions. The following 32
|
||||||
|
// bits are padding.
|
||||||
|
binary.LittleEndian.PutUint32(payload[0:4], uint32(regionCount))
|
||||||
|
|
||||||
|
if regionCount > 0 {
|
||||||
|
// The underlying byte array of the slice should already have the correct
|
||||||
|
// format, so just copy that.
|
||||||
|
copied := copy(payload[8:], unsafe.Slice((*byte)(unsafe.Pointer(®ions[0])), regionCount*regionSize))
|
||||||
|
if copied != regionCount*regionSize {
|
||||||
|
panic(fmt.Sprintf("copied only %d bytes of the memory regions, but expected %d",
|
||||||
|
copied, regionCount*regionSize))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return payload
|
||||||
|
}
|
||||||
42
overlay/vhost/memory_internal_test.go
Normal file
42
overlay/vhost/memory_internal_test.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package vhost
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMemoryRegion_Size(t *testing.T) {
|
||||||
|
assert.EqualValues(t, 32, unsafe.Sizeof(MemoryRegion{}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryLayout_SerializePayload(t *testing.T) {
|
||||||
|
layout := MemoryLayout([]MemoryRegion{
|
||||||
|
{
|
||||||
|
GuestPhysicalAddress: 42,
|
||||||
|
Size: 100,
|
||||||
|
UserspaceAddress: 142,
|
||||||
|
}, {
|
||||||
|
GuestPhysicalAddress: 99,
|
||||||
|
Size: 100,
|
||||||
|
UserspaceAddress: 99,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
payload := layout.serializePayload()
|
||||||
|
|
||||||
|
assert.Equal(t, []byte{
|
||||||
|
0x02, 0x00, 0x00, 0x00, // nregions
|
||||||
|
0x00, 0x00, 0x00, 0x00, // padding
|
||||||
|
// region 0
|
||||||
|
0x2a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // guest_phys_addr
|
||||||
|
0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // memory_size
|
||||||
|
0x8e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // userspace_addr
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // flags_padding
|
||||||
|
// region 1
|
||||||
|
0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // guest_phys_addr
|
||||||
|
0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // memory_size
|
||||||
|
0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // userspace_addr
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // flags_padding
|
||||||
|
}, payload)
|
||||||
|
}
|
||||||
@@ -6,9 +6,9 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
"github.com/hetznercloud/virtio-go/vhost"
|
"github.com/slackhq/nebula/overlay/vhost"
|
||||||
"github.com/hetznercloud/virtio-go/virtio"
|
"github.com/slackhq/nebula/overlay/virtio"
|
||||||
"github.com/hetznercloud/virtio-go/virtqueue"
|
"github.com/slackhq/nebula/overlay/virtqueue"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/hetznercloud/virtio-go/vhost"
|
"github.com/slackhq/nebula/overlay/vhost"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -3,8 +3,7 @@ package vhostnet
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/hetznercloud/virtio-go/tuntap"
|
"github.com/slackhq/nebula/overlay/virtqueue"
|
||||||
"github.com/hetznercloud/virtio-go/virtqueue"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type optionValues struct {
|
type optionValues struct {
|
||||||
@@ -60,11 +59,11 @@ func WithBackendFD(backendFD int) Option {
|
|||||||
return func(o *optionValues) { o.backendFD = backendFD }
|
return func(o *optionValues) { o.backendFD = backendFD }
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithBackendDevice returns an [Option] that sets the given TAP device as the
|
//// WithBackendDevice returns an [Option] that sets the given TAP device as the
|
||||||
// backend that will be used for the queues of the device. The device will
|
//// backend that will be used for the queues of the device. The device will
|
||||||
// write and read packets to/from that backend. The TAP device should have been
|
//// write and read packets to/from that backend. The TAP device should have been
|
||||||
// created with the [tuntap.WithVirtioNetHdr] option enabled.
|
//// created with the [tuntap.WithVirtioNetHdr] option enabled.
|
||||||
// Either this or [WithBackendFD] is required.
|
//// Either this or [WithBackendFD] is required.
|
||||||
func WithBackendDevice(dev *tuntap.Device) Option {
|
//func WithBackendDevice(dev *tuntap.Device) Option {
|
||||||
return func(o *optionValues) { o.backendFD = int(dev.File().Fd()) }
|
// return func(o *optionValues) { o.backendFD = int(dev.File().Fd()) }
|
||||||
}
|
//}
|
||||||
|
|||||||
3
overlay/virtio/doc.go
Normal file
3
overlay/virtio/doc.go
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
// Package virtio contains some generic types and concepts related to the virtio
|
||||||
|
// protocol.
|
||||||
|
package virtio
|
||||||
136
overlay/virtio/features.go
Normal file
136
overlay/virtio/features.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package virtio
|
||||||
|
|
||||||
|
// Feature contains feature bits that describe a virtio device or driver.
|
||||||
|
type Feature uint64
|
||||||
|
|
||||||
|
// Device-independent feature bits.
|
||||||
|
//
|
||||||
|
// Source: https://docs.oasis-open.org/virtio/virtio/v1.2/csd01/virtio-v1.2-csd01.html#x1-6600006
|
||||||
|
const (
|
||||||
|
// FeatureIndirectDescriptors indicates that the driver can use descriptors
|
||||||
|
// with an additional layer of indirection.
|
||||||
|
FeatureIndirectDescriptors Feature = 1 << 28
|
||||||
|
|
||||||
|
// FeatureVersion1 indicates compliance with version 1.0 of the virtio
|
||||||
|
// specification.
|
||||||
|
FeatureVersion1 Feature = 1 << 32
|
||||||
|
)
|
||||||
|
|
||||||
|
// Feature bits for networking devices.
|
||||||
|
//
|
||||||
|
// Source: https://docs.oasis-open.org/virtio/virtio/v1.2/csd01/virtio-v1.2-csd01.html#x1-2200003
|
||||||
|
const (
|
||||||
|
// FeatureNetDeviceCsum indicates that the device can handle packets with
|
||||||
|
// partial checksum (checksum offload).
|
||||||
|
FeatureNetDeviceCsum Feature = 1 << 0
|
||||||
|
|
||||||
|
// FeatureNetDriverCsum indicates that the driver can handle packets with
|
||||||
|
// partial checksum.
|
||||||
|
FeatureNetDriverCsum Feature = 1 << 1
|
||||||
|
|
||||||
|
// FeatureNetCtrlDriverOffloads indicates support for dynamic offload state
|
||||||
|
// reconfiguration.
|
||||||
|
FeatureNetCtrlDriverOffloads Feature = 1 << 2
|
||||||
|
|
||||||
|
// FeatureNetMTU indicates that the device reports a maximum MTU value.
|
||||||
|
FeatureNetMTU Feature = 1 << 3
|
||||||
|
|
||||||
|
// FeatureNetMAC indicates that the device provides a MAC address.
|
||||||
|
FeatureNetMAC Feature = 1 << 5
|
||||||
|
|
||||||
|
// FeatureNetDriverTSO4 indicates that the driver supports the TCP
|
||||||
|
// segmentation offload for received IPv4 packets.
|
||||||
|
FeatureNetDriverTSO4 Feature = 1 << 7
|
||||||
|
|
||||||
|
// FeatureNetDriverTSO6 indicates that the driver supports the TCP
|
||||||
|
// segmentation offload for received IPv6 packets.
|
||||||
|
FeatureNetDriverTSO6 Feature = 1 << 8
|
||||||
|
|
||||||
|
// FeatureNetDriverECN indicates that the driver supports the TCP
|
||||||
|
// segmentation offload with ECN for received packets.
|
||||||
|
FeatureNetDriverECN Feature = 1 << 9
|
||||||
|
|
||||||
|
// FeatureNetDriverUFO indicates that the driver supports the UDP
|
||||||
|
// fragmentation offload for received packets.
|
||||||
|
FeatureNetDriverUFO Feature = 1 << 10
|
||||||
|
|
||||||
|
// FeatureNetDeviceTSO4 indicates that the device supports the TCP
|
||||||
|
// segmentation offload for received IPv4 packets.
|
||||||
|
FeatureNetDeviceTSO4 Feature = 1 << 11
|
||||||
|
|
||||||
|
// FeatureNetDeviceTSO6 indicates that the device supports the TCP
|
||||||
|
// segmentation offload for received IPv6 packets.
|
||||||
|
FeatureNetDeviceTSO6 Feature = 1 << 12
|
||||||
|
|
||||||
|
// FeatureNetDeviceECN indicates that the device supports the TCP
|
||||||
|
// segmentation offload with ECN for received packets.
|
||||||
|
FeatureNetDeviceECN Feature = 1 << 13
|
||||||
|
|
||||||
|
// FeatureNetDeviceUFO indicates that the device supports the UDP
|
||||||
|
// fragmentation offload for received packets.
|
||||||
|
FeatureNetDeviceUFO Feature = 1 << 14
|
||||||
|
|
||||||
|
// FeatureNetMergeRXBuffers indicates that the driver can handle merged
|
||||||
|
// receive buffers.
|
||||||
|
// When this feature is negotiated, devices may merge multiple descriptor
|
||||||
|
// chains together to transport large received packets. [NetHdr.NumBuffers]
|
||||||
|
// will then contain the number of merged descriptor chains.
|
||||||
|
FeatureNetMergeRXBuffers Feature = 1 << 15
|
||||||
|
|
||||||
|
// FeatureNetStatus indicates that the device configuration status field is
|
||||||
|
// available.
|
||||||
|
FeatureNetStatus Feature = 1 << 16
|
||||||
|
|
||||||
|
// FeatureNetCtrlVQ indicates that a control channel virtqueue is
|
||||||
|
// available.
|
||||||
|
FeatureNetCtrlVQ Feature = 1 << 17
|
||||||
|
|
||||||
|
// FeatureNetCtrlRX indicates support for RX mode control (e.g. promiscuous
|
||||||
|
// or all-multicast) for packet receive filtering.
|
||||||
|
FeatureNetCtrlRX Feature = 1 << 18
|
||||||
|
|
||||||
|
// FeatureNetCtrlVLAN indicates support for VLAN filtering through the
|
||||||
|
// control channel.
|
||||||
|
FeatureNetCtrlVLAN Feature = 1 << 19
|
||||||
|
|
||||||
|
// FeatureNetDriverAnnounce indicates that the driver can send gratuitous
|
||||||
|
// packets.
|
||||||
|
FeatureNetDriverAnnounce Feature = 1 << 21
|
||||||
|
|
||||||
|
// FeatureNetMQ indicates that the device supports multiqueue with automatic
|
||||||
|
// receive steering.
|
||||||
|
FeatureNetMQ Feature = 1 << 22
|
||||||
|
|
||||||
|
// FeatureNetCtrlMACAddr indicates that the MAC address can be set through
|
||||||
|
// the control channel.
|
||||||
|
FeatureNetCtrlMACAddr Feature = 1 << 23
|
||||||
|
|
||||||
|
// FeatureNetDeviceUSO indicates that the device supports the UDP
|
||||||
|
// segmentation offload for received packets.
|
||||||
|
FeatureNetDeviceUSO Feature = 1 << 56
|
||||||
|
|
||||||
|
// FeatureNetHashReport indicates that the device can report a per-packet
|
||||||
|
// hash value and type.
|
||||||
|
FeatureNetHashReport Feature = 1 << 57
|
||||||
|
|
||||||
|
// FeatureNetDriverHdrLen indicates that the driver can provide the exact
|
||||||
|
// header length value (see [NetHdr.HdrLen]).
|
||||||
|
// Devices may benefit from knowing the exact header length.
|
||||||
|
FeatureNetDriverHdrLen Feature = 1 << 59
|
||||||
|
|
||||||
|
// FeatureNetRSS indicates that the device supports RSS (receive-side
|
||||||
|
// scaling) with configurable hash parameters.
|
||||||
|
FeatureNetRSS Feature = 1 << 60
|
||||||
|
|
||||||
|
// FeatureNetRSCExt indicates that the device can process duplicated ACKs
|
||||||
|
// and report the number of coalesced segments and duplicated ACKs.
|
||||||
|
FeatureNetRSCExt Feature = 1 << 61
|
||||||
|
|
||||||
|
// FeatureNetStandby indicates that the device may act as a standby for a
|
||||||
|
// primary device with the same MAC address.
|
||||||
|
FeatureNetStandby Feature = 1 << 62
|
||||||
|
|
||||||
|
// FeatureNetSpeedDuplex indicates that the device can report link speed and
|
||||||
|
// duplex mode.
|
||||||
|
FeatureNetSpeedDuplex Feature = 1 << 63
|
||||||
|
)
|
||||||
77
overlay/virtio/net_hdr.go
Normal file
77
overlay/virtio/net_hdr.go
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
package virtio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Workaround to make Go doc links work.
|
||||||
|
var _ unix.Errno
|
||||||
|
|
||||||
|
// NetHdrSize is the number of bytes needed to store a [NetHdr] in memory.
|
||||||
|
const NetHdrSize = 12
|
||||||
|
|
||||||
|
// ErrNetHdrBufferTooSmall is returned when a buffer is too small to fit a
|
||||||
|
// virtio_net_hdr.
|
||||||
|
var ErrNetHdrBufferTooSmall = errors.New("the buffer is too small to fit a virtio_net_hdr")
|
||||||
|
|
||||||
|
// NetHdr defines the virtio_net_hdr as described by the virtio specification.
|
||||||
|
type NetHdr struct {
|
||||||
|
// Flags that describe the packet.
|
||||||
|
// Possible values are:
|
||||||
|
// - [unix.VIRTIO_NET_HDR_F_NEEDS_CSUM]
|
||||||
|
// - [unix.VIRTIO_NET_HDR_F_DATA_VALID]
|
||||||
|
// - [unix.VIRTIO_NET_HDR_F_RSC_INFO]
|
||||||
|
Flags uint8
|
||||||
|
// GSOType contains the type of segmentation offload that should be used for
|
||||||
|
// the packet.
|
||||||
|
// Possible values are:
|
||||||
|
// - [unix.VIRTIO_NET_HDR_GSO_NONE]
|
||||||
|
// - [unix.VIRTIO_NET_HDR_GSO_TCPV4]
|
||||||
|
// - [unix.VIRTIO_NET_HDR_GSO_UDP]
|
||||||
|
// - [unix.VIRTIO_NET_HDR_GSO_TCPV6]
|
||||||
|
// - [unix.VIRTIO_NET_HDR_GSO_UDP_L4]
|
||||||
|
// - [unix.VIRTIO_NET_HDR_GSO_ECN]
|
||||||
|
GSOType uint8
|
||||||
|
// HdrLen contains the length of the headers that need to be replicated by
|
||||||
|
// segmentation offloads. It's the number of bytes from the beginning of the
|
||||||
|
// packet to the beginning of the transport payload.
|
||||||
|
// Only used when [FeatureNetDriverHdrLen] is negotiated.
|
||||||
|
HdrLen uint16
|
||||||
|
// GSOSize contains the maximum size of each segmented packet beyond the
|
||||||
|
// header (payload size). In case of TCP, this is the MSS.
|
||||||
|
GSOSize uint16
|
||||||
|
// CsumStart contains the offset within the packet from which on the
|
||||||
|
// checksum should be computed.
|
||||||
|
CsumStart uint16
|
||||||
|
// CsumOffset specifies how many bytes after [NetHdr.CsumStart] the computed
|
||||||
|
// 16-bit checksum should be inserted.
|
||||||
|
CsumOffset uint16
|
||||||
|
// NumBuffers contains the number of merged descriptor chains when
|
||||||
|
// [FeatureNetMergeRXBuffers] is negotiated.
|
||||||
|
// This field is only used for packets received by the driver and should be
|
||||||
|
// zero for transmitted packets.
|
||||||
|
NumBuffers uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode decodes the [NetHdr] from the given byte slice. The slice must contain
|
||||||
|
// at least [NetHdrSize] bytes.
|
||||||
|
func (v *NetHdr) Decode(data []byte) error {
|
||||||
|
if len(data) < NetHdrSize {
|
||||||
|
return ErrNetHdrBufferTooSmall
|
||||||
|
}
|
||||||
|
copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), NetHdrSize), data[:NetHdrSize])
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes the [NetHdr] into the given byte slice. The slice must have
|
||||||
|
// room for at least [NetHdrSize] bytes.
|
||||||
|
func (v *NetHdr) Encode(data []byte) error {
|
||||||
|
if len(data) < NetHdrSize {
|
||||||
|
return ErrNetHdrBufferTooSmall
|
||||||
|
}
|
||||||
|
copy(data[:NetHdrSize], unsafe.Slice((*byte)(unsafe.Pointer(v)), NetHdrSize))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
43
overlay/virtio/net_hdr_test.go
Normal file
43
overlay/virtio/net_hdr_test.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package virtio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNetHdr_Size(t *testing.T) {
|
||||||
|
assert.EqualValues(t, NetHdrSize, unsafe.Sizeof(NetHdr{}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetHdr_Encoding(t *testing.T) {
|
||||||
|
vnethdr := NetHdr{
|
||||||
|
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||||||
|
GSOType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
|
||||||
|
HdrLen: 42,
|
||||||
|
GSOSize: 1472,
|
||||||
|
CsumStart: 34,
|
||||||
|
CsumOffset: 6,
|
||||||
|
NumBuffers: 16,
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, NetHdrSize)
|
||||||
|
require.NoError(t, vnethdr.Encode(buf))
|
||||||
|
|
||||||
|
assert.Equal(t, []byte{
|
||||||
|
0x01, 0x05,
|
||||||
|
0x2a, 0x00,
|
||||||
|
0xc0, 0x05,
|
||||||
|
0x22, 0x00,
|
||||||
|
0x06, 0x00,
|
||||||
|
0x10, 0x00,
|
||||||
|
}, buf)
|
||||||
|
|
||||||
|
var decoded NetHdr
|
||||||
|
require.NoError(t, decoded.Decode(buf))
|
||||||
|
|
||||||
|
assert.Equal(t, vnethdr, decoded)
|
||||||
|
}
|
||||||
103
overlay/virtqueue/available_ring.go
Normal file
103
overlay/virtqueue/available_ring.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package virtqueue
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
// availableRingFlag is a flag that describes an [AvailableRing].
|
||||||
|
type availableRingFlag uint16
|
||||||
|
|
||||||
|
const (
|
||||||
|
// availableRingFlagNoInterrupt is used by the guest to advise the host to
|
||||||
|
// not interrupt it when consuming a buffer. It's unreliable, so it's simply
|
||||||
|
// an optimization.
|
||||||
|
availableRingFlagNoInterrupt availableRingFlag = 1 << iota
|
||||||
|
)
|
||||||
|
|
||||||
|
// availableRingSize is the number of bytes needed to store an [AvailableRing]
|
||||||
|
// with the given queue size in memory.
|
||||||
|
func availableRingSize(queueSize int) int {
|
||||||
|
return 6 + 2*queueSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// availableRingAlignment is the minimum alignment of an [AvailableRing]
|
||||||
|
// in memory, as required by the virtio spec.
|
||||||
|
const availableRingAlignment = 2
|
||||||
|
|
||||||
|
// AvailableRing is used by the driver to offer descriptor chains to the device.
|
||||||
|
// Each ring entry refers to the head of a descriptor chain. It is only written
|
||||||
|
// to by the driver and read by the device.
|
||||||
|
//
|
||||||
|
// Because the size of the ring depends on the queue size, we cannot define a
|
||||||
|
// Go struct with a static size that maps to the memory of the ring. Instead,
|
||||||
|
// this struct only contains pointers to the corresponding memory areas.
|
||||||
|
type AvailableRing struct {
|
||||||
|
initialized bool
|
||||||
|
|
||||||
|
// flags that describe this ring.
|
||||||
|
flags *availableRingFlag
|
||||||
|
// ringIndex indicates where the driver would put the next entry into the
|
||||||
|
// ring (modulo the queue size).
|
||||||
|
ringIndex *uint16
|
||||||
|
// ring references buffers using the index of the head of the descriptor
|
||||||
|
// chain in the [DescriptorTable]. It wraps around at queue size.
|
||||||
|
ring []uint16
|
||||||
|
// usedEvent is not used by this implementation, but we reserve it anyway to
|
||||||
|
// avoid issues in case a device may try to access it, contrary to the
|
||||||
|
// virtio specification.
|
||||||
|
usedEvent *uint16
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// newAvailableRing creates an available ring that uses the given underlying
|
||||||
|
// memory. The length of the memory slice must match the size needed for the
|
||||||
|
// ring (see [availableRingSize]) for the given queue size.
|
||||||
|
func newAvailableRing(queueSize int, mem []byte) *AvailableRing {
|
||||||
|
ringSize := availableRingSize(queueSize)
|
||||||
|
if len(mem) != ringSize {
|
||||||
|
panic(fmt.Sprintf("memory size (%v) does not match required size "+
|
||||||
|
"for available ring: %v", len(mem), ringSize))
|
||||||
|
}
|
||||||
|
|
||||||
|
return &AvailableRing{
|
||||||
|
initialized: true,
|
||||||
|
flags: (*availableRingFlag)(unsafe.Pointer(&mem[0])),
|
||||||
|
ringIndex: (*uint16)(unsafe.Pointer(&mem[2])),
|
||||||
|
ring: unsafe.Slice((*uint16)(unsafe.Pointer(&mem[4])), queueSize),
|
||||||
|
usedEvent: (*uint16)(unsafe.Pointer(&mem[ringSize-2])),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Address returns the pointer to the beginning of the ring in memory.
|
||||||
|
// Do not modify the memory directly to not interfere with this implementation.
|
||||||
|
func (r *AvailableRing) Address() uintptr {
|
||||||
|
if !r.initialized {
|
||||||
|
panic("available ring is not initialized")
|
||||||
|
}
|
||||||
|
return uintptr(unsafe.Pointer(r.flags))
|
||||||
|
}
|
||||||
|
|
||||||
|
// offer adds the given descriptor chain heads to the available ring and
|
||||||
|
// advances the ring index accordingly to make the device process the new
|
||||||
|
// descriptor chains.
|
||||||
|
func (r *AvailableRing) offer(chainHeads []uint16) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
// Add descriptor chain heads to the ring.
|
||||||
|
for offset, head := range chainHeads {
|
||||||
|
// The 16-bit ring index may overflow. This is expected and is not an
|
||||||
|
// issue because the size of the ring array (which equals the queue
|
||||||
|
// size) is always a power of 2 and smaller than the highest possible
|
||||||
|
// 16-bit value.
|
||||||
|
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
|
||||||
|
r.ring[insertIndex] = head
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increase the ring index by the number of descriptor chains added to the
|
||||||
|
// ring.
|
||||||
|
*r.ringIndex += uint16(len(chainHeads))
|
||||||
|
}
|
||||||
71
overlay/virtqueue/available_ring_internal_test.go
Normal file
71
overlay/virtqueue/available_ring_internal_test.go
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
package virtqueue
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAvailableRing_MemoryLayout(t *testing.T) {
|
||||||
|
const queueSize = 2
|
||||||
|
|
||||||
|
memory := make([]byte, availableRingSize(queueSize))
|
||||||
|
r := newAvailableRing(queueSize, memory)
|
||||||
|
|
||||||
|
*r.flags = 0x01ff
|
||||||
|
*r.ringIndex = 1
|
||||||
|
r.ring[0] = 0x1234
|
||||||
|
r.ring[1] = 0x5678
|
||||||
|
|
||||||
|
assert.Equal(t, []byte{
|
||||||
|
0xff, 0x01,
|
||||||
|
0x01, 0x00,
|
||||||
|
0x34, 0x12,
|
||||||
|
0x78, 0x56,
|
||||||
|
0x00, 0x00,
|
||||||
|
}, memory)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAvailableRing_Offer(t *testing.T) {
|
||||||
|
const queueSize = 8
|
||||||
|
|
||||||
|
chainHeads := []uint16{42, 33, 69}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
startRingIndex uint16
|
||||||
|
expectedRingIndex uint16
|
||||||
|
expectedRing []uint16
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no overflow",
|
||||||
|
startRingIndex: 0,
|
||||||
|
expectedRingIndex: 3,
|
||||||
|
expectedRing: []uint16{42, 33, 69, 0, 0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ring overflow",
|
||||||
|
startRingIndex: 6,
|
||||||
|
expectedRingIndex: 9,
|
||||||
|
expectedRing: []uint16{69, 0, 0, 0, 0, 0, 42, 33},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "index overflow",
|
||||||
|
startRingIndex: 65535,
|
||||||
|
expectedRingIndex: 2,
|
||||||
|
expectedRing: []uint16{33, 69, 0, 0, 0, 0, 0, 42},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
memory := make([]byte, availableRingSize(queueSize))
|
||||||
|
r := newAvailableRing(queueSize, memory)
|
||||||
|
*r.ringIndex = tt.startRingIndex
|
||||||
|
|
||||||
|
r.offer(chainHeads)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedRingIndex, *r.ringIndex)
|
||||||
|
assert.Equal(t, tt.expectedRing, r.ring)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
43
overlay/virtqueue/descriptor.go
Normal file
43
overlay/virtqueue/descriptor.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package virtqueue
|
||||||
|
|
||||||
|
// descriptorFlag is a flag that describes a [Descriptor].
|
||||||
|
type descriptorFlag uint16
|
||||||
|
|
||||||
|
const (
|
||||||
|
// descriptorFlagHasNext marks a descriptor chain as continuing via the next
|
||||||
|
// field.
|
||||||
|
descriptorFlagHasNext descriptorFlag = 1 << iota
|
||||||
|
// descriptorFlagWritable marks a buffer as device write-only (otherwise
|
||||||
|
// device read-only).
|
||||||
|
descriptorFlagWritable
|
||||||
|
// descriptorFlagIndirect means the buffer contains a list of buffer
|
||||||
|
// descriptors to provide an additional layer of indirection.
|
||||||
|
// Only allowed when the [virtio.FeatureIndirectDescriptors] feature was
|
||||||
|
// negotiated.
|
||||||
|
descriptorFlagIndirect
|
||||||
|
)
|
||||||
|
|
||||||
|
// descriptorSize is the number of bytes needed to store a [Descriptor] in
|
||||||
|
// memory.
|
||||||
|
const descriptorSize = 16
|
||||||
|
|
||||||
|
// Descriptor describes (a part of) a buffer which is either read-only for the
|
||||||
|
// device or write-only for the device (depending on [descriptorFlagWritable]).
|
||||||
|
// Multiple descriptors can be chained to produce a "descriptor chain" that can
|
||||||
|
// contain both device-readable and device-writable buffers. Device-readable
|
||||||
|
// descriptors always come first in a chain. A single, large buffer may be
|
||||||
|
// split up by chaining multiple similar descriptors that reference different
|
||||||
|
// memory pages. This is required, because buffers may exceed a single page size
|
||||||
|
// and the memory accessed by the device is expected to be continuous.
|
||||||
|
type Descriptor struct {
|
||||||
|
// address is the address to the continuous memory holding the data for this
|
||||||
|
// descriptor.
|
||||||
|
address uintptr
|
||||||
|
// length is the amount of bytes stored at address.
|
||||||
|
length uint32
|
||||||
|
// flags that describe this descriptor.
|
||||||
|
flags descriptorFlag
|
||||||
|
// next contains the index of the next descriptor continuing this descriptor
|
||||||
|
// chain when the [descriptorFlagHasNext] flag is set.
|
||||||
|
next uint16
|
||||||
|
}
|
||||||
12
overlay/virtqueue/descriptor_internal_test.go
Normal file
12
overlay/virtqueue/descriptor_internal_test.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package virtqueue
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDescriptor_Size(t *testing.T) {
|
||||||
|
assert.EqualValues(t, descriptorSize, unsafe.Sizeof(Descriptor{}))
|
||||||
|
}
|
||||||
437
overlay/virtqueue/descriptor_table.go
Normal file
437
overlay/virtqueue/descriptor_table.go
Normal file
@@ -0,0 +1,437 @@
|
|||||||
|
package virtqueue
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrDescriptorChainEmpty is returned when a descriptor chain would contain
|
||||||
|
// no buffers, which is not allowed.
|
||||||
|
ErrDescriptorChainEmpty = errors.New("empty descriptor chains are not allowed")
|
||||||
|
|
||||||
|
// ErrNotEnoughFreeDescriptors is returned when the free descriptors are
|
||||||
|
// exhausted, meaning that the queue is full.
|
||||||
|
ErrNotEnoughFreeDescriptors = errors.New("not enough free descriptors, queue is full")
|
||||||
|
|
||||||
|
// ErrInvalidDescriptorChain is returned when a descriptor chain is not
|
||||||
|
// valid for a given operation.
|
||||||
|
ErrInvalidDescriptorChain = errors.New("invalid descriptor chain")
|
||||||
|
)
|
||||||
|
|
||||||
|
// noFreeHead is used to mark when all descriptors are in use and we have no
|
||||||
|
// free chain. This value is impossible to occur as an index naturally, because
|
||||||
|
// it exceeds the maximum queue size.
|
||||||
|
const noFreeHead = uint16(math.MaxUint16)
|
||||||
|
|
||||||
|
// descriptorTableSize is the number of bytes needed to store a
|
||||||
|
// [DescriptorTable] with the given queue size in memory.
|
||||||
|
func descriptorTableSize(queueSize int) int {
|
||||||
|
return descriptorSize * queueSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// descriptorTableAlignment is the minimum alignment of a [DescriptorTable]
|
||||||
|
// in memory, as required by the virtio spec.
|
||||||
|
const descriptorTableAlignment = 16
|
||||||
|
|
||||||
|
// DescriptorTable is a table that holds [Descriptor]s, addressed via their
|
||||||
|
// index in the slice.
|
||||||
|
type DescriptorTable struct {
|
||||||
|
descriptors []Descriptor
|
||||||
|
|
||||||
|
// freeHeadIndex is the index of the head of the descriptor chain which
|
||||||
|
// contains all currently unused descriptors. When all descriptors are in
|
||||||
|
// use, this has the special value of noFreeHead.
|
||||||
|
freeHeadIndex uint16
|
||||||
|
// freeNum tracks the number of descriptors which are currently not in use.
|
||||||
|
freeNum uint16
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// newDescriptorTable creates a descriptor table that uses the given underlying
|
||||||
|
// memory. The Length of the memory slice must match the size needed for the
|
||||||
|
// descriptor table (see [descriptorTableSize]) for the given queue size.
|
||||||
|
//
|
||||||
|
// Before this descriptor table can be used, [initialize] must be called.
|
||||||
|
func newDescriptorTable(queueSize int, mem []byte) *DescriptorTable {
|
||||||
|
dtSize := descriptorTableSize(queueSize)
|
||||||
|
if len(mem) != dtSize {
|
||||||
|
panic(fmt.Sprintf("memory size (%v) does not match required size "+
|
||||||
|
"for descriptor table: %v", len(mem), dtSize))
|
||||||
|
}
|
||||||
|
|
||||||
|
return &DescriptorTable{
|
||||||
|
descriptors: unsafe.Slice((*Descriptor)(unsafe.Pointer(&mem[0])), queueSize),
|
||||||
|
// We have no free descriptors until they were initialized.
|
||||||
|
freeHeadIndex: noFreeHead,
|
||||||
|
freeNum: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Address returns the pointer to the beginning of the descriptor table in
|
||||||
|
// memory. Do not modify the memory directly to not interfere with this
|
||||||
|
// implementation.
|
||||||
|
func (dt *DescriptorTable) Address() uintptr {
|
||||||
|
if dt.descriptors == nil {
|
||||||
|
panic("descriptor table is not initialized")
|
||||||
|
}
|
||||||
|
return uintptr(unsafe.Pointer(&dt.descriptors[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BufferAddresses returns pointers to all memory pages used by the descriptor
|
||||||
|
// table to store buffers, independent of whether descriptors are currently in
|
||||||
|
// use or not. These pointers can be helpful to set up memory mappings. Do not
|
||||||
|
// use them to access or modify the memory in any way.
|
||||||
|
// Each pointer points to a whole memory page. Use [os.Getpagesize] to get the
|
||||||
|
// page size.
|
||||||
|
func (dt *DescriptorTable) BufferAddresses() []uintptr {
|
||||||
|
if dt.descriptors == nil {
|
||||||
|
panic("descriptor table is not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
dt.mu.Lock()
|
||||||
|
defer dt.mu.Unlock()
|
||||||
|
|
||||||
|
ptrs := make([]uintptr, len(dt.descriptors))
|
||||||
|
for i, desc := range dt.descriptors {
|
||||||
|
ptrs[i] = desc.address
|
||||||
|
}
|
||||||
|
|
||||||
|
return ptrs
|
||||||
|
}
|
||||||
|
|
||||||
|
// initializeDescriptors allocates buffers with the size of a full memory page
|
||||||
|
// for each descriptor in the table. While this may be a bit wasteful, it makes
|
||||||
|
// dealing with descriptors way easier. Without this preallocation, we would
|
||||||
|
// have to allocate and free memory on demand, increasing complexity.
|
||||||
|
//
|
||||||
|
// All descriptors will be marked as free and will form a free chain. The
|
||||||
|
// addresses of all descriptors will be populated while their length remains
|
||||||
|
// zero.
|
||||||
|
func (dt *DescriptorTable) initializeDescriptors() error {
|
||||||
|
pageSize := os.Getpagesize()
|
||||||
|
|
||||||
|
dt.mu.Lock()
|
||||||
|
defer dt.mu.Unlock()
|
||||||
|
|
||||||
|
for i := range dt.descriptors {
|
||||||
|
// Allocate a full memory page for this descriptor.
|
||||||
|
pagePtr, err := unix.MmapPtr(-1, 0, nil, uintptr(pageSize),
|
||||||
|
unix.PROT_READ|unix.PROT_WRITE,
|
||||||
|
unix.MAP_PRIVATE|unix.MAP_ANONYMOUS)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("allocate page for descriptor %d: %w", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dt.descriptors[i] = Descriptor{
|
||||||
|
address: uintptr(pagePtr),
|
||||||
|
length: 0,
|
||||||
|
// All descriptors should form a free chain that loops around.
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: uint16((i + 1) % len(dt.descriptors)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// All descriptors are free to use now.
|
||||||
|
dt.freeHeadIndex = 0
|
||||||
|
dt.freeNum = uint16(len(dt.descriptors))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// releaseBuffers releases all allocated buffers for this descriptor table.
|
||||||
|
// The implementation will try to release as many buffers as possible and
|
||||||
|
// collect potential errors before returning them.
|
||||||
|
// The descriptor table should no longer be used after calling this.
|
||||||
|
func (dt *DescriptorTable) releaseBuffers() error {
|
||||||
|
dt.mu.Lock()
|
||||||
|
defer dt.mu.Unlock()
|
||||||
|
|
||||||
|
var errs []error
|
||||||
|
pageSize := os.Getpagesize()
|
||||||
|
for i := range dt.descriptors {
|
||||||
|
descriptor := &dt.descriptors[i]
|
||||||
|
if descriptor.address == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// The pointer points to memory not managed by Go, so this conversion
|
||||||
|
// is safe. See https://github.com/golang/go/issues/58625
|
||||||
|
//goland:noinspection GoVetUnsafePointer
|
||||||
|
err := unix.MunmapPtr(unsafe.Pointer(descriptor.address), uintptr(pageSize))
|
||||||
|
if err == nil {
|
||||||
|
descriptor.address = 0
|
||||||
|
} else {
|
||||||
|
errs = append(errs, fmt.Errorf("release page for descriptor %d: %w", i, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// As a safety measure, make sure no descriptors can be used anymore.
|
||||||
|
dt.freeHeadIndex = noFreeHead
|
||||||
|
dt.freeNum = 0
|
||||||
|
|
||||||
|
return errors.Join(errs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// createDescriptorChain creates a new descriptor chain within the descriptor
|
||||||
|
// table which contains a number of device-readable buffers (out buffers) and
|
||||||
|
// device-writable buffers (in buffers).
|
||||||
|
//
|
||||||
|
// All buffers in the outBuffers slice will be concatenated by chaining
|
||||||
|
// descriptors, one for each buffer in the slice. The size of the single buffers
|
||||||
|
// must not exceed the size of a memory page (see [os.Getpagesize]).
|
||||||
|
// When numInBuffers is greater than zero, the given number of device-writable
|
||||||
|
// descriptors will be appended to the end of the chain, each referencing a
|
||||||
|
// whole memory page.
|
||||||
|
//
|
||||||
|
// The index of the head of the new descriptor chain will be returned. Callers
|
||||||
|
// should make sure to free the descriptor chain using [freeDescriptorChain]
|
||||||
|
// after it was used by the device.
|
||||||
|
//
|
||||||
|
// When there are not enough free descriptors to hold the given number of
|
||||||
|
// buffers, an [ErrNotEnoughFreeDescriptors] will be returned. In this case, the
|
||||||
|
// caller should try again after some descriptor chains were used by the device
|
||||||
|
// and returned back into the free chain.
|
||||||
|
func (dt *DescriptorTable) createDescriptorChain(outBuffers [][]byte, numInBuffers int) (uint16, error) {
|
||||||
|
pageSize := os.Getpagesize()
|
||||||
|
|
||||||
|
// Calculate the number of descriptors needed to build the chain.
|
||||||
|
numDesc := uint16(len(outBuffers) + numInBuffers)
|
||||||
|
|
||||||
|
// Descriptor chains must always contain at least one descriptor.
|
||||||
|
if numDesc < 1 {
|
||||||
|
return 0, ErrDescriptorChainEmpty
|
||||||
|
}
|
||||||
|
|
||||||
|
dt.mu.Lock()
|
||||||
|
defer dt.mu.Unlock()
|
||||||
|
|
||||||
|
// Do we still have enough free descriptors?
|
||||||
|
if numDesc > dt.freeNum {
|
||||||
|
return 0, fmt.Errorf("%w: %d free but needed %d", ErrNotEnoughFreeDescriptors, dt.freeNum, numDesc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Above validation ensured that there is at least one free descriptor, so
|
||||||
|
// the free descriptor chain head should be valid.
|
||||||
|
if dt.freeHeadIndex == noFreeHead {
|
||||||
|
panic("free descriptor chain head is unset but there should be free descriptors")
|
||||||
|
}
|
||||||
|
|
||||||
|
// To avoid having to iterate over the whole table to find the descriptor
|
||||||
|
// pointing to the head just to replace the free head, we instead always
|
||||||
|
// create descriptor chains from the descriptors coming after the head.
|
||||||
|
// This way we only have to touch the head as a last resort, when all other
|
||||||
|
// descriptors are already used.
|
||||||
|
head := dt.descriptors[dt.freeHeadIndex].next
|
||||||
|
next := head
|
||||||
|
tail := head
|
||||||
|
for i, buffer := range outBuffers {
|
||||||
|
desc := &dt.descriptors[next]
|
||||||
|
checkUnusedDescriptorLength(next, desc)
|
||||||
|
|
||||||
|
if len(buffer) > pageSize {
|
||||||
|
// The caller should already prevent that from happening.
|
||||||
|
panic(fmt.Sprintf("out buffer %d has size %d which exceeds page size %d", i, len(buffer), pageSize))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy the buffer to the memory referenced by the descriptor.
|
||||||
|
// The descriptor address points to memory not managed by Go, so this
|
||||||
|
// conversion is safe. See https://github.com/golang/go/issues/58625
|
||||||
|
//goland:noinspection GoVetUnsafePointer
|
||||||
|
copy(unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), pageSize), buffer)
|
||||||
|
desc.length = uint32(len(buffer))
|
||||||
|
|
||||||
|
// Clear the flags in case there were any others set.
|
||||||
|
desc.flags = descriptorFlagHasNext
|
||||||
|
|
||||||
|
tail = next
|
||||||
|
next = desc.next
|
||||||
|
}
|
||||||
|
for range numInBuffers {
|
||||||
|
desc := &dt.descriptors[next]
|
||||||
|
checkUnusedDescriptorLength(next, desc)
|
||||||
|
|
||||||
|
// Give the device the maximum available number of bytes to write into.
|
||||||
|
desc.length = uint32(pageSize)
|
||||||
|
|
||||||
|
// Mark the descriptor as device-writable.
|
||||||
|
desc.flags = descriptorFlagHasNext | descriptorFlagWritable
|
||||||
|
|
||||||
|
tail = next
|
||||||
|
next = desc.next
|
||||||
|
}
|
||||||
|
|
||||||
|
// The last descriptor should end the chain.
|
||||||
|
tailDesc := &dt.descriptors[tail]
|
||||||
|
tailDesc.flags &= ^descriptorFlagHasNext
|
||||||
|
tailDesc.next = 0 // Not necessary to clear this, it's just for looks.
|
||||||
|
|
||||||
|
dt.freeNum -= numDesc
|
||||||
|
|
||||||
|
if dt.freeNum == 0 {
|
||||||
|
// The last descriptor in the chain should be the free chain head
|
||||||
|
// itself.
|
||||||
|
if tail != dt.freeHeadIndex {
|
||||||
|
panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
|
||||||
|
}
|
||||||
|
|
||||||
|
// When this new chain takes up all remaining descriptors, we no longer
|
||||||
|
// have a free chain.
|
||||||
|
dt.freeHeadIndex = noFreeHead
|
||||||
|
} else {
|
||||||
|
// We took some descriptors out of the free chain, so make sure to close
|
||||||
|
// the circle again.
|
||||||
|
dt.descriptors[dt.freeHeadIndex].next = next
|
||||||
|
}
|
||||||
|
|
||||||
|
return head, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Implement a zero-copy variant of createDescriptorChain?
|
||||||
|
|
||||||
|
// getDescriptorChain returns the device-readable buffers (out buffers) and
|
||||||
|
// device-writable buffers (in buffers) of the descriptor chain that starts with
|
||||||
|
// the given head index. The descriptor chain must have been created using
|
||||||
|
// [createDescriptorChain] and must not have been freed yet (meaning that the
|
||||||
|
// head index must not be contained in the free chain).
|
||||||
|
//
|
||||||
|
// Be careful to only access the returned buffer slices when the device has not
|
||||||
|
// yet or is no longer using them. They must not be accessed after
|
||||||
|
// [freeDescriptorChain] has been called.
|
||||||
|
func (dt *DescriptorTable) getDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
|
||||||
|
if int(head) > len(dt.descriptors) {
|
||||||
|
return nil, nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
||||||
|
}
|
||||||
|
|
||||||
|
dt.mu.Lock()
|
||||||
|
defer dt.mu.Unlock()
|
||||||
|
|
||||||
|
// Iterate over the chain. The iteration is limited to the queue size to
|
||||||
|
// avoid ending up in an endless loop when things go very wrong.
|
||||||
|
next := head
|
||||||
|
for range len(dt.descriptors) {
|
||||||
|
if next == dt.freeHeadIndex {
|
||||||
|
return nil, nil, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
|
||||||
|
}
|
||||||
|
|
||||||
|
desc := &dt.descriptors[next]
|
||||||
|
|
||||||
|
// The descriptor address points to memory not managed by Go, so this
|
||||||
|
// conversion is safe. See https://github.com/golang/go/issues/58625
|
||||||
|
//goland:noinspection GoVetUnsafePointer
|
||||||
|
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
|
||||||
|
|
||||||
|
if desc.flags&descriptorFlagWritable == 0 {
|
||||||
|
outBuffers = append(outBuffers, bs)
|
||||||
|
} else {
|
||||||
|
inBuffers = append(inBuffers, bs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Is this the tail of the chain?
|
||||||
|
if desc.flags&descriptorFlagHasNext == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect loops.
|
||||||
|
if desc.next == head {
|
||||||
|
return nil, nil, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
|
||||||
|
}
|
||||||
|
|
||||||
|
next = desc.next
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// freeDescriptorChain can be used to free a descriptor chain when it is no
|
||||||
|
// longer in use. The descriptor chain that starts with the given index will be
|
||||||
|
// put back into the free chain, so the descriptors can be used for later calls
|
||||||
|
// of [createDescriptorChain].
|
||||||
|
// The descriptor chain must have been created using [createDescriptorChain] and
|
||||||
|
// must not have been freed yet (meaning that the head index must not be
|
||||||
|
// contained in the free chain).
|
||||||
|
func (dt *DescriptorTable) freeDescriptorChain(head uint16) error {
|
||||||
|
if int(head) > len(dt.descriptors) {
|
||||||
|
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
||||||
|
}
|
||||||
|
|
||||||
|
dt.mu.Lock()
|
||||||
|
defer dt.mu.Unlock()
|
||||||
|
|
||||||
|
// Iterate over the chain. The iteration is limited to the queue size to
|
||||||
|
// avoid ending up in an endless loop when things go very wrong.
|
||||||
|
next := head
|
||||||
|
var tailDesc *Descriptor
|
||||||
|
var chainLen uint16
|
||||||
|
for range len(dt.descriptors) {
|
||||||
|
if next == dt.freeHeadIndex {
|
||||||
|
return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
|
||||||
|
}
|
||||||
|
|
||||||
|
desc := &dt.descriptors[next]
|
||||||
|
chainLen++
|
||||||
|
|
||||||
|
// Set the length of all unused descriptors back to zero.
|
||||||
|
desc.length = 0
|
||||||
|
|
||||||
|
// Unset all flags except the next flag.
|
||||||
|
desc.flags &= descriptorFlagHasNext
|
||||||
|
|
||||||
|
// Is this the tail of the chain?
|
||||||
|
if desc.flags&descriptorFlagHasNext == 0 {
|
||||||
|
tailDesc = desc
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect loops.
|
||||||
|
if desc.next == head {
|
||||||
|
return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
|
||||||
|
}
|
||||||
|
|
||||||
|
next = desc.next
|
||||||
|
}
|
||||||
|
if tailDesc == nil {
|
||||||
|
// A descriptor chain longer than the queue size but without loops
|
||||||
|
// should be impossible.
|
||||||
|
panic(fmt.Sprintf("could not find a tail for descriptor chain starting at %d", head))
|
||||||
|
}
|
||||||
|
|
||||||
|
// The tail descriptor does not have the next flag set, but when it comes
|
||||||
|
// back into the free chain, it should have.
|
||||||
|
tailDesc.flags = descriptorFlagHasNext
|
||||||
|
|
||||||
|
if dt.freeHeadIndex == noFreeHead {
|
||||||
|
// The whole free chain was used up, so we turn this returned descriptor
|
||||||
|
// chain into the new free chain by completing the circle and using its
|
||||||
|
// head.
|
||||||
|
tailDesc.next = head
|
||||||
|
dt.freeHeadIndex = head
|
||||||
|
} else {
|
||||||
|
// Attach the returned chain at the beginning of the free chain but
|
||||||
|
// right after the free chain head.
|
||||||
|
freeHeadDesc := &dt.descriptors[dt.freeHeadIndex]
|
||||||
|
tailDesc.next = freeHeadDesc.next
|
||||||
|
freeHeadDesc.next = head
|
||||||
|
}
|
||||||
|
|
||||||
|
dt.freeNum += chainLen
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkUnusedDescriptorLength asserts that the length of an unused descriptor
|
||||||
|
// is zero, as it should be.
|
||||||
|
// This is not a requirement by the virtio spec but rather a thing we do to
|
||||||
|
// notice when our algorithm goes sideways.
|
||||||
|
func checkUnusedDescriptorLength(index uint16, desc *Descriptor) {
|
||||||
|
if desc.length != 0 {
|
||||||
|
panic(fmt.Sprintf("descriptor %d should be unused but has a non-zero length", index))
|
||||||
|
}
|
||||||
|
}
|
||||||
407
overlay/virtqueue/descriptor_table_internal_test.go
Normal file
407
overlay/virtqueue/descriptor_table_internal_test.go
Normal file
@@ -0,0 +1,407 @@
|
|||||||
|
package virtqueue
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDescriptorTable_InitializeDescriptors(t *testing.T) {
|
||||||
|
const queueSize = 32
|
||||||
|
|
||||||
|
dt := DescriptorTable{
|
||||||
|
descriptors: make([]Descriptor, queueSize),
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NoError(t, dt.initializeDescriptors())
|
||||||
|
t.Cleanup(func() {
|
||||||
|
assert.NoError(t, dt.releaseBuffers())
|
||||||
|
})
|
||||||
|
|
||||||
|
for i, descriptor := range dt.descriptors {
|
||||||
|
assert.NotZero(t, descriptor.address)
|
||||||
|
assert.Zero(t, descriptor.length)
|
||||||
|
assert.EqualValues(t, descriptorFlagHasNext, descriptor.flags)
|
||||||
|
assert.EqualValues(t, (i+1)%queueSize, descriptor.next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDescriptorTable_DescriptorChains(t *testing.T) {
|
||||||
|
// Use a very short queue size to not make this test overly verbose.
|
||||||
|
const queueSize = 8
|
||||||
|
|
||||||
|
pageSize := os.Getpagesize()
|
||||||
|
|
||||||
|
// Initialize descriptor table.
|
||||||
|
dt := DescriptorTable{
|
||||||
|
descriptors: make([]Descriptor, queueSize),
|
||||||
|
}
|
||||||
|
assert.NoError(t, dt.initializeDescriptors())
|
||||||
|
t.Cleanup(func() {
|
||||||
|
assert.NoError(t, dt.releaseBuffers())
|
||||||
|
})
|
||||||
|
|
||||||
|
// Some utilities for easier checking if the descriptor table looks as
|
||||||
|
// expected.
|
||||||
|
type desc struct {
|
||||||
|
buffer []byte
|
||||||
|
flags descriptorFlag
|
||||||
|
next uint16
|
||||||
|
}
|
||||||
|
assertDescriptorTable := func(expected [queueSize]desc) {
|
||||||
|
for i := 0; i < queueSize; i++ {
|
||||||
|
actualDesc := &dt.descriptors[i]
|
||||||
|
expectedDesc := &expected[i]
|
||||||
|
assert.Equal(t, uint32(len(expectedDesc.buffer)), actualDesc.length)
|
||||||
|
if len(expectedDesc.buffer) > 0 {
|
||||||
|
//goland:noinspection GoVetUnsafePointer
|
||||||
|
assert.EqualValues(t,
|
||||||
|
unsafe.Slice((*byte)(unsafe.Pointer(actualDesc.address)), actualDesc.length),
|
||||||
|
expectedDesc.buffer)
|
||||||
|
}
|
||||||
|
assert.Equal(t, expectedDesc.flags, actualDesc.flags)
|
||||||
|
if expectedDesc.flags&descriptorFlagHasNext != 0 {
|
||||||
|
assert.Equal(t, expectedDesc.next, actualDesc.next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initial state: All descriptors are in the free chain.
|
||||||
|
assert.Equal(t, uint16(0), dt.freeHeadIndex)
|
||||||
|
assert.Equal(t, uint16(8), dt.freeNum)
|
||||||
|
assertDescriptorTable([queueSize]desc{
|
||||||
|
{
|
||||||
|
// Free head.
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 4,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 6,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 7,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 0,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create the first chain.
|
||||||
|
firstChain, err := dt.createDescriptorChain([][]byte{
|
||||||
|
makeTestBuffer(t, 26),
|
||||||
|
makeTestBuffer(t, 256),
|
||||||
|
}, 1)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, uint16(1), firstChain)
|
||||||
|
|
||||||
|
// Now there should be a new chain next to the free chain.
|
||||||
|
assert.Equal(t, uint16(0), dt.freeHeadIndex)
|
||||||
|
assert.Equal(t, uint16(5), dt.freeNum)
|
||||||
|
assertDescriptorTable([queueSize]desc{
|
||||||
|
{
|
||||||
|
// Free head.
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 4,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Head of first chain.
|
||||||
|
buffer: makeTestBuffer(t, 26),
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
buffer: makeTestBuffer(t, 256),
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Tail of first chain.
|
||||||
|
buffer: make([]byte, pageSize),
|
||||||
|
flags: descriptorFlagWritable,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 6,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 7,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 0,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a second chain with only a single in buffer.
|
||||||
|
secondChain, err := dt.createDescriptorChain(nil, 1)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, uint16(4), secondChain)
|
||||||
|
|
||||||
|
// Now there should be two chains next to the free chain.
|
||||||
|
assert.Equal(t, uint16(0), dt.freeHeadIndex)
|
||||||
|
assert.Equal(t, uint16(4), dt.freeNum)
|
||||||
|
assertDescriptorTable([queueSize]desc{
|
||||||
|
{
|
||||||
|
// Free head.
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Head of the first chain.
|
||||||
|
buffer: makeTestBuffer(t, 26),
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
buffer: makeTestBuffer(t, 256),
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Tail of the first chain.
|
||||||
|
buffer: make([]byte, pageSize),
|
||||||
|
flags: descriptorFlagWritable,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Head and tail of the second chain.
|
||||||
|
buffer: make([]byte, pageSize),
|
||||||
|
flags: descriptorFlagWritable,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 6,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 7,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 0,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a third chain taking up all remaining descriptors.
|
||||||
|
thirdChain, err := dt.createDescriptorChain([][]byte{
|
||||||
|
makeTestBuffer(t, 42),
|
||||||
|
makeTestBuffer(t, 96),
|
||||||
|
makeTestBuffer(t, 33),
|
||||||
|
makeTestBuffer(t, 222),
|
||||||
|
}, 0)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, uint16(5), thirdChain)
|
||||||
|
|
||||||
|
// Now there should be three chains and no free chain.
|
||||||
|
assert.Equal(t, noFreeHead, dt.freeHeadIndex)
|
||||||
|
assert.Equal(t, uint16(0), dt.freeNum)
|
||||||
|
assertDescriptorTable([queueSize]desc{
|
||||||
|
{
|
||||||
|
// Tail of the third chain.
|
||||||
|
buffer: makeTestBuffer(t, 222),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Head of the first chain.
|
||||||
|
buffer: makeTestBuffer(t, 26),
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
buffer: makeTestBuffer(t, 256),
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Tail of the first chain.
|
||||||
|
buffer: make([]byte, pageSize),
|
||||||
|
flags: descriptorFlagWritable,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Head and tail of the second chain.
|
||||||
|
buffer: make([]byte, pageSize),
|
||||||
|
flags: descriptorFlagWritable,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Head of the third chain.
|
||||||
|
buffer: makeTestBuffer(t, 42),
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 6,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
buffer: makeTestBuffer(t, 96),
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 7,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
buffer: makeTestBuffer(t, 33),
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 0,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Free the third chain.
|
||||||
|
assert.NoError(t, dt.freeDescriptorChain(thirdChain))
|
||||||
|
|
||||||
|
// Now there should be two chains and a free chain again.
|
||||||
|
assert.Equal(t, uint16(5), dt.freeHeadIndex)
|
||||||
|
assert.Equal(t, uint16(4), dt.freeNum)
|
||||||
|
assertDescriptorTable([queueSize]desc{
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Head of the first chain.
|
||||||
|
buffer: makeTestBuffer(t, 26),
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
buffer: makeTestBuffer(t, 256),
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Tail of the first chain.
|
||||||
|
buffer: make([]byte, pageSize),
|
||||||
|
flags: descriptorFlagWritable,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Head and tail of the second chain.
|
||||||
|
buffer: make([]byte, pageSize),
|
||||||
|
flags: descriptorFlagWritable,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Free head.
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 6,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 7,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 0,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Free the first chain.
|
||||||
|
assert.NoError(t, dt.freeDescriptorChain(firstChain))
|
||||||
|
|
||||||
|
// Now there should be only a single chain next to the free chain.
|
||||||
|
assert.Equal(t, uint16(5), dt.freeHeadIndex)
|
||||||
|
assert.Equal(t, uint16(7), dt.freeNum)
|
||||||
|
assertDescriptorTable([queueSize]desc{
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 6,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Head and tail of the second chain.
|
||||||
|
buffer: make([]byte, pageSize),
|
||||||
|
flags: descriptorFlagWritable,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Free head.
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 7,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 0,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Free the second chain.
|
||||||
|
assert.NoError(t, dt.freeDescriptorChain(secondChain))
|
||||||
|
|
||||||
|
// Now all descriptors should be in the free chain again.
|
||||||
|
assert.Equal(t, uint16(5), dt.freeHeadIndex)
|
||||||
|
assert.Equal(t, uint16(8), dt.freeNum)
|
||||||
|
assertDescriptorTable([queueSize]desc{
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 6,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Free head.
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 4,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 7,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flags: descriptorFlagHasNext,
|
||||||
|
next: 0,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeTestBuffer(t *testing.T, length int) []byte {
|
||||||
|
t.Helper()
|
||||||
|
buf := make([]byte, length)
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
buf[i] = byte(length - i)
|
||||||
|
}
|
||||||
|
return buf
|
||||||
|
}
|
||||||
7
overlay/virtqueue/doc.go
Normal file
7
overlay/virtqueue/doc.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
// Package virtqueue implements the driver-side for a virtio queue as described
|
||||||
|
// in the specification:
|
||||||
|
// https://docs.oasis-open.org/virtio/virtio/v1.2/csd01/virtio-v1.2-csd01.html#x1-270006
|
||||||
|
// This package does not make assumptions about the device that consumes the
|
||||||
|
// queue. It rather just allocates the queue structures in memory and provides
|
||||||
|
// methods to interact with it.
|
||||||
|
package virtqueue
|
||||||
45
overlay/virtqueue/eventfd_test.go
Normal file
45
overlay/virtqueue/eventfd_test.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package virtqueue
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"gvisor.dev/gvisor/pkg/eventfd"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Tests how an eventfd and a waiting goroutine can be gracefully closed.
|
||||||
|
// Extends the eventfd test suite:
|
||||||
|
// https://github.com/google/gvisor/blob/0799336d64be65eb97d330606c30162dc3440cab/pkg/eventfd/eventfd_test.go
|
||||||
|
func TestEventFD_CancelWait(t *testing.T) {
|
||||||
|
efd, err := eventfd.Create()
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
assert.NoError(t, efd.Close())
|
||||||
|
})
|
||||||
|
|
||||||
|
var stop bool
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
for !stop {
|
||||||
|
_ = efd.Wait()
|
||||||
|
}
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
t.Fatalf("goroutine ended early")
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
}
|
||||||
|
|
||||||
|
stop = true
|
||||||
|
assert.NoError(t, efd.Notify())
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
break
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Error("goroutine did not end")
|
||||||
|
}
|
||||||
|
}
|
||||||
33
overlay/virtqueue/size.go
Normal file
33
overlay/virtqueue/size.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package virtqueue
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrQueueSizeInvalid is returned when a queue size is invalid.
|
||||||
|
var ErrQueueSizeInvalid = errors.New("queue size is invalid")
|
||||||
|
|
||||||
|
// CheckQueueSize checks if the given value would be a valid size for a
|
||||||
|
// virtqueue and returns an [ErrQueueSizeInvalid], if not.
|
||||||
|
func CheckQueueSize(queueSize int) error {
|
||||||
|
if queueSize <= 0 {
|
||||||
|
return fmt.Errorf("%w: %d is too small", ErrQueueSizeInvalid, queueSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The queue size must always be a power of 2.
|
||||||
|
// This ensures that ring indexes wrap correctly when the 16-bit integers
|
||||||
|
// overflow.
|
||||||
|
if queueSize&(queueSize-1) != 0 {
|
||||||
|
return fmt.Errorf("%w: %d is not a power of 2", ErrQueueSizeInvalid, queueSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The largest power of 2 that fits into a 16-bit integer is 32768.
|
||||||
|
// 2 * 32768 would be 65536 which no longer fits.
|
||||||
|
if queueSize > 32768 {
|
||||||
|
return fmt.Errorf("%w: %d is larger than the maximum possible queue size 32768",
|
||||||
|
ErrQueueSizeInvalid, queueSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
59
overlay/virtqueue/size_test.go
Normal file
59
overlay/virtqueue/size_test.go
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
package virtqueue
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCheckQueueSize(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
queueSize int
|
||||||
|
containsErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "negative",
|
||||||
|
queueSize: -1,
|
||||||
|
containsErr: "too small",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero",
|
||||||
|
queueSize: 0,
|
||||||
|
containsErr: "too small",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not a power of 2",
|
||||||
|
queueSize: 24,
|
||||||
|
containsErr: "not a power of 2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "too large",
|
||||||
|
queueSize: 65536,
|
||||||
|
containsErr: "larger than the maximum",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid 1",
|
||||||
|
queueSize: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid 256",
|
||||||
|
queueSize: 256,
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "valid 32768",
|
||||||
|
queueSize: 32768,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := CheckQueueSize(tt.queueSize)
|
||||||
|
if tt.containsErr != "" {
|
||||||
|
assert.ErrorContains(t, err, tt.containsErr)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
427
overlay/virtqueue/split_virtqueue.go
Normal file
427
overlay/virtqueue/split_virtqueue.go
Normal file
@@ -0,0 +1,427 @@
|
|||||||
|
package virtqueue
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
"gvisor.dev/gvisor/pkg/eventfd"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SplitQueue is a virtqueue that consists of several parts, where each part is
|
||||||
|
// writeable by either the driver or the device, but not both.
|
||||||
|
type SplitQueue struct {
|
||||||
|
// size is the size of the queue.
|
||||||
|
size int
|
||||||
|
// buf is the underlying memory used for the queue.
|
||||||
|
buf []byte
|
||||||
|
|
||||||
|
descriptorTable *DescriptorTable
|
||||||
|
availableRing *AvailableRing
|
||||||
|
usedRing *UsedRing
|
||||||
|
|
||||||
|
// kickEventFD is used to signal the device when descriptor chains were
|
||||||
|
// added to the available ring.
|
||||||
|
kickEventFD eventfd.Eventfd
|
||||||
|
// callEventFD is used by the device to signal when it has used descriptor
|
||||||
|
// chains and put them in the used ring.
|
||||||
|
callEventFD eventfd.Eventfd
|
||||||
|
|
||||||
|
// usedChains is a chanel that receives [UsedElement]s for descriptor chains
|
||||||
|
// that were used by the device.
|
||||||
|
usedChains chan UsedElement
|
||||||
|
|
||||||
|
// moreFreeDescriptors is a channel that signals when any descriptors were
|
||||||
|
// put back into the free chain of the descriptor table. This is used to
|
||||||
|
// unblock methods waiting for available room in the queue to create new
|
||||||
|
// descriptor chains again.
|
||||||
|
moreFreeDescriptors chan struct{}
|
||||||
|
|
||||||
|
// stop is used by [SplitQueue.Close] to cancel the goroutine that handles
|
||||||
|
// used buffer notifications. It blocks until the goroutine ended.
|
||||||
|
stop func() error
|
||||||
|
|
||||||
|
// offerMutex is used to synchronize calls to
|
||||||
|
// [SplitQueue.OfferDescriptorChain].
|
||||||
|
offerMutex sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSplitQueue allocates a new [SplitQueue] in memory. The given queue size
|
||||||
|
// specifies the number of entries/buffers the queue can hold. This also affects
|
||||||
|
// the memory consumption.
|
||||||
|
func NewSplitQueue(queueSize int) (_ *SplitQueue, err error) {
|
||||||
|
if err = CheckQueueSize(queueSize); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sq := SplitQueue{
|
||||||
|
size: queueSize,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up a partially initialized queue when something fails.
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
_ = sq.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// There are multiple ways for how the memory for the virtqueue could be
|
||||||
|
// allocated. We could use Go native structs with arrays inside them, but
|
||||||
|
// this wouldn't allow us to make the queue size configurable. And including
|
||||||
|
// a slice in the Go structs wouldn't work, because this would just put the
|
||||||
|
// Go slice descriptor into the memory region which the virtio device will
|
||||||
|
// not understand.
|
||||||
|
// Additionally, Go does not allow us to ensure a correct alignment of the
|
||||||
|
// parts of the virtqueue, as it is required by the virtio specification.
|
||||||
|
//
|
||||||
|
// To resolve this, let's just allocate the memory manually by allocating
|
||||||
|
// one or more memory pages, depending on the queue size. Making the
|
||||||
|
// virtqueue start at the beginning of a page is not strictly necessary, as
|
||||||
|
// the virtio specification does not require it to be continuous in the
|
||||||
|
// physical memory of the host (e.g. the vhost implementation in the kernel
|
||||||
|
// always uses copy_from_user to access it), but this makes it very easy to
|
||||||
|
// guarantee the alignment. Also, it is not required for the virtqueue parts
|
||||||
|
// to be in the same memory region, as we pass separate pointers to them to
|
||||||
|
// the device, but this design just makes things easier to implement.
|
||||||
|
//
|
||||||
|
// One added benefit of allocating the memory manually is, that we have full
|
||||||
|
// control over its lifetime and don't risk the garbage collector to collect
|
||||||
|
// our valuable structures while the device still works with them.
|
||||||
|
|
||||||
|
// The descriptor table is at the start of the page, so alignment is not an
|
||||||
|
// issue here.
|
||||||
|
descriptorTableStart := 0
|
||||||
|
descriptorTableEnd := descriptorTableStart + descriptorTableSize(queueSize)
|
||||||
|
availableRingStart := align(descriptorTableEnd, availableRingAlignment)
|
||||||
|
availableRingEnd := availableRingStart + availableRingSize(queueSize)
|
||||||
|
usedRingStart := align(availableRingEnd, usedRingAlignment)
|
||||||
|
usedRingEnd := usedRingStart + usedRingSize(queueSize)
|
||||||
|
|
||||||
|
sq.buf, err = unix.Mmap(-1, 0, usedRingEnd,
|
||||||
|
unix.PROT_READ|unix.PROT_WRITE,
|
||||||
|
unix.MAP_PRIVATE|unix.MAP_ANONYMOUS)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("allocate virtqueue buffer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sq.descriptorTable = newDescriptorTable(queueSize, sq.buf[descriptorTableStart:descriptorTableEnd])
|
||||||
|
sq.availableRing = newAvailableRing(queueSize, sq.buf[availableRingStart:availableRingEnd])
|
||||||
|
sq.usedRing = newUsedRing(queueSize, sq.buf[usedRingStart:usedRingEnd])
|
||||||
|
|
||||||
|
sq.kickEventFD, err = eventfd.Create()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create kick event file descriptor: %w", err)
|
||||||
|
}
|
||||||
|
sq.callEventFD, err = eventfd.Create()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create call event file descriptor: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = sq.descriptorTable.initializeDescriptors(); err != nil {
|
||||||
|
return nil, fmt.Errorf("initialize descriptors: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize channels.
|
||||||
|
sq.usedChains = make(chan UsedElement, queueSize)
|
||||||
|
sq.moreFreeDescriptors = make(chan struct{})
|
||||||
|
|
||||||
|
// Consume used buffer notifications in the background.
|
||||||
|
sq.stop = sq.startConsumeUsedRing()
|
||||||
|
|
||||||
|
return &sq, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size returns the size of this queue, which is the number of entries/buffers
|
||||||
|
// this queue can hold.
|
||||||
|
func (sq *SplitQueue) Size() int {
|
||||||
|
sq.ensureInitialized()
|
||||||
|
return sq.size
|
||||||
|
}
|
||||||
|
|
||||||
|
// DescriptorTable returns the [DescriptorTable] behind this queue.
|
||||||
|
func (sq *SplitQueue) DescriptorTable() *DescriptorTable {
|
||||||
|
sq.ensureInitialized()
|
||||||
|
return sq.descriptorTable
|
||||||
|
}
|
||||||
|
|
||||||
|
// AvailableRing returns the [AvailableRing] behind this queue.
|
||||||
|
func (sq *SplitQueue) AvailableRing() *AvailableRing {
|
||||||
|
sq.ensureInitialized()
|
||||||
|
return sq.availableRing
|
||||||
|
}
|
||||||
|
|
||||||
|
// UsedRing returns the [UsedRing] behind this queue.
|
||||||
|
func (sq *SplitQueue) UsedRing() *UsedRing {
|
||||||
|
sq.ensureInitialized()
|
||||||
|
return sq.usedRing
|
||||||
|
}
|
||||||
|
|
||||||
|
// KickEventFD returns the kick event file descriptor behind this queue.
|
||||||
|
// The returned file descriptor should be used with great care to not interfere
|
||||||
|
// with this implementation.
|
||||||
|
func (sq *SplitQueue) KickEventFD() int {
|
||||||
|
sq.ensureInitialized()
|
||||||
|
return sq.kickEventFD.FD()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CallEventFD returns the call event file descriptor behind this queue.
|
||||||
|
// The returned file descriptor should be used with great care to not interfere
|
||||||
|
// with this implementation.
|
||||||
|
func (sq *SplitQueue) CallEventFD() int {
|
||||||
|
sq.ensureInitialized()
|
||||||
|
return sq.callEventFD.FD()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UsedDescriptorChains returns the channel that receives [UsedElement]s for all
|
||||||
|
// descriptor chains that were used by the device.
|
||||||
|
//
|
||||||
|
// Users of the [SplitQueue] should read from this channel, handle the used
|
||||||
|
// descriptor chains and free them using [SplitQueue.FreeDescriptorChain] when
|
||||||
|
// they're done with them. When this does not happen, the queue will run full
|
||||||
|
// and any further calls to [SplitQueue.OfferDescriptorChain] will stall.
|
||||||
|
//
|
||||||
|
// When [SplitQueue.Close] is called, this channel will be closed as well.
|
||||||
|
func (sq *SplitQueue) UsedDescriptorChains() chan UsedElement {
|
||||||
|
sq.ensureInitialized()
|
||||||
|
return sq.usedChains
|
||||||
|
}
|
||||||
|
|
||||||
|
// startConsumeUsedRing starts a goroutine that runs [consumeUsedRing].
|
||||||
|
// A function is returned that can be used to gracefully cancel it.
|
||||||
|
func (sq *SplitQueue) startConsumeUsedRing() func() error {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
done := make(chan error)
|
||||||
|
go func() {
|
||||||
|
done <- sq.consumeUsedRing(ctx)
|
||||||
|
}()
|
||||||
|
return func() error {
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
// The goroutine blocks until it receives a signal on the event file
|
||||||
|
// descriptor, so it will never notice the context being canceled.
|
||||||
|
// To resolve this, we can just produce a fake-signal ourselves to wake
|
||||||
|
// it up.
|
||||||
|
if err := sq.callEventFD.Notify(); err != nil {
|
||||||
|
return fmt.Errorf("wake up goroutine: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the goroutine to end. This prevents the event file
|
||||||
|
// descriptor to be closed while it's still being used.
|
||||||
|
// If the goroutine failed, this is the last chance to propagate the
|
||||||
|
// error so it at least doesn't go unnoticed, even though the error may
|
||||||
|
// be older already.
|
||||||
|
if err := <-done; err != nil {
|
||||||
|
return fmt.Errorf("goroutine: consume used ring: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// consumeUsedRing runs in a goroutine, waits for the device to signal that it
|
||||||
|
// has used descriptor chains and puts all new [UsedElement]s into the channel
|
||||||
|
// for them.
|
||||||
|
func (sq *SplitQueue) consumeUsedRing(ctx context.Context) error {
|
||||||
|
for ctx.Err() == nil {
|
||||||
|
// Wait for a signal from the device.
|
||||||
|
if err := sq.callEventFD.Wait(); err != nil {
|
||||||
|
return fmt.Errorf("wait: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process all new used elements.
|
||||||
|
for _, usedElement := range sq.usedRing.take() {
|
||||||
|
sq.usedChains <- usedElement
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OfferDescriptorChain offers a descriptor chain to the device which contains a
|
||||||
|
// number of device-readable buffers (out buffers) and device-writable buffers
|
||||||
|
// (in buffers).
|
||||||
|
//
|
||||||
|
// All buffers in the outBuffers slice will be concatenated by chaining
|
||||||
|
// descriptors, one for each buffer in the slice. When a buffer is too large to
|
||||||
|
// fit into a single descriptor (limited by the system's page size), it will be
|
||||||
|
// split up into multiple descriptors within the chain.
|
||||||
|
// When numInBuffers is greater than zero, the given number of device-writable
|
||||||
|
// descriptors will be appended to the end of the chain, each referencing a
|
||||||
|
// whole memory page (see [os.Getpagesize]).
|
||||||
|
//
|
||||||
|
// When the queue is full and no more descriptor chains can be added, a wrapped
|
||||||
|
// [ErrNotEnoughFreeDescriptors] will be returned. If you set waitFree to true,
|
||||||
|
// this method will handle this error and will block instead until there are
|
||||||
|
// enough free descriptors again.
|
||||||
|
//
|
||||||
|
// After defining the descriptor chain in the [DescriptorTable], the index of
|
||||||
|
// the head of the chain will be made available to the device using the
|
||||||
|
// [AvailableRing] and will be returned by this method.
|
||||||
|
// Callers should read from the [SplitQueue.UsedDescriptorChains] channel to be
|
||||||
|
// notified when the descriptor chain was used by the device and should free the
|
||||||
|
// used descriptor chains again using [SplitQueue.FreeDescriptorChain] when
|
||||||
|
// they're done with them. When this does not happen, the queue will run full
|
||||||
|
// and any further calls to [SplitQueue.OfferDescriptorChain] will stall.
|
||||||
|
func (sq *SplitQueue) OfferDescriptorChain(outBuffers [][]byte, numInBuffers int, waitFree bool) (uint16, error) {
|
||||||
|
sq.ensureInitialized()
|
||||||
|
|
||||||
|
// Each descriptor can only hold a whole memory page, so split large out
|
||||||
|
// buffers into multiple smaller ones.
|
||||||
|
outBuffers = splitBuffers(outBuffers, os.Getpagesize())
|
||||||
|
|
||||||
|
// Synchronize the offering of descriptor chains. While the descriptor table
|
||||||
|
// and available ring are synchronized on their own as well, this does not
|
||||||
|
// protect us from interleaved calls which could cause reordering.
|
||||||
|
// By locking here, we can ensure that all descriptor chains are made
|
||||||
|
// available to the device in the same order as this method was called.
|
||||||
|
sq.offerMutex.Lock()
|
||||||
|
defer sq.offerMutex.Unlock()
|
||||||
|
|
||||||
|
// Create a descriptor chain for the given buffers.
|
||||||
|
var (
|
||||||
|
head uint16
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
for {
|
||||||
|
head, err = sq.descriptorTable.createDescriptorChain(outBuffers, numInBuffers)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if waitFree && errors.Is(err, ErrNotEnoughFreeDescriptors) {
|
||||||
|
// Wait for more free descriptors to be put back into the queue.
|
||||||
|
// If the number of free descriptors is still not sufficient, we'll
|
||||||
|
// land here again.
|
||||||
|
<-sq.moreFreeDescriptors
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("create descriptor chain: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make the descriptor chain available to the device.
|
||||||
|
sq.availableRing.offer([]uint16{head})
|
||||||
|
|
||||||
|
// Notify the device to make it process the updated available ring.
|
||||||
|
if err := sq.kickEventFD.Notify(); err != nil {
|
||||||
|
return head, fmt.Errorf("notify device: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return head, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDescriptorChain returns the device-readable buffers (out buffers) and
|
||||||
|
// device-writable buffers (in buffers) of the descriptor chain with the given
|
||||||
|
// head index.
|
||||||
|
// The head index must be one that was returned by a previous call to
|
||||||
|
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
|
||||||
|
// freed yet.
|
||||||
|
//
|
||||||
|
// Be careful to only access the returned buffer slices when the device is no
|
||||||
|
// longer using them. They must not be accessed after
|
||||||
|
// [SplitQueue.FreeDescriptorChain] has been called.
|
||||||
|
func (sq *SplitQueue) GetDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
|
||||||
|
sq.ensureInitialized()
|
||||||
|
return sq.descriptorTable.getDescriptorChain(head)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FreeDescriptorChain frees the descriptor chain with the given head index.
|
||||||
|
// The head index must be one that was returned by a previous call to
|
||||||
|
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
|
||||||
|
// freed yet.
|
||||||
|
//
|
||||||
|
// This creates new room in the queue which can be used by following
|
||||||
|
// [SplitQueue.OfferDescriptorChain] calls.
|
||||||
|
// When there are outstanding calls for [SplitQueue.OfferDescriptorChain] that
|
||||||
|
// are waiting for free room in the queue, they may become unblocked by this.
|
||||||
|
func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
|
||||||
|
sq.ensureInitialized()
|
||||||
|
|
||||||
|
if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
|
||||||
|
return fmt.Errorf("free: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// There is more free room in the descriptor table now.
|
||||||
|
// This is a fire-and-forget signal, so do not block when nobody listens.
|
||||||
|
select {
|
||||||
|
case sq.moreFreeDescriptors <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close releases all resources used for this queue.
|
||||||
|
// The implementation will try to release as many resources as possible and
|
||||||
|
// collect potential errors before returning them.
|
||||||
|
func (sq *SplitQueue) Close() error {
|
||||||
|
var errs []error
|
||||||
|
|
||||||
|
if sq.stop != nil {
|
||||||
|
// This has to happen before the event file descriptors may be closed.
|
||||||
|
if err := sq.stop(); err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("stop consume used ring: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// The stop function blocked until the goroutine ended, so the channel
|
||||||
|
// can now safely be closed.
|
||||||
|
close(sq.usedChains)
|
||||||
|
|
||||||
|
// Make sure that this code block is executed only once.
|
||||||
|
sq.stop = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sq.kickEventFD.Close(); err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("close kick event file descriptor: %w", err))
|
||||||
|
}
|
||||||
|
if err := sq.callEventFD.Close(); err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("close call event file descriptor: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sq.descriptorTable.releaseBuffers(); err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("release descriptor buffers: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if sq.buf != nil {
|
||||||
|
if err := unix.Munmap(sq.buf); err == nil {
|
||||||
|
sq.buf = nil
|
||||||
|
} else {
|
||||||
|
errs = append(errs, fmt.Errorf("unmap virtqueue buffer: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors.Join(errs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureInitialized is used as a guard to prevent methods to be called on an
|
||||||
|
// uninitialized instance.
|
||||||
|
func (sq *SplitQueue) ensureInitialized() {
|
||||||
|
if sq.buf == nil {
|
||||||
|
panic("used ring is not initialized")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func align(index, alignment int) int {
|
||||||
|
remainder := index % alignment
|
||||||
|
if remainder == 0 {
|
||||||
|
return index
|
||||||
|
}
|
||||||
|
return index + alignment - remainder
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitBuffers processes a list of buffers and splits each buffer that is
|
||||||
|
// larger than the size limit into multiple smaller buffers.
|
||||||
|
func splitBuffers(buffers [][]byte, sizeLimit int) [][]byte {
|
||||||
|
result := make([][]byte, 0, len(buffers))
|
||||||
|
for _, buffer := range buffers {
|
||||||
|
for added := 0; added < len(buffer); added += sizeLimit {
|
||||||
|
if len(buffer)-added <= sizeLimit {
|
||||||
|
result = append(result, buffer[added:])
|
||||||
|
break
|
||||||
|
}
|
||||||
|
result = append(result, buffer[added:added+sizeLimit])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
105
overlay/virtqueue/split_virtqueue_internal_test.go
Normal file
105
overlay/virtqueue/split_virtqueue_internal_test.go
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
package virtqueue
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSplitQueue_MemoryAlignment(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
queueSize int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "minimal queue size",
|
||||||
|
queueSize: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "small queue size",
|
||||||
|
queueSize: 8,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "large queue size",
|
||||||
|
queueSize: 256,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
sq, err := NewSplitQueue(tt.queueSize)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Zero(t, sq.descriptorTable.Address()%descriptorTableAlignment)
|
||||||
|
assert.Zero(t, sq.availableRing.Address()%availableRingAlignment)
|
||||||
|
assert.Zero(t, sq.usedRing.Address()%usedRingAlignment)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitBuffers(t *testing.T) {
|
||||||
|
const sizeLimit = 16
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
buffers [][]byte
|
||||||
|
expected [][]byte
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no buffers",
|
||||||
|
buffers: make([][]byte, 0),
|
||||||
|
expected: make([][]byte, 0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "small",
|
||||||
|
buffers: [][]byte{
|
||||||
|
make([]byte, 11),
|
||||||
|
},
|
||||||
|
expected: [][]byte{
|
||||||
|
make([]byte, 11),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exact size",
|
||||||
|
buffers: [][]byte{
|
||||||
|
make([]byte, sizeLimit),
|
||||||
|
},
|
||||||
|
expected: [][]byte{
|
||||||
|
make([]byte, sizeLimit),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "large",
|
||||||
|
buffers: [][]byte{
|
||||||
|
make([]byte, 42),
|
||||||
|
},
|
||||||
|
expected: [][]byte{
|
||||||
|
make([]byte, 16),
|
||||||
|
make([]byte, 16),
|
||||||
|
make([]byte, 10),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed",
|
||||||
|
buffers: [][]byte{
|
||||||
|
make([]byte, 7),
|
||||||
|
make([]byte, 30),
|
||||||
|
make([]byte, 15),
|
||||||
|
make([]byte, 32),
|
||||||
|
},
|
||||||
|
expected: [][]byte{
|
||||||
|
make([]byte, 7),
|
||||||
|
make([]byte, 16),
|
||||||
|
make([]byte, 14),
|
||||||
|
make([]byte, 15),
|
||||||
|
make([]byte, 16),
|
||||||
|
make([]byte, 16),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
actual := splitBuffers(tt.buffers, sizeLimit)
|
||||||
|
assert.Equal(t, tt.expected, actual)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
17
overlay/virtqueue/used_element.go
Normal file
17
overlay/virtqueue/used_element.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package virtqueue
|
||||||
|
|
||||||
|
// usedElementSize is the number of bytes needed to store a [UsedElement] in
|
||||||
|
// memory.
|
||||||
|
const usedElementSize = 8
|
||||||
|
|
||||||
|
// UsedElement is an element of the [UsedRing] and describes a descriptor chain
|
||||||
|
// that was used by the device.
|
||||||
|
type UsedElement struct {
|
||||||
|
// DescriptorIndex is the index of the head of the used descriptor chain in
|
||||||
|
// the [DescriptorTable].
|
||||||
|
// The index is 32-bit here for padding reasons.
|
||||||
|
DescriptorIndex uint32
|
||||||
|
// Length is the number of bytes written into the device writable portion of
|
||||||
|
// the buffer described by the descriptor chain.
|
||||||
|
Length uint32
|
||||||
|
}
|
||||||
12
overlay/virtqueue/used_element_internal_test.go
Normal file
12
overlay/virtqueue/used_element_internal_test.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package virtqueue
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUsedElement_Size(t *testing.T) {
|
||||||
|
assert.EqualValues(t, usedElementSize, unsafe.Sizeof(UsedElement{}))
|
||||||
|
}
|
||||||
119
overlay/virtqueue/used_ring.go
Normal file
119
overlay/virtqueue/used_ring.go
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
package virtqueue
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
// usedRingFlag is a flag that describes a [UsedRing].
|
||||||
|
type usedRingFlag uint16
|
||||||
|
|
||||||
|
const (
|
||||||
|
// usedRingFlagNoNotify is used by the host to advise the guest to not
|
||||||
|
// kick it when adding a buffer. It's unreliable, so it's simply an
|
||||||
|
// optimization. Guest will still kick when it's out of buffers.
|
||||||
|
usedRingFlagNoNotify usedRingFlag = 1 << iota
|
||||||
|
)
|
||||||
|
|
||||||
|
// usedRingSize is the number of bytes needed to store a [UsedRing] with the
|
||||||
|
// given queue size in memory.
|
||||||
|
func usedRingSize(queueSize int) int {
|
||||||
|
return 6 + usedElementSize*queueSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// usedRingAlignment is the minimum alignment of a [UsedRing] in memory, as
|
||||||
|
// required by the virtio spec.
|
||||||
|
const usedRingAlignment = 4
|
||||||
|
|
||||||
|
// UsedRing is where the device returns descriptor chains once it is done with
|
||||||
|
// them. Each ring entry is a [UsedElement]. It is only written to by the device
|
||||||
|
// and read by the driver.
|
||||||
|
//
|
||||||
|
// Because the size of the ring depends on the queue size, we cannot define a
|
||||||
|
// Go struct with a static size that maps to the memory of the ring. Instead,
|
||||||
|
// this struct only contains pointers to the corresponding memory areas.
|
||||||
|
type UsedRing struct {
|
||||||
|
initialized bool
|
||||||
|
|
||||||
|
// flags that describe this ring.
|
||||||
|
flags *usedRingFlag
|
||||||
|
// ringIndex indicates where the device would put the next entry into the
|
||||||
|
// ring (modulo the queue size).
|
||||||
|
ringIndex *uint16
|
||||||
|
// ring contains the [UsedElement]s. It wraps around at queue size.
|
||||||
|
ring []UsedElement
|
||||||
|
// availableEvent is not used by this implementation, but we reserve it
|
||||||
|
// anyway to avoid issues in case a device may try to write to it, contrary
|
||||||
|
// to the virtio specification.
|
||||||
|
availableEvent *uint16
|
||||||
|
|
||||||
|
// lastIndex is the internal ringIndex up to which all [UsedElement]s were
|
||||||
|
// processed.
|
||||||
|
lastIndex uint16
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// newUsedRing creates a used ring that uses the given underlying memory. The
|
||||||
|
// length of the memory slice must match the size needed for the ring (see
|
||||||
|
// [usedRingSize]) for the given queue size.
|
||||||
|
func newUsedRing(queueSize int, mem []byte) *UsedRing {
|
||||||
|
ringSize := usedRingSize(queueSize)
|
||||||
|
if len(mem) != ringSize {
|
||||||
|
panic(fmt.Sprintf("memory size (%v) does not match required size "+
|
||||||
|
"for used ring: %v", len(mem), ringSize))
|
||||||
|
}
|
||||||
|
|
||||||
|
r := UsedRing{
|
||||||
|
initialized: true,
|
||||||
|
flags: (*usedRingFlag)(unsafe.Pointer(&mem[0])),
|
||||||
|
ringIndex: (*uint16)(unsafe.Pointer(&mem[2])),
|
||||||
|
ring: unsafe.Slice((*UsedElement)(unsafe.Pointer(&mem[4])), queueSize),
|
||||||
|
availableEvent: (*uint16)(unsafe.Pointer(&mem[ringSize-2])),
|
||||||
|
}
|
||||||
|
r.lastIndex = *r.ringIndex
|
||||||
|
return &r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Address returns the pointer to the beginning of the ring in memory.
|
||||||
|
// Do not modify the memory directly to not interfere with this implementation.
|
||||||
|
func (r *UsedRing) Address() uintptr {
|
||||||
|
if !r.initialized {
|
||||||
|
panic("used ring is not initialized")
|
||||||
|
}
|
||||||
|
return uintptr(unsafe.Pointer(r.flags))
|
||||||
|
}
|
||||||
|
|
||||||
|
// take returns all new [UsedElement]s that the device put into the ring and
|
||||||
|
// that weren't already returned by a previous call to this method.
|
||||||
|
func (r *UsedRing) take() []UsedElement {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
ringIndex := *r.ringIndex
|
||||||
|
if ringIndex == r.lastIndex {
|
||||||
|
// Nothing new.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the number new used elements that we can read from the ring.
|
||||||
|
// The ring index may wrap, so special handling for that case is needed.
|
||||||
|
count := int(ringIndex - r.lastIndex)
|
||||||
|
if count < 0 {
|
||||||
|
count += 0xffff
|
||||||
|
}
|
||||||
|
|
||||||
|
// The number of new elements can never exceed the queue size.
|
||||||
|
if count > len(r.ring) {
|
||||||
|
panic("used ring contains more new elements than the ring is long")
|
||||||
|
}
|
||||||
|
|
||||||
|
elems := make([]UsedElement, count)
|
||||||
|
for i := range count {
|
||||||
|
elems[i] = r.ring[r.lastIndex%uint16(len(r.ring))]
|
||||||
|
r.lastIndex++
|
||||||
|
}
|
||||||
|
|
||||||
|
return elems
|
||||||
|
}
|
||||||
136
overlay/virtqueue/used_ring_internal_test.go
Normal file
136
overlay/virtqueue/used_ring_internal_test.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package virtqueue
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUsedRing_MemoryLayout(t *testing.T) {
|
||||||
|
const queueSize = 2
|
||||||
|
|
||||||
|
memory := make([]byte, usedRingSize(queueSize))
|
||||||
|
r := newUsedRing(queueSize, memory)
|
||||||
|
|
||||||
|
*r.flags = 0x01ff
|
||||||
|
*r.ringIndex = 1
|
||||||
|
r.ring[0] = UsedElement{
|
||||||
|
DescriptorIndex: 0x0123,
|
||||||
|
Length: 0x4567,
|
||||||
|
}
|
||||||
|
r.ring[1] = UsedElement{
|
||||||
|
DescriptorIndex: 0x89ab,
|
||||||
|
Length: 0xcdef,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, []byte{
|
||||||
|
0xff, 0x01,
|
||||||
|
0x01, 0x00,
|
||||||
|
0x23, 0x01, 0x00, 0x00,
|
||||||
|
0x67, 0x45, 0x00, 0x00,
|
||||||
|
0xab, 0x89, 0x00, 0x00,
|
||||||
|
0xef, 0xcd, 0x00, 0x00,
|
||||||
|
0x00, 0x00,
|
||||||
|
}, memory)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsedRing_Take(t *testing.T) {
|
||||||
|
const queueSize = 8
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ring []UsedElement
|
||||||
|
ringIndex uint16
|
||||||
|
lastIndex uint16
|
||||||
|
expected []UsedElement
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nothing new",
|
||||||
|
ring: []UsedElement{
|
||||||
|
{DescriptorIndex: 1},
|
||||||
|
{DescriptorIndex: 2},
|
||||||
|
{DescriptorIndex: 3},
|
||||||
|
{DescriptorIndex: 4},
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
},
|
||||||
|
ringIndex: 4,
|
||||||
|
lastIndex: 4,
|
||||||
|
expected: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no overflow",
|
||||||
|
ring: []UsedElement{
|
||||||
|
{DescriptorIndex: 1},
|
||||||
|
{DescriptorIndex: 2},
|
||||||
|
{DescriptorIndex: 3},
|
||||||
|
{DescriptorIndex: 4},
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
},
|
||||||
|
ringIndex: 4,
|
||||||
|
lastIndex: 1,
|
||||||
|
expected: []UsedElement{
|
||||||
|
{DescriptorIndex: 2},
|
||||||
|
{DescriptorIndex: 3},
|
||||||
|
{DescriptorIndex: 4},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ring overflow",
|
||||||
|
ring: []UsedElement{
|
||||||
|
{DescriptorIndex: 9},
|
||||||
|
{DescriptorIndex: 10},
|
||||||
|
{DescriptorIndex: 3},
|
||||||
|
{DescriptorIndex: 4},
|
||||||
|
{DescriptorIndex: 5},
|
||||||
|
{DescriptorIndex: 6},
|
||||||
|
{DescriptorIndex: 7},
|
||||||
|
{DescriptorIndex: 8},
|
||||||
|
},
|
||||||
|
ringIndex: 10,
|
||||||
|
lastIndex: 7,
|
||||||
|
expected: []UsedElement{
|
||||||
|
{DescriptorIndex: 8},
|
||||||
|
{DescriptorIndex: 9},
|
||||||
|
{DescriptorIndex: 10},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "index overflow",
|
||||||
|
ring: []UsedElement{
|
||||||
|
{DescriptorIndex: 9},
|
||||||
|
{DescriptorIndex: 10},
|
||||||
|
{DescriptorIndex: 3},
|
||||||
|
{DescriptorIndex: 4},
|
||||||
|
{DescriptorIndex: 5},
|
||||||
|
{DescriptorIndex: 6},
|
||||||
|
{DescriptorIndex: 7},
|
||||||
|
{DescriptorIndex: 8},
|
||||||
|
},
|
||||||
|
ringIndex: 2,
|
||||||
|
lastIndex: 65535,
|
||||||
|
expected: []UsedElement{
|
||||||
|
{DescriptorIndex: 8},
|
||||||
|
{DescriptorIndex: 9},
|
||||||
|
{DescriptorIndex: 10},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
memory := make([]byte, usedRingSize(queueSize))
|
||||||
|
r := newUsedRing(queueSize, memory)
|
||||||
|
|
||||||
|
copy(r.ring, tt.ring)
|
||||||
|
*r.ringIndex = tt.ringIndex
|
||||||
|
r.lastIndex = tt.lastIndex
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expected, r.take())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user