This commit is contained in:
JackDoan
2025-11-07 16:50:43 -06:00
parent 6e22bfeeb1
commit e3be0943fd
13 changed files with 1469 additions and 8 deletions

412
overlay/vhostnet/device.go Normal file
View File

@@ -0,0 +1,412 @@
package vhostnet
import (
"errors"
"fmt"
"os"
"runtime"
"github.com/hetznercloud/virtio-go/vhost"
"github.com/hetznercloud/virtio-go/virtio"
"github.com/hetznercloud/virtio-go/virtqueue"
"golang.org/x/sys/unix"
)
// ErrDeviceClosed is returned when the [Device] is closed while operations are
// still running.
var ErrDeviceClosed = errors.New("device was closed")
// The indexes for the receive and transmit queues.
const (
receiveQueueIndex = 0
transmitQueueIndex = 1
)
// Device represents a vhost networking device within the kernel-level virtio
// implementation and provides methods to interact with it.
type Device struct {
initialized bool
controlFD int
receiveQueue *virtqueue.SplitQueue
transmitQueue *virtqueue.SplitQueue
// transmitted contains channels for each possible descriptor chain head
// index. This is used for packet transmit notifications.
// When a packet was transmitted and the descriptor chain was used by the
// device, the corresponding channel receives the [virtqueue.UsedElement]
// instance provided by the device.
transmitted []chan virtqueue.UsedElement
}
// NewDevice initializes a new vhost networking device within the
// kernel-level virtio implementation, sets up the virtqueues and returns a
// [Device] instance that can be used to communicate with that vhost device.
//
// There are multiple options that can be passed to this constructor to
// influence device creation:
// - [WithQueueSize]
// - [WithBackendFD]
// - [WithBackendDevice]
//
// Remember to call [Device.Close] after use to free up resources.
func NewDevice(options ...Option) (_ *Device, err error) {
opts := optionDefaults
opts.apply(options)
if err = opts.validate(); err != nil {
return nil, fmt.Errorf("invalid options: %w", err)
}
dev := Device{
controlFD: -1,
}
// Clean up a partially initialized device when something fails.
defer func() {
if err != nil {
_ = dev.Close()
}
}()
// Retrieve a new control file descriptor. This will be used to configure
// the vhost networking device in the kernel.
dev.controlFD, err = unix.Open("/dev/vhost-net", os.O_RDWR, 0666)
if err != nil {
return nil, fmt.Errorf("get control file descriptor: %w", err)
}
if err = vhost.OwnControlFD(dev.controlFD); err != nil {
return nil, fmt.Errorf("own control file descriptor: %w", err)
}
// Advertise the supported features. This isn't much for now.
// TODO: Add feature options and implement proper feature negotiation.
features := virtio.FeatureVersion1 // | virtio.FeatureNetMergeRXBuffers
if err = vhost.SetFeatures(dev.controlFD, features); err != nil {
return nil, fmt.Errorf("set features: %w", err)
}
// Initialize and register the queues needed for the networking device.
if dev.receiveQueue, err = createQueue(dev.controlFD, receiveQueueIndex, opts.queueSize); err != nil {
return nil, fmt.Errorf("create receive queue: %w", err)
}
if dev.transmitQueue, err = createQueue(dev.controlFD, transmitQueueIndex, opts.queueSize); err != nil {
return nil, fmt.Errorf("create transmit queue: %w", err)
}
// Set up memory mappings for all buffers used by the queues. This has to
// happen before a backend for the queues can be registered.
memoryLayout := vhost.NewMemoryLayoutForQueues(
[]*virtqueue.SplitQueue{dev.receiveQueue, dev.transmitQueue},
)
if err = vhost.SetMemoryLayout(dev.controlFD, memoryLayout); err != nil {
return nil, fmt.Errorf("setup memory layout: %w", err)
}
// Set the queue backends. This activates the queues within the kernel.
if err = SetQueueBackend(dev.controlFD, receiveQueueIndex, opts.backendFD); err != nil {
return nil, fmt.Errorf("set receive queue backend: %w", err)
}
if err = SetQueueBackend(dev.controlFD, transmitQueueIndex, opts.backendFD); err != nil {
return nil, fmt.Errorf("set transmit queue backend: %w", err)
}
// Fully populate the receive queue with available buffers which the device
// can write new packets into.
if err = dev.refillReceiveQueue(); err != nil {
return nil, fmt.Errorf("refill receive queue: %w", err)
}
// Initialize channels for transmit notifications.
dev.transmitted = make([]chan virtqueue.UsedElement, dev.transmitQueue.Size())
for i := range len(dev.transmitted) {
// It is important to use a single-element buffered channel here.
// When the channel was unbuffered and the monitorTransmitQueue
// goroutine would write into it, the writing would block which could
// lead to deadlocks in case transmit notifications do not arrive in
// order.
// When the goroutine would use fire-and-forget to write into that
// channel, there may be a chance that the TransmitPacket does not
// receive the transmit notification due to this being a race condition.
// Buffering a single transmit notification resolves this without race
// conditions or possible deadlocks.
dev.transmitted[i] = make(chan virtqueue.UsedElement, 1)
}
// Monitor transmit queue in background.
go dev.monitorTransmitQueue()
dev.initialized = true
// Make sure to clean up even when the device gets garbage collected without
// Close being called first.
devPtr := &dev
runtime.SetFinalizer(devPtr, (*Device).Close)
return devPtr, nil
}
// monitorTransmitQueue waits for the device to advertise used descriptor chains
// in the transmit queue and produces a transmit notification via the
// corresponding channel.
func (dev *Device) monitorTransmitQueue() {
usedChan := dev.transmitQueue.UsedDescriptorChains()
for {
used, ok := <-usedChan
if !ok {
// The queue was closed.
return
}
if int(used.DescriptorIndex) > len(dev.transmitted) {
panic(fmt.Sprintf("device provided a used descriptor index (%d) that is out of range",
used.DescriptorIndex))
}
dev.transmitted[used.DescriptorIndex] <- used
}
}
// TransmitPacket writes the given packet into the transmit queue of this
// device. The packet will be prepended with the [virtio.NetHdr].
//
// When the queue is full, this will block until the queue has enough room to
// transmit the packet. This method will not return before the packet was
// transmitted and the device notifies that it has used the packet buffer.
func (dev *Device) TransmitPacket(vnethdr virtio.NetHdr, packet []byte) error {
// Prepend the packet with its virtio-net header.
vnethdrBuf := make([]byte, virtio.NetHdrSize) //todo WHY
if err := vnethdr.Encode(vnethdrBuf); err != nil {
return fmt.Errorf("encode vnethdr: %w", err)
}
outBuffers := [][]byte{vnethdrBuf, packet}
chainIndex, err := dev.transmitQueue.OfferDescriptorChain(outBuffers, 0, true)
if err != nil {
return fmt.Errorf("offer descriptor chain: %w", err)
}
// Wait for the packet to have been transmitted.
<-dev.transmitted[chainIndex]
if err = dev.transmitQueue.FreeDescriptorChain(chainIndex); err != nil {
return fmt.Errorf("free descriptor chain: %w", err)
}
return nil
}
// ReceivePacket reads the next available packet from the receive queue of this
// device and returns its [virtio.NetHdr] and packet data separately.
//
// When no packet is available, this will block until there is one.
//
// When this method returns an error, the receive queue will likely be in a
// broken state which this implementation cannot recover from. The caller should
// close the device and not attempt any additional receives.
func (dev *Device) ReceivePacket() (virtio.NetHdr, []byte, error) {
var (
chainHeads []uint16
vnethdr virtio.NetHdr
buffers [][]byte
// Each packet starts with a virtio-net header which we have to subtract
// from the total length.
packetLength = -virtio.NetHdrSize
)
// We presented FeatureNetMergeRXBuffers to the device, so one packet may be
// made of multiple descriptor chains which are to be merged.
for remainingChains := 1; remainingChains > 0; remainingChains-- {
// Get the next descriptor chain.
usedElement, ok := <-dev.receiveQueue.UsedDescriptorChains()
if !ok {
return virtio.NetHdr{}, nil, ErrDeviceClosed
}
// Track this chain to be freed later.
head := uint16(usedElement.DescriptorIndex)
chainHeads = append(chainHeads, head)
outBuffers, inBuffers, err := dev.receiveQueue.GetDescriptorChain(head)
if err != nil {
// When this fails we may miss to free some descriptor chains. We
// could try to mitigate this by deferring the freeing somehow, but
// it's not worth the hassle. When this method fails, the queue will
// be in a broken state anyway.
return virtio.NetHdr{}, nil, fmt.Errorf("get descriptor chain: %w", err)
}
if len(outBuffers) > 0 {
// How did this happen!?
panic("receive queue contains device-readable buffers")
}
if len(inBuffers) == 0 {
// Empty descriptor chains should not be possible.
panic("descriptor chain contains no buffers")
}
// The device tells us how many bytes of the descriptor chain it has
// actually written to. The specification forces the device to fully
// fill up all but the last descriptor chain when multiple descriptor
// chains are being merged, but being more compatible here doesn't hurt.
inBuffers = truncateBuffers(inBuffers, int(usedElement.Length))
packetLength += int(usedElement.Length)
// Is this the first descriptor chain we process?
if len(buffers) == 0 {
// The specification requires that the first descriptor chain starts
// with a virtio-net header. It is not clear, whether it is also
// required to be fully contained in the first buffer of that
// descriptor chain, but it is reasonable to assume that this is
// always the case.
// The decode method already does the buffer length check.
if err = vnethdr.Decode(inBuffers[0]); err != nil {
// The device misbehaved. There is no way we can gracefully
// recover from this, because we don't know how many of the
// following descriptor chains belong to this packet.
return virtio.NetHdr{}, nil, fmt.Errorf("decode vnethdr: %w", err)
}
inBuffers[0] = inBuffers[0][virtio.NetHdrSize:]
// The virtio-net header tells us how many descriptor chains this
// packet is long.
remainingChains = int(vnethdr.NumBuffers)
}
buffers = append(buffers, inBuffers...)
}
// Copy all the buffers together to produce the complete packet slice.
packet := make([]byte, packetLength)
copied := 0
for _, buffer := range buffers {
copied += copy(packet[copied:], buffer)
}
if copied != packetLength {
panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", packetLength, copied))
}
// Now that we have copied all buffers, we can free the used descriptor
// chains again.
// TODO: Recycling the descriptor chains would be more efficient than
// freeing them just to offer them again right after.
for _, head := range chainHeads {
if err := dev.receiveQueue.FreeDescriptorChain(head); err != nil {
return virtio.NetHdr{}, nil, fmt.Errorf("free descriptor chain with head index %d: %w", head, err)
}
}
// It's advised to always keep the receive queue fully populated with
// available buffers which the device can write new packets into.
if err := dev.refillReceiveQueue(); err != nil {
return virtio.NetHdr{}, nil, fmt.Errorf("refill receive queue: %w", err)
}
return vnethdr, packet, nil
}
// TODO: Make above methods cancelable by taking a context.Context argument?
// TODO: Implement zero-copy variants to transmit and receive packets?
// refillReceiveQueue offers as many new device-writable buffers to the device
// as the queue can fit. The device will then use these to write received
// packets.
func (dev *Device) refillReceiveQueue() error {
for {
_, err := dev.receiveQueue.OfferDescriptorChain(nil, 1, false)
if err != nil {
if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
// Queue is full, job is done.
return nil
}
return fmt.Errorf("offer descriptor chain: %w", err)
}
}
}
// Close cleans up the vhost networking device within the kernel and releases
// all resources used for it.
// The implementation will try to release as many resources as possible and
// collect potential errors before returning them.
func (dev *Device) Close() error {
dev.initialized = false
// Closing the control file descriptor will unregister all queues from the
// kernel.
if dev.controlFD >= 0 {
if err := unix.Close(dev.controlFD); err != nil {
// Return an error and do not continue, because the memory used for
// the queues should not be released before they were unregistered
// from the kernel.
return fmt.Errorf("close control file descriptor: %w", err)
}
dev.controlFD = -1
}
var errs []error
if dev.receiveQueue != nil {
if err := dev.receiveQueue.Close(); err == nil {
dev.receiveQueue = nil
} else {
errs = append(errs, fmt.Errorf("close receive queue: %w", err))
}
}
if dev.transmitQueue != nil {
if err := dev.transmitQueue.Close(); err == nil {
dev.transmitQueue = nil
} else {
errs = append(errs, fmt.Errorf("close transmit queue: %w", err))
}
}
if len(errs) == 0 {
// Everything was cleaned up. No need to run the finalizer anymore.
runtime.SetFinalizer(dev, nil)
}
return errors.Join(errs...)
}
// ensureInitialized is used as a guard to prevent methods to be called on an
// uninitialized instance.
func (dev *Device) ensureInitialized() {
if !dev.initialized {
panic("device is not initialized")
}
}
// createQueue creates a new virtqueue and registers it with the vhost device
// using the given index.
func createQueue(controlFD int, queueIndex int, queueSize int) (*virtqueue.SplitQueue, error) {
var (
queue *virtqueue.SplitQueue
err error
)
if queue, err = virtqueue.NewSplitQueue(queueSize); err != nil {
return nil, fmt.Errorf("create virtqueue: %w", err)
}
if err = vhost.RegisterQueue(controlFD, uint32(queueIndex), queue); err != nil {
return nil, fmt.Errorf("register virtqueue with index %d: %w", queueIndex, err)
}
return queue, nil
}
// truncateBuffers returns a new list of buffers whose combined length matches
// exactly the specified length. When the specified length exceeds the length of
// the buffers, this is an error. When it is smaller, the buffer list will be
// truncated accordingly.
func truncateBuffers(buffers [][]byte, length int) (out [][]byte) {
for _, buffer := range buffers {
if length < len(buffer) {
out = append(out, buffer[:length])
return
}
out = append(out, buffer)
length -= len(buffer)
}
if length > 0 {
panic("length exceeds the combined length of all buffers")
}
return
}

View File

@@ -0,0 +1,86 @@
package vhostnet
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestTruncateBuffers(t *testing.T) {
tests := []struct {
name string
buffers [][]byte
length int
expected [][]byte
}{
{
name: "no buffers",
buffers: nil,
length: 0,
expected: nil,
},
{
name: "single buffer correct length",
buffers: [][]byte{
make([]byte, 100),
},
length: 100,
expected: [][]byte{
make([]byte, 100),
},
},
{
name: "single buffer truncated",
buffers: [][]byte{
make([]byte, 100),
},
length: 90,
expected: [][]byte{
make([]byte, 90),
},
},
{
name: "multiple buffers correct length",
buffers: [][]byte{
make([]byte, 200),
make([]byte, 100),
},
length: 300,
expected: [][]byte{
make([]byte, 200),
make([]byte, 100),
},
},
{
name: "multiple buffers truncated",
buffers: [][]byte{
make([]byte, 200),
make([]byte, 100),
},
length: 250,
expected: [][]byte{
make([]byte, 200),
make([]byte, 50),
},
},
{
name: "multiple buffers truncated buffer list",
buffers: [][]byte{
make([]byte, 200),
make([]byte, 200),
make([]byte, 200),
},
length: 350,
expected: [][]byte{
make([]byte, 200),
make([]byte, 150),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := truncateBuffers(tt.buffers, tt.length)
assert.Equal(t, tt.expected, actual)
})
}
}

View File

@@ -0,0 +1,224 @@
package vhostnet_test
import (
"fmt"
"os"
"sync"
"testing"
"github.com/gopacket/gopacket/afpacket"
"github.com/hetznercloud/virtio-go/internal/testsupport"
"github.com/hetznercloud/virtio-go/tuntap"
"github.com/hetznercloud/virtio-go/vhostnet"
"github.com/hetznercloud/virtio-go/virtio"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)
// Here is the general idea of how the following tests work to verify the
// correct communication with the vhost-net device within the kernel:
//
// +-----------------------------------+
// | go test running in user space |
// +-----------------------------------+
// ^ ^
// | |
// capture / write transmit / receive
// using AF_PACKET using this package
// | |
// v v
// +----------------+ +-----------+
// | tun (TAP mode) |<---->| vhost-net |
// +----------------+ +-----------+
//
func TestDevice_TransmitPacket(t *testing.T) {
testsupport.VirtrunOnly(t)
fx := NewTestFixture(t)
for _, length := range []int{64, 1514, 9014, 64100} {
t.Run(fmt.Sprintf("%d byte packet", length), func(t *testing.T) {
vnethdr, pkt := testsupport.TestPacket(t, fx.TAPDevice.MAC(), length)
// Transmit the packet over the vhost-net device.
require.NoError(t, fx.NetDevice.TransmitPacket(vnethdr, pkt))
// Check if the packet arrived at the TAP device. The virtio-net
// header should have been stripped by the TAP device.
data, _, err := fx.TPacket.ReadPacketData()
assert.NoError(t, err)
assert.Equal(t, pkt, data)
})
}
}
func TestDevice_ReceivePacket(t *testing.T) {
testsupport.VirtrunOnly(t)
fx := NewTestFixture(t)
for _, length := range []int{64, 1514, 9014, 64100} {
t.Run(fmt.Sprintf("%d byte packet", length), func(t *testing.T) {
vnethdr, pkt := testsupport.TestPacket(t, fx.TAPDevice.MAC(), length)
prependedPkt := testsupport.PrependPacket(t, vnethdr, pkt)
// Write the prepended packet to the TAP device.
require.NoError(t, fx.TPacket.WritePacketData(prependedPkt))
// Try to receive the packet on the vhost-net device.
vnethdr, data, err := fx.NetDevice.ReceivePacket()
assert.NoError(t, err)
assert.Equal(t, pkt, data)
// Large packets should have been received as multiple buffers.
assert.Equal(t, (len(prependedPkt)/os.Getpagesize())+1, int(vnethdr.NumBuffers))
})
}
}
func TestDevice_TransmitManyPackets(t *testing.T) {
testsupport.VirtrunOnly(t)
fx := NewTestFixture(t)
// Test with a packet which does not fit into a single memory page.
vnethdr, pkt := testsupport.TestPacket(t, fx.TAPDevice.MAC(), 9014)
const count = 1024
var received int
var wg sync.WaitGroup
wg.Go(func() {
for range count {
err := fx.NetDevice.TransmitPacket(vnethdr, pkt)
if !assert.NoError(t, err) {
return
}
}
})
wg.Go(func() {
for range count {
data, _, err := fx.TPacket.ReadPacketData()
if !assert.NoError(t, err) {
return
}
assert.Equal(t, pkt, data)
received++
}
})
wg.Wait()
assert.Equal(t, count, received)
}
func TestDevice_ReceiveManyPackets(t *testing.T) {
testsupport.VirtrunOnly(t)
fx := NewTestFixture(t)
// Test with a packet which does not fit into a single memory page.
vnethdr, pkt := testsupport.TestPacket(t, fx.TAPDevice.MAC(), 9014)
prependedPkt := testsupport.PrependPacket(t, vnethdr, pkt)
const count = 1024
var received int
var wg sync.WaitGroup
wg.Go(func() {
for range count {
err := fx.TPacket.WritePacketData(prependedPkt)
if !assert.NoError(t, err) {
return
}
}
})
wg.Go(func() {
for range count {
_, data, err := fx.NetDevice.ReceivePacket()
if !assert.NoError(t, err) {
return
}
assert.Equal(t, pkt, data)
received++
}
})
wg.Wait()
assert.Equal(t, count, received)
}
type TestFixture struct {
TAPDevice *tuntap.Device
NetDevice *vhostnet.Device
TPacket *afpacket.TPacket
}
func NewTestFixture(t *testing.T) *TestFixture {
testsupport.VirtrunOnly(t)
// In case something doesn't work, some more debug logging from the kernel
// modules may be very helpful.
testsupport.EnableDynamicDebug(t, "module tun")
testsupport.EnableDynamicDebug(t, "module vhost")
testsupport.EnableDynamicDebug(t, "module vhost_net")
// Make sure the Linux kernel does not send router solicitations that may
// interfere with these tests.
testsupport.SetSysctl(t, "net.ipv6.conf.all.disable_ipv6", "1")
var (
fx TestFixture
err error
)
// Create a TAP device.
fx.TAPDevice, err = tuntap.NewDevice(
tuntap.WithDeviceType(tuntap.DeviceTypeTAP),
// Helps to stop the Linux kernel from sending packets on this
// interface.
tuntap.WithInterfaceFlags(unix.IFF_NOARP),
// Packets going over this device are prepended with a virtio-net
// header. When this is not set, then packets written to the TAP device
// will be passed to the Linux network stack without their virtio-net
// header stripped.
tuntap.WithVirtioNetHdr(true),
// When writing packets into the TAP device using the RAW socket, we
// don't want the offloads to be applied by the kernel. Advertising
// offload support makes the kernel pass the offload request along to
// our vhost-net device.
tuntap.WithOffloads(unix.TUN_F_CSUM|unix.TUN_F_USO4|unix.TUN_F_USO6),
)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, fx.TAPDevice.Close())
})
// Create a vhost-net device that uses the TAP device as the backend.
fx.NetDevice, err = vhostnet.NewDevice(
vhostnet.WithQueueSize(32),
vhostnet.WithBackendDevice(fx.TAPDevice),
)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, fx.NetDevice.Close())
})
// Open a RAW socket to capture packets arriving at the TAP device or
// write packets into it.
fx.TPacket, err = afpacket.NewTPacket(
afpacket.SocketRaw,
afpacket.TPacketVersion3,
afpacket.OptInterface(fx.TAPDevice.Name()),
// Tell the kernel that packets written to this socket are prepended
// with a virto-net header. This is used to communicate the use of GSO
// for large packets.
afpacket.OptVNetHdrSize(virtio.NetHdrSize),
)
require.NoError(t, err)
t.Cleanup(fx.TPacket.Close)
return &fx
}

3
overlay/vhostnet/doc.go Normal file
View File

@@ -0,0 +1,3 @@
// Package vhostnet implements methods to initialize vhost networking devices
// within the kernel-level virtio implementation and communicate with them.
package vhostnet

31
overlay/vhostnet/ioctl.go Normal file
View File

@@ -0,0 +1,31 @@
package vhostnet
import (
"fmt"
"unsafe"
"github.com/hetznercloud/virtio-go/vhost"
)
const (
// vhostNetIoctlSetBackend can be used to attach a virtqueue to a RAW socket
// or TAP device.
//
// Request payload: [vhost.QueueFile]
// Kernel name: VHOST_NET_SET_BACKEND
vhostNetIoctlSetBackend = 0x4008af30
)
// SetQueueBackend attaches a virtqueue of the vhost networking device
// described by controlFD to the given backend file descriptor.
// The backend file descriptor can either be a RAW socket or a TAP device. When
// it is -1, the queue will be detached.
func SetQueueBackend(controlFD int, queueIndex uint32, backendFD int) error {
if err := vhost.IoctlPtr(controlFD, vhostNetIoctlSetBackend, unsafe.Pointer(&vhost.QueueFile{
QueueIndex: queueIndex,
FD: int32(backendFD),
})); err != nil {
return fmt.Errorf("set queue backend file descriptor: %w", err)
}
return nil
}

View File

@@ -0,0 +1,70 @@
package vhostnet
import (
"errors"
"github.com/hetznercloud/virtio-go/tuntap"
"github.com/hetznercloud/virtio-go/virtqueue"
)
type optionValues struct {
queueSize int
backendFD int
}
func (o *optionValues) apply(options []Option) {
for _, option := range options {
option(o)
}
}
func (o *optionValues) validate() error {
if o.queueSize == -1 {
return errors.New("queue size is required")
}
if err := virtqueue.CheckQueueSize(o.queueSize); err != nil {
return err
}
if o.backendFD == -1 {
return errors.New("backend file descriptor is required")
}
return nil
}
var optionDefaults = optionValues{
// Required.
queueSize: -1,
// Required.
backendFD: -1,
}
// Option can be passed to [NewDevice] to influence device creation.
type Option func(*optionValues)
// WithQueueSize returns an [Option] that sets the size of the TX and RX queues
// that are to be created for the device. It specifies the number of
// entries/buffers each queue can hold. This also affects the memory
// consumption.
// This is required and must be an integer from 1 to 32768 that is also a power
// of 2.
func WithQueueSize(queueSize int) Option {
return func(o *optionValues) { o.queueSize = queueSize }
}
// WithBackendFD returns an [Option] that sets the file descriptor of the
// backend that will be used for the queues of the device. The device will write
// and read packets to/from that backend. The file descriptor can either be of a
// RAW socket or TUN/TAP device.
// Either this or [WithBackendDevice] is required.
func WithBackendFD(backendFD int) Option {
return func(o *optionValues) { o.backendFD = backendFD }
}
// WithBackendDevice returns an [Option] that sets the given TAP device as the
// backend that will be used for the queues of the device. The device will
// write and read packets to/from that backend. The TAP device should have been
// created with the [tuntap.WithVirtioNetHdr] option enabled.
// Either this or [WithBackendFD] is required.
func WithBackendDevice(dev *tuntap.Device) Option {
return func(o *optionValues) { o.backendFD = int(dev.File().Fd()) }
}

View File

@@ -0,0 +1,66 @@
package vhostnet
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestOptionValues_Apply(t *testing.T) {
opts := optionDefaults
opts.apply([]Option{
WithQueueSize(256),
WithBackendFD(99),
})
assert.Equal(t, optionValues{
queueSize: 256,
backendFD: 99,
}, opts)
}
func TestOptionValues_Validate(t *testing.T) {
tests := []struct {
name string
values optionValues
assertErr assert.ErrorAssertionFunc
}{
{
name: "queue size missing",
values: optionValues{
queueSize: -1,
backendFD: 99,
},
assertErr: assert.Error,
},
{
name: "invalid queue size",
values: optionValues{
queueSize: 24,
backendFD: 99,
},
assertErr: assert.Error,
},
{
name: "backend fd missing",
values: optionValues{
queueSize: 256,
backendFD: -1,
},
assertErr: assert.Error,
},
{
name: "valid",
values: optionValues{
queueSize: 256,
backendFD: 99,
},
assertErr: assert.NoError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.assertErr(t, tt.values.validate())
})
}
}