mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-23 08:54:25 +01:00
checkpt
This commit is contained in:
@@ -16,11 +16,11 @@ import (
|
|||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/hetznercloud/virtio-go/tuntap"
|
|
||||||
"github.com/hetznercloud/virtio-go/vhostnet"
|
|
||||||
"github.com/hetznercloud/virtio-go/virtio"
|
"github.com/hetznercloud/virtio-go/virtio"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/overlay/tuntap"
|
||||||
|
"github.com/slackhq/nebula/overlay/vhostnet"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
@@ -121,9 +121,9 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|||||||
name := strings.Trim(c.GetString("tun.dev", ""), "\x00")
|
name := strings.Trim(c.GetString("tun.dev", ""), "\x00")
|
||||||
tundev, err := tuntap.NewDevice(
|
tundev, err := tuntap.NewDevice(
|
||||||
tuntap.WithName(name),
|
tuntap.WithName(name),
|
||||||
tuntap.WithDeviceType(tuntap.DeviceTypeTUN),
|
tuntap.WithDeviceType(tuntap.DeviceTypeTUN), //todo wtf
|
||||||
tuntap.WithVirtioNetHdr(true),
|
tuntap.WithVirtioNetHdr(true), //todo hmm
|
||||||
tuntap.WithOffloads(0x0), //todo
|
tuntap.WithOffloads(unix.TUN_F_CSUM|unix.TUN_F_USO4|unix.TUN_F_USO6), //todo
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -137,7 +137,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|||||||
t.Device = name
|
t.Device = name
|
||||||
|
|
||||||
vdev, err := vhostnet.NewDevice(
|
vdev, err := vhostnet.NewDevice(
|
||||||
vhostnet.WithBackendDevice(tundev),
|
vhostnet.WithBackendFD(int(tundev.File().Fd())),
|
||||||
vhostnet.WithQueueSize(8), //todo config
|
vhostnet.WithQueueSize(8), //todo config
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -264,9 +264,12 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Read(p []byte) (int, error) {
|
func (t *tun) Read(p []byte) (int, error) {
|
||||||
_, out, err := t.vdev.ReceivePacket()
|
hdr, out, err := t.vdev.ReceivePacket()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
|
}
|
||||||
|
if hdr.NumBuffers == 0 {
|
||||||
|
|
||||||
}
|
}
|
||||||
p = p[:len(out)]
|
p = p[:len(out)]
|
||||||
copy(p, out)
|
copy(p, out)
|
||||||
@@ -278,14 +281,18 @@ func (t *tun) Write(b []byte) (int, error) {
|
|||||||
|
|
||||||
hdr := virtio.NetHdr{ //todo
|
hdr := virtio.NetHdr{ //todo
|
||||||
Flags: 0,
|
Flags: 0,
|
||||||
GSOType: 0,
|
GSOType: unix.VIRTIO_NET_HDR_GSO_NONE,
|
||||||
HdrLen: 0,
|
HdrLen: 0,
|
||||||
GSOSize: 0,
|
GSOSize: 0,
|
||||||
CsumStart: 0,
|
CsumStart: 0,
|
||||||
CsumOffset: 0,
|
CsumOffset: 0,
|
||||||
NumBuffers: 0,
|
NumBuffers: 0,
|
||||||
}
|
}
|
||||||
|
//todo wow fuck this
|
||||||
|
//bb := make([]byte, maximum+14)
|
||||||
|
//copy(bb[14:], b)
|
||||||
err := t.vdev.TransmitPacket(hdr, b)
|
err := t.vdev.TransmitPacket(hdr, b)
|
||||||
|
//err := t.vdev.TransmitPacket2(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|||||||
232
overlay/tuntap/device.go
Normal file
232
overlay/tuntap/device.go
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
package tuntap
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/hetznercloud/virtio-go/virtio"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Documentation:
|
||||||
|
// https://docs.kernel.org/networking/tuntap.html
|
||||||
|
// Also worth a read:
|
||||||
|
// https://blog.cloudflare.com/virtual-networking-101-understanding-tap/
|
||||||
|
|
||||||
|
// Device represents a TUN/TAP device.
|
||||||
|
type Device struct {
|
||||||
|
name string
|
||||||
|
ifindex uint32
|
||||||
|
mac net.HardwareAddr
|
||||||
|
file *os.File
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDevice creates a new TUN/TAP device, brings it up, and returns a [Device]
|
||||||
|
// instance providing access to it.
|
||||||
|
//
|
||||||
|
// There are multiple options that can be passed to this constructor to
|
||||||
|
// influence device creation:
|
||||||
|
// - [WithName]
|
||||||
|
// - [WithDeviceType]
|
||||||
|
// - [WithVirtioNetHdr]
|
||||||
|
// - [WithInterfaceFlags]
|
||||||
|
//
|
||||||
|
// Remember to call [Device.Close] after use to free up resources.
|
||||||
|
func NewDevice(options ...Option) (_ *Device, err error) {
|
||||||
|
opts := optionDefaults
|
||||||
|
opts.apply(options)
|
||||||
|
if err = opts.validate(); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid options: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get a file descriptor. The device will exist as long as we keep this
|
||||||
|
// file descriptor open.
|
||||||
|
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0666)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("access tuntap driver: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an interface request. When the name is empty, the kernel will
|
||||||
|
// auto-select one.
|
||||||
|
ifreq, err := unix.NewIfreq(opts.name)
|
||||||
|
if err != nil {
|
||||||
|
_ = unix.Close(fd)
|
||||||
|
return nil, fmt.Errorf("new ifreq: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the new device.
|
||||||
|
ifreq.SetUint16(opts.ifreqFlags())
|
||||||
|
if err = unix.IoctlIfreq(fd, unix.TUNSETIFF, ifreq); err != nil {
|
||||||
|
_ = unix.Close(fd)
|
||||||
|
return nil, fmt.Errorf("create device: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dev := Device{
|
||||||
|
// The TUNSETIFF ioctl writes the actual name that was chosen for the
|
||||||
|
// device back to the request, so use that.
|
||||||
|
name: ifreq.Name(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make the file descriptor of the device non-blocking. This enables us to
|
||||||
|
// cancel reads after a timeout when no packets are arriving.
|
||||||
|
// This, and the call to NewFile has to happen after creating the device:
|
||||||
|
// https://github.com/golang/go/issues/30426#issuecomment-470330742
|
||||||
|
// NewFile will recognize that the file descriptor is non-blocking and will
|
||||||
|
// configure polling for it.
|
||||||
|
if err = unix.SetNonblock(fd, true); err != nil {
|
||||||
|
_ = unix.Close(fd)
|
||||||
|
return nil, fmt.Errorf("make file descriptor non-blocking: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// By wrapping the file descriptor as an os.File, we not only have a
|
||||||
|
// convenient way to read and write, but also register a finalizer that
|
||||||
|
// closes the file descriptor when it's being garbage collected.
|
||||||
|
dev.file = os.NewFile(uintptr(fd), dev.name)
|
||||||
|
|
||||||
|
// Make sure the device is removed when one of the following initialization
|
||||||
|
// steps fails.
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
_ = dev.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if opts.virtioNetHdr {
|
||||||
|
// Tell the device which size we use for our virtio_net_hdr.
|
||||||
|
err = unix.IoctlSetPointerInt(fd, unix.TUNSETVNETHDRSZ, virtio.NetHdrSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("set vnethdr size: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tell the device which offloads are supported.
|
||||||
|
err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, opts.offloads)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("set offloads: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For the following ioctls we need just any AF_INET socket, so create one.
|
||||||
|
inet, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("open inet socket: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = unix.Close(inet) }()
|
||||||
|
|
||||||
|
// Set the interface flags to bring it up.
|
||||||
|
ifreq.SetUint16(unix.IFF_UP | opts.interfaceFlags)
|
||||||
|
if err = unix.IoctlIfreq(inet, unix.SIOCSIFFLAGS, ifreq); err != nil {
|
||||||
|
return nil, fmt.Errorf("set interface flags: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the interface index.
|
||||||
|
if err = unix.IoctlIfreq(inet, unix.SIOCGIFINDEX, ifreq); err != nil {
|
||||||
|
return nil, fmt.Errorf("get interface index: %w", err)
|
||||||
|
}
|
||||||
|
dev.ifindex = ifreq.Uint32()
|
||||||
|
|
||||||
|
// Get the MAC address.
|
||||||
|
// This ioctl writes a sockaddr into the data ifru section of the interface
|
||||||
|
// request struct. The MAC address is in the beginning of the
|
||||||
|
// sockaddr.sa_data section.
|
||||||
|
if err = unix.IoctlIfreq(inet, unix.SIOCGIFHWADDR, ifreq); err != nil {
|
||||||
|
return nil, fmt.Errorf("get mac address: %w", err)
|
||||||
|
}
|
||||||
|
dev.mac = unsafe.Slice((*byte)(unsafe.Pointer(ifreq)), 32)[16+2 : 16+8]
|
||||||
|
|
||||||
|
return &dev, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the file descriptor behind this device. This will cause the
|
||||||
|
// TUN/TAP device to be removed.
|
||||||
|
func (dev *Device) Close() error {
|
||||||
|
if err := dev.file.Close(); err != nil {
|
||||||
|
return fmt.Errorf("close file descriptor: %w", err)
|
||||||
|
}
|
||||||
|
dev.file = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the name of this device.
|
||||||
|
func (dev *Device) Name() string {
|
||||||
|
dev.ensureInitialized()
|
||||||
|
return dev.name
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ifindex returns the interface index of this device.
|
||||||
|
func (dev *Device) Ifindex() uint32 {
|
||||||
|
dev.ensureInitialized()
|
||||||
|
return dev.ifindex
|
||||||
|
}
|
||||||
|
|
||||||
|
// MAC returns the hardware address of this device.
|
||||||
|
func (dev *Device) MAC() net.HardwareAddr {
|
||||||
|
dev.ensureInitialized()
|
||||||
|
return dev.mac
|
||||||
|
}
|
||||||
|
|
||||||
|
// File returns the [os.File] that is used to communicate with this device.
|
||||||
|
// If you access it directly, please be careful to not interfere with this
|
||||||
|
// implementation.
|
||||||
|
func (dev *Device) File() *os.File {
|
||||||
|
dev.ensureInitialized()
|
||||||
|
return dev.file
|
||||||
|
}
|
||||||
|
|
||||||
|
// WritePacket writes the given packet to the TUN/TAP device.
|
||||||
|
// When the [WithVirtioNetHdr] option was enabled, then the caller is
|
||||||
|
// responsible to prepend the packet with a [virtio.NetHdr].
|
||||||
|
func (dev *Device) WritePacket(packet []byte) error {
|
||||||
|
dev.ensureInitialized()
|
||||||
|
|
||||||
|
_, err := dev.file.Write(packet)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("write %d bytes: %w", len(packet), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadPacket reads the next available packet from the TUN/TAP device into the
|
||||||
|
// given buffer. Make sure that the buffer is large enough, otherwise only a
|
||||||
|
// part of the packet may be read. The number of read bytes will be returned.
|
||||||
|
//
|
||||||
|
// When the [WithVirtioNetHdr] option was enabled, then the read packet will be
|
||||||
|
// prepended with a [virtio.NetHdr]. The caller is responsible to handle it
|
||||||
|
// accordingly.
|
||||||
|
//
|
||||||
|
// A timeout can be given to limit the time this operation blocks. If no packet
|
||||||
|
// arrives within the given timeout, the read is canceled and an error that
|
||||||
|
// wraps [os.ErrDeadlineExceeded] is returned. Pass a timeout of zero to make
|
||||||
|
// this operation block infinitely.
|
||||||
|
func (dev *Device) ReadPacket(buf []byte, timeout time.Duration) (int, error) {
|
||||||
|
dev.ensureInitialized()
|
||||||
|
|
||||||
|
// Make sure the read times out. This only works for files that support
|
||||||
|
// polling (see above).
|
||||||
|
// When no timeout is desired, passing the zero time removes the deadline.
|
||||||
|
var deadline time.Time
|
||||||
|
if timeout > 0 {
|
||||||
|
deadline = time.Now().Add(timeout)
|
||||||
|
}
|
||||||
|
if err := dev.file.SetReadDeadline(deadline); err != nil {
|
||||||
|
return 0, fmt.Errorf("set deadline: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := dev.file.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
return n, fmt.Errorf("read up to %d bytes: %w", len(buf), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureInitialized is used as a guard to prevent methods to be called on an
|
||||||
|
// uninitialized instance.
|
||||||
|
func (dev *Device) ensureInitialized() {
|
||||||
|
if dev.file == nil {
|
||||||
|
panic("device is not initialized")
|
||||||
|
}
|
||||||
|
}
|
||||||
132
overlay/tuntap/device_test.go
Normal file
132
overlay/tuntap/device_test.go
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
package tuntap_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gopacket/gopacket/afpacket"
|
||||||
|
"github.com/hetznercloud/virtio-go/internal/testsupport"
|
||||||
|
"github.com/hetznercloud/virtio-go/tuntap"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewDevice(t *testing.T) {
|
||||||
|
testsupport.VirtrunOnly(t)
|
||||||
|
|
||||||
|
t.Run("with static name", func(t *testing.T) {
|
||||||
|
const name = "test42"
|
||||||
|
dev, err := tuntap.NewDevice(
|
||||||
|
tuntap.WithDeviceType(tuntap.DeviceTypeTAP),
|
||||||
|
tuntap.WithName(name),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
assert.NoError(t, dev.Close())
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, name, dev.Name())
|
||||||
|
|
||||||
|
iface, err := net.InterfaceByIndex(int(dev.Ifindex()))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, name, iface.Name)
|
||||||
|
assert.Equal(t, dev.MAC(), iface.HardwareAddr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with auto selected name", func(t *testing.T) {
|
||||||
|
dev, err := tuntap.NewDevice(
|
||||||
|
tuntap.WithDeviceType(tuntap.DeviceTypeTAP),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
assert.NoError(t, dev.Close())
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Contains(t, dev.Name(), "tap")
|
||||||
|
|
||||||
|
iface, err := net.InterfaceByIndex(int(dev.Ifindex()))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, dev.Name(), iface.Name)
|
||||||
|
assert.Equal(t, dev.MAC(), iface.HardwareAddr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDevice_WritePacket(t *testing.T) {
|
||||||
|
testsupport.VirtrunOnly(t)
|
||||||
|
|
||||||
|
dev, tPacket := setupTestDevice(t)
|
||||||
|
|
||||||
|
// Write a test packet to the TAP device.
|
||||||
|
_, pkt := testsupport.TestPacket(t, dev.MAC(), 64)
|
||||||
|
assert.NoError(t, dev.WritePacket(pkt))
|
||||||
|
|
||||||
|
// Check if the packet arrived in the RAW socket.
|
||||||
|
data, _, err := tPacket.ReadPacketData()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, pkt, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDevice_ReadPacket(t *testing.T) {
|
||||||
|
testsupport.VirtrunOnly(t)
|
||||||
|
|
||||||
|
dev, tPacket := setupTestDevice(t)
|
||||||
|
|
||||||
|
// Write a test packet to the RAW socket.
|
||||||
|
_, pkt := testsupport.TestPacket(t, dev.MAC(), 64)
|
||||||
|
assert.NoError(t, tPacket.WritePacketData(pkt))
|
||||||
|
|
||||||
|
// Check if the packet arrived at the TAP device.
|
||||||
|
receiveBuf := make([]byte, 1024)
|
||||||
|
n, err := dev.ReadPacket(receiveBuf, time.Second)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, len(pkt), n)
|
||||||
|
assert.Equal(t, pkt, receiveBuf[:n])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDevice_ReadPacket_Timeout(t *testing.T) {
|
||||||
|
testsupport.VirtrunOnly(t)
|
||||||
|
|
||||||
|
dev, _ := setupTestDevice(t)
|
||||||
|
|
||||||
|
// Try to receive a packet on the TAP device when none was sent.
|
||||||
|
// This should time out.
|
||||||
|
receiveBuf := make([]byte, 1024)
|
||||||
|
_, err := dev.ReadPacket(receiveBuf, 500*time.Millisecond)
|
||||||
|
assert.ErrorIs(t, err, os.ErrDeadlineExceeded)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTestDevice(t *testing.T) (*tuntap.Device, *afpacket.TPacket) {
|
||||||
|
t.Helper()
|
||||||
|
testsupport.VirtrunOnly(t)
|
||||||
|
|
||||||
|
// Make sure the Linux kernel does not send router solicitations that may
|
||||||
|
// interfere with these tests.
|
||||||
|
testsupport.SetSysctl(t, "net.ipv6.conf.all.disable_ipv6", "1")
|
||||||
|
|
||||||
|
// Create a TAP device.
|
||||||
|
dev, err := tuntap.NewDevice(
|
||||||
|
tuntap.WithDeviceType(tuntap.DeviceTypeTAP),
|
||||||
|
// Helps to stop the Linux kernel from sending packets on this
|
||||||
|
// interface.
|
||||||
|
tuntap.WithInterfaceFlags(unix.IFF_NOARP),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
assert.NoError(t, dev.Close())
|
||||||
|
})
|
||||||
|
|
||||||
|
// Open a RAW socket to capture packets arriving at the TAP device or
|
||||||
|
// write packets to it.
|
||||||
|
tPacket, err := afpacket.NewTPacket(
|
||||||
|
afpacket.SocketRaw,
|
||||||
|
afpacket.TPacketVersion3,
|
||||||
|
afpacket.OptInterface(dev.Name()),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(tPacket.Close)
|
||||||
|
|
||||||
|
return dev, tPacket
|
||||||
|
}
|
||||||
3
overlay/tuntap/doc.go
Normal file
3
overlay/tuntap/doc.go
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
// Package tuntap provides methods to create TUN/TAP devices and send and
|
||||||
|
// receive packets on them.
|
||||||
|
package tuntap
|
||||||
116
overlay/tuntap/options.go
Normal file
116
overlay/tuntap/options.go
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
package tuntap
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DeviceType is the TUN/TAP device type.
|
||||||
|
type DeviceType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DeviceTypeTUN can be used to create TUN devices that operate on layer 3.
|
||||||
|
// Packets that are transported over TUN devices do not have an Ethernet
|
||||||
|
// header.
|
||||||
|
DeviceTypeTUN DeviceType = unix.IFF_TUN
|
||||||
|
// DeviceTypeTAP can be used to create TAP devices that operate on layer 2.
|
||||||
|
// Packets that are transported over TAP devices do have an Ethernet header.
|
||||||
|
DeviceTypeTAP DeviceType = unix.IFF_TAP
|
||||||
|
)
|
||||||
|
|
||||||
|
type optionValues struct {
|
||||||
|
name string
|
||||||
|
deviceType DeviceType
|
||||||
|
virtioNetHdr bool
|
||||||
|
offloads int
|
||||||
|
interfaceFlags uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *optionValues) apply(options []Option) {
|
||||||
|
for _, option := range options {
|
||||||
|
option(o)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *optionValues) validate() error {
|
||||||
|
if len(o.name) >= unix.IFNAMSIZ {
|
||||||
|
return errors.New("name must not be longer that 15 characters")
|
||||||
|
}
|
||||||
|
if o.deviceType != DeviceTypeTUN && o.deviceType != DeviceTypeTAP {
|
||||||
|
return errors.New("device type is required and must be either TUN or TAP")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *optionValues) ifreqFlags() uint16 {
|
||||||
|
flags := uint16(o.deviceType)
|
||||||
|
|
||||||
|
// Disable the packet information prefix.
|
||||||
|
flags |= unix.IFF_NO_PI
|
||||||
|
|
||||||
|
// Ensure the ioctl fails when a device with the same name already exists.
|
||||||
|
flags |= unix.IFF_TUN_EXCL
|
||||||
|
|
||||||
|
if o.virtioNetHdr {
|
||||||
|
// Also requires the TUNSETVNETHDRSZ ioctl at a later time.
|
||||||
|
flags |= unix.IFF_VNET_HDR
|
||||||
|
}
|
||||||
|
|
||||||
|
return flags
|
||||||
|
}
|
||||||
|
|
||||||
|
var optionDefaults = optionValues{
|
||||||
|
// Let the kernel auto-select a name.
|
||||||
|
name: "",
|
||||||
|
// Required.
|
||||||
|
deviceType: -1,
|
||||||
|
// Don't enable it by default to avoid surprises.
|
||||||
|
virtioNetHdr: false,
|
||||||
|
// Optional. No offload support advertised by default.
|
||||||
|
offloads: 0,
|
||||||
|
// Optional. IFF_UP will always be set.
|
||||||
|
interfaceFlags: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option can be passed to [NewDevice] to influence device creation.
|
||||||
|
type Option func(*optionValues)
|
||||||
|
|
||||||
|
// WithName returns an [Option] that sets the name of the to be created device.
|
||||||
|
// This is optional. When no name is specified, the kernel will auto-select a
|
||||||
|
// name using the scheme "tunX" or "tapX".
|
||||||
|
func WithName(name string) Option {
|
||||||
|
return func(o *optionValues) { o.name = name }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDeviceType returns an [Option] that sets the type of device that should
|
||||||
|
// be created.
|
||||||
|
// This is required.
|
||||||
|
func WithDeviceType(deviceType DeviceType) Option {
|
||||||
|
return func(o *optionValues) { o.deviceType = deviceType }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithVirtioNetHdr returns an [Option] that sets whether packets that are
|
||||||
|
// transported over the device are prepended with a [virtio.NetHdr].
|
||||||
|
// This is optional and disabled by default.
|
||||||
|
func WithVirtioNetHdr(enable bool) Option {
|
||||||
|
return func(o *optionValues) { o.virtioNetHdr = enable }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithOffloads returns an [Option] that sets the supported offloads that the
|
||||||
|
// device should advertise. This tells the kernel which offloads the owner of
|
||||||
|
// the device can deal with ([unix.TUN_F_CSUM] for example).
|
||||||
|
// This is optional. By default, no offloads are supported.
|
||||||
|
// When configured, then [WithVirtioNetHdr] should also be enabled.
|
||||||
|
func WithOffloads(offloads int) Option {
|
||||||
|
return func(o *optionValues) { o.offloads = offloads }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithInterfaceFlags returns an [Option] that sets the flags that should be
|
||||||
|
// used when taking the created interface up.
|
||||||
|
// This is optional. The [unix.IFF_UP] flag will always be set.
|
||||||
|
// The [unix.IFF_NOARP] flag may be useful in some scenarios to avoid packets
|
||||||
|
// from the Linux networking stack interfering with your application.
|
||||||
|
func WithInterfaceFlags(flags uint16) Option {
|
||||||
|
return func(o *optionValues) { o.interfaceFlags = flags }
|
||||||
|
}
|
||||||
79
overlay/tuntap/options_internal_test.go
Normal file
79
overlay/tuntap/options_internal_test.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package tuntap
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOptionValues_Apply(t *testing.T) {
|
||||||
|
opts := optionDefaults
|
||||||
|
opts.apply([]Option{
|
||||||
|
WithName("name"),
|
||||||
|
WithDeviceType(DeviceTypeTAP),
|
||||||
|
WithVirtioNetHdr(true),
|
||||||
|
WithOffloads(unix.TUN_F_CSUM),
|
||||||
|
WithInterfaceFlags(unix.IFF_NOARP),
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, optionValues{
|
||||||
|
name: "name",
|
||||||
|
deviceType: DeviceTypeTAP,
|
||||||
|
virtioNetHdr: true,
|
||||||
|
offloads: unix.TUN_F_CSUM,
|
||||||
|
interfaceFlags: unix.IFF_NOARP,
|
||||||
|
}, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOptionValues_Validate(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
values optionValues
|
||||||
|
assertErr assert.ErrorAssertionFunc
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "name too long",
|
||||||
|
values: optionValues{
|
||||||
|
name: "thisisaverylongname",
|
||||||
|
deviceType: DeviceTypeTAP,
|
||||||
|
},
|
||||||
|
assertErr: assert.Error,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "device type missing",
|
||||||
|
values: optionValues{},
|
||||||
|
assertErr: assert.Error,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid device type",
|
||||||
|
values: optionValues{
|
||||||
|
deviceType: 999,
|
||||||
|
},
|
||||||
|
assertErr: assert.Error,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid minimal",
|
||||||
|
values: optionValues{
|
||||||
|
deviceType: DeviceTypeTAP,
|
||||||
|
},
|
||||||
|
assertErr: assert.NoError,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid full",
|
||||||
|
values: optionValues{
|
||||||
|
name: "name",
|
||||||
|
deviceType: DeviceTypeTAP,
|
||||||
|
virtioNetHdr: true,
|
||||||
|
offloads: unix.TUN_F_CSUM,
|
||||||
|
interfaceFlags: unix.IFF_NOARP,
|
||||||
|
},
|
||||||
|
assertErr: assert.NoError,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.assertErr(t, tt.values.validate())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
412
overlay/vhostnet/device.go
Normal file
412
overlay/vhostnet/device.go
Normal file
@@ -0,0 +1,412 @@
|
|||||||
|
package vhostnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
"github.com/hetznercloud/virtio-go/vhost"
|
||||||
|
"github.com/hetznercloud/virtio-go/virtio"
|
||||||
|
"github.com/hetznercloud/virtio-go/virtqueue"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrDeviceClosed is returned when the [Device] is closed while operations are
|
||||||
|
// still running.
|
||||||
|
var ErrDeviceClosed = errors.New("device was closed")
|
||||||
|
|
||||||
|
// The indexes for the receive and transmit queues.
|
||||||
|
const (
|
||||||
|
receiveQueueIndex = 0
|
||||||
|
transmitQueueIndex = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
// Device represents a vhost networking device within the kernel-level virtio
|
||||||
|
// implementation and provides methods to interact with it.
|
||||||
|
type Device struct {
|
||||||
|
initialized bool
|
||||||
|
controlFD int
|
||||||
|
|
||||||
|
receiveQueue *virtqueue.SplitQueue
|
||||||
|
transmitQueue *virtqueue.SplitQueue
|
||||||
|
|
||||||
|
// transmitted contains channels for each possible descriptor chain head
|
||||||
|
// index. This is used for packet transmit notifications.
|
||||||
|
// When a packet was transmitted and the descriptor chain was used by the
|
||||||
|
// device, the corresponding channel receives the [virtqueue.UsedElement]
|
||||||
|
// instance provided by the device.
|
||||||
|
transmitted []chan virtqueue.UsedElement
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDevice initializes a new vhost networking device within the
|
||||||
|
// kernel-level virtio implementation, sets up the virtqueues and returns a
|
||||||
|
// [Device] instance that can be used to communicate with that vhost device.
|
||||||
|
//
|
||||||
|
// There are multiple options that can be passed to this constructor to
|
||||||
|
// influence device creation:
|
||||||
|
// - [WithQueueSize]
|
||||||
|
// - [WithBackendFD]
|
||||||
|
// - [WithBackendDevice]
|
||||||
|
//
|
||||||
|
// Remember to call [Device.Close] after use to free up resources.
|
||||||
|
func NewDevice(options ...Option) (_ *Device, err error) {
|
||||||
|
opts := optionDefaults
|
||||||
|
opts.apply(options)
|
||||||
|
if err = opts.validate(); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid options: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dev := Device{
|
||||||
|
controlFD: -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up a partially initialized device when something fails.
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
_ = dev.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Retrieve a new control file descriptor. This will be used to configure
|
||||||
|
// the vhost networking device in the kernel.
|
||||||
|
dev.controlFD, err = unix.Open("/dev/vhost-net", os.O_RDWR, 0666)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get control file descriptor: %w", err)
|
||||||
|
}
|
||||||
|
if err = vhost.OwnControlFD(dev.controlFD); err != nil {
|
||||||
|
return nil, fmt.Errorf("own control file descriptor: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Advertise the supported features. This isn't much for now.
|
||||||
|
// TODO: Add feature options and implement proper feature negotiation.
|
||||||
|
features := virtio.FeatureVersion1 // | virtio.FeatureNetMergeRXBuffers
|
||||||
|
if err = vhost.SetFeatures(dev.controlFD, features); err != nil {
|
||||||
|
return nil, fmt.Errorf("set features: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize and register the queues needed for the networking device.
|
||||||
|
if dev.receiveQueue, err = createQueue(dev.controlFD, receiveQueueIndex, opts.queueSize); err != nil {
|
||||||
|
return nil, fmt.Errorf("create receive queue: %w", err)
|
||||||
|
}
|
||||||
|
if dev.transmitQueue, err = createQueue(dev.controlFD, transmitQueueIndex, opts.queueSize); err != nil {
|
||||||
|
return nil, fmt.Errorf("create transmit queue: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up memory mappings for all buffers used by the queues. This has to
|
||||||
|
// happen before a backend for the queues can be registered.
|
||||||
|
memoryLayout := vhost.NewMemoryLayoutForQueues(
|
||||||
|
[]*virtqueue.SplitQueue{dev.receiveQueue, dev.transmitQueue},
|
||||||
|
)
|
||||||
|
if err = vhost.SetMemoryLayout(dev.controlFD, memoryLayout); err != nil {
|
||||||
|
return nil, fmt.Errorf("setup memory layout: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the queue backends. This activates the queues within the kernel.
|
||||||
|
if err = SetQueueBackend(dev.controlFD, receiveQueueIndex, opts.backendFD); err != nil {
|
||||||
|
return nil, fmt.Errorf("set receive queue backend: %w", err)
|
||||||
|
}
|
||||||
|
if err = SetQueueBackend(dev.controlFD, transmitQueueIndex, opts.backendFD); err != nil {
|
||||||
|
return nil, fmt.Errorf("set transmit queue backend: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fully populate the receive queue with available buffers which the device
|
||||||
|
// can write new packets into.
|
||||||
|
if err = dev.refillReceiveQueue(); err != nil {
|
||||||
|
return nil, fmt.Errorf("refill receive queue: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize channels for transmit notifications.
|
||||||
|
dev.transmitted = make([]chan virtqueue.UsedElement, dev.transmitQueue.Size())
|
||||||
|
for i := range len(dev.transmitted) {
|
||||||
|
// It is important to use a single-element buffered channel here.
|
||||||
|
// When the channel was unbuffered and the monitorTransmitQueue
|
||||||
|
// goroutine would write into it, the writing would block which could
|
||||||
|
// lead to deadlocks in case transmit notifications do not arrive in
|
||||||
|
// order.
|
||||||
|
// When the goroutine would use fire-and-forget to write into that
|
||||||
|
// channel, there may be a chance that the TransmitPacket does not
|
||||||
|
// receive the transmit notification due to this being a race condition.
|
||||||
|
// Buffering a single transmit notification resolves this without race
|
||||||
|
// conditions or possible deadlocks.
|
||||||
|
dev.transmitted[i] = make(chan virtqueue.UsedElement, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Monitor transmit queue in background.
|
||||||
|
go dev.monitorTransmitQueue()
|
||||||
|
|
||||||
|
dev.initialized = true
|
||||||
|
|
||||||
|
// Make sure to clean up even when the device gets garbage collected without
|
||||||
|
// Close being called first.
|
||||||
|
devPtr := &dev
|
||||||
|
runtime.SetFinalizer(devPtr, (*Device).Close)
|
||||||
|
|
||||||
|
return devPtr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// monitorTransmitQueue waits for the device to advertise used descriptor chains
|
||||||
|
// in the transmit queue and produces a transmit notification via the
|
||||||
|
// corresponding channel.
|
||||||
|
func (dev *Device) monitorTransmitQueue() {
|
||||||
|
usedChan := dev.transmitQueue.UsedDescriptorChains()
|
||||||
|
for {
|
||||||
|
used, ok := <-usedChan
|
||||||
|
if !ok {
|
||||||
|
// The queue was closed.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if int(used.DescriptorIndex) > len(dev.transmitted) {
|
||||||
|
panic(fmt.Sprintf("device provided a used descriptor index (%d) that is out of range",
|
||||||
|
used.DescriptorIndex))
|
||||||
|
}
|
||||||
|
|
||||||
|
dev.transmitted[used.DescriptorIndex] <- used
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TransmitPacket writes the given packet into the transmit queue of this
|
||||||
|
// device. The packet will be prepended with the [virtio.NetHdr].
|
||||||
|
//
|
||||||
|
// When the queue is full, this will block until the queue has enough room to
|
||||||
|
// transmit the packet. This method will not return before the packet was
|
||||||
|
// transmitted and the device notifies that it has used the packet buffer.
|
||||||
|
func (dev *Device) TransmitPacket(vnethdr virtio.NetHdr, packet []byte) error {
|
||||||
|
// Prepend the packet with its virtio-net header.
|
||||||
|
vnethdrBuf := make([]byte, virtio.NetHdrSize) //todo WHY
|
||||||
|
if err := vnethdr.Encode(vnethdrBuf); err != nil {
|
||||||
|
return fmt.Errorf("encode vnethdr: %w", err)
|
||||||
|
}
|
||||||
|
outBuffers := [][]byte{vnethdrBuf, packet}
|
||||||
|
|
||||||
|
chainIndex, err := dev.transmitQueue.OfferDescriptorChain(outBuffers, 0, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("offer descriptor chain: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the packet to have been transmitted.
|
||||||
|
<-dev.transmitted[chainIndex]
|
||||||
|
|
||||||
|
if err = dev.transmitQueue.FreeDescriptorChain(chainIndex); err != nil {
|
||||||
|
return fmt.Errorf("free descriptor chain: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReceivePacket reads the next available packet from the receive queue of this
|
||||||
|
// device and returns its [virtio.NetHdr] and packet data separately.
|
||||||
|
//
|
||||||
|
// When no packet is available, this will block until there is one.
|
||||||
|
//
|
||||||
|
// When this method returns an error, the receive queue will likely be in a
|
||||||
|
// broken state which this implementation cannot recover from. The caller should
|
||||||
|
// close the device and not attempt any additional receives.
|
||||||
|
func (dev *Device) ReceivePacket() (virtio.NetHdr, []byte, error) {
|
||||||
|
var (
|
||||||
|
chainHeads []uint16
|
||||||
|
|
||||||
|
vnethdr virtio.NetHdr
|
||||||
|
buffers [][]byte
|
||||||
|
|
||||||
|
// Each packet starts with a virtio-net header which we have to subtract
|
||||||
|
// from the total length.
|
||||||
|
packetLength = -virtio.NetHdrSize
|
||||||
|
)
|
||||||
|
|
||||||
|
// We presented FeatureNetMergeRXBuffers to the device, so one packet may be
|
||||||
|
// made of multiple descriptor chains which are to be merged.
|
||||||
|
for remainingChains := 1; remainingChains > 0; remainingChains-- {
|
||||||
|
// Get the next descriptor chain.
|
||||||
|
usedElement, ok := <-dev.receiveQueue.UsedDescriptorChains()
|
||||||
|
if !ok {
|
||||||
|
return virtio.NetHdr{}, nil, ErrDeviceClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track this chain to be freed later.
|
||||||
|
head := uint16(usedElement.DescriptorIndex)
|
||||||
|
chainHeads = append(chainHeads, head)
|
||||||
|
|
||||||
|
outBuffers, inBuffers, err := dev.receiveQueue.GetDescriptorChain(head)
|
||||||
|
if err != nil {
|
||||||
|
// When this fails we may miss to free some descriptor chains. We
|
||||||
|
// could try to mitigate this by deferring the freeing somehow, but
|
||||||
|
// it's not worth the hassle. When this method fails, the queue will
|
||||||
|
// be in a broken state anyway.
|
||||||
|
return virtio.NetHdr{}, nil, fmt.Errorf("get descriptor chain: %w", err)
|
||||||
|
}
|
||||||
|
if len(outBuffers) > 0 {
|
||||||
|
// How did this happen!?
|
||||||
|
panic("receive queue contains device-readable buffers")
|
||||||
|
}
|
||||||
|
if len(inBuffers) == 0 {
|
||||||
|
// Empty descriptor chains should not be possible.
|
||||||
|
panic("descriptor chain contains no buffers")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The device tells us how many bytes of the descriptor chain it has
|
||||||
|
// actually written to. The specification forces the device to fully
|
||||||
|
// fill up all but the last descriptor chain when multiple descriptor
|
||||||
|
// chains are being merged, but being more compatible here doesn't hurt.
|
||||||
|
inBuffers = truncateBuffers(inBuffers, int(usedElement.Length))
|
||||||
|
packetLength += int(usedElement.Length)
|
||||||
|
|
||||||
|
// Is this the first descriptor chain we process?
|
||||||
|
if len(buffers) == 0 {
|
||||||
|
// The specification requires that the first descriptor chain starts
|
||||||
|
// with a virtio-net header. It is not clear, whether it is also
|
||||||
|
// required to be fully contained in the first buffer of that
|
||||||
|
// descriptor chain, but it is reasonable to assume that this is
|
||||||
|
// always the case.
|
||||||
|
// The decode method already does the buffer length check.
|
||||||
|
if err = vnethdr.Decode(inBuffers[0]); err != nil {
|
||||||
|
// The device misbehaved. There is no way we can gracefully
|
||||||
|
// recover from this, because we don't know how many of the
|
||||||
|
// following descriptor chains belong to this packet.
|
||||||
|
return virtio.NetHdr{}, nil, fmt.Errorf("decode vnethdr: %w", err)
|
||||||
|
}
|
||||||
|
inBuffers[0] = inBuffers[0][virtio.NetHdrSize:]
|
||||||
|
|
||||||
|
// The virtio-net header tells us how many descriptor chains this
|
||||||
|
// packet is long.
|
||||||
|
remainingChains = int(vnethdr.NumBuffers)
|
||||||
|
}
|
||||||
|
|
||||||
|
buffers = append(buffers, inBuffers...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy all the buffers together to produce the complete packet slice.
|
||||||
|
packet := make([]byte, packetLength)
|
||||||
|
copied := 0
|
||||||
|
for _, buffer := range buffers {
|
||||||
|
copied += copy(packet[copied:], buffer)
|
||||||
|
}
|
||||||
|
if copied != packetLength {
|
||||||
|
panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", packetLength, copied))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now that we have copied all buffers, we can free the used descriptor
|
||||||
|
// chains again.
|
||||||
|
// TODO: Recycling the descriptor chains would be more efficient than
|
||||||
|
// freeing them just to offer them again right after.
|
||||||
|
for _, head := range chainHeads {
|
||||||
|
if err := dev.receiveQueue.FreeDescriptorChain(head); err != nil {
|
||||||
|
return virtio.NetHdr{}, nil, fmt.Errorf("free descriptor chain with head index %d: %w", head, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// It's advised to always keep the receive queue fully populated with
|
||||||
|
// available buffers which the device can write new packets into.
|
||||||
|
if err := dev.refillReceiveQueue(); err != nil {
|
||||||
|
return virtio.NetHdr{}, nil, fmt.Errorf("refill receive queue: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return vnethdr, packet, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Make above methods cancelable by taking a context.Context argument?
|
||||||
|
// TODO: Implement zero-copy variants to transmit and receive packets?
|
||||||
|
|
||||||
|
// refillReceiveQueue offers as many new device-writable buffers to the device
|
||||||
|
// as the queue can fit. The device will then use these to write received
|
||||||
|
// packets.
|
||||||
|
func (dev *Device) refillReceiveQueue() error {
|
||||||
|
for {
|
||||||
|
_, err := dev.receiveQueue.OfferDescriptorChain(nil, 1, false)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
|
||||||
|
// Queue is full, job is done.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("offer descriptor chain: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close cleans up the vhost networking device within the kernel and releases
|
||||||
|
// all resources used for it.
|
||||||
|
// The implementation will try to release as many resources as possible and
|
||||||
|
// collect potential errors before returning them.
|
||||||
|
func (dev *Device) Close() error {
|
||||||
|
dev.initialized = false
|
||||||
|
|
||||||
|
// Closing the control file descriptor will unregister all queues from the
|
||||||
|
// kernel.
|
||||||
|
if dev.controlFD >= 0 {
|
||||||
|
if err := unix.Close(dev.controlFD); err != nil {
|
||||||
|
// Return an error and do not continue, because the memory used for
|
||||||
|
// the queues should not be released before they were unregistered
|
||||||
|
// from the kernel.
|
||||||
|
return fmt.Errorf("close control file descriptor: %w", err)
|
||||||
|
}
|
||||||
|
dev.controlFD = -1
|
||||||
|
}
|
||||||
|
|
||||||
|
var errs []error
|
||||||
|
|
||||||
|
if dev.receiveQueue != nil {
|
||||||
|
if err := dev.receiveQueue.Close(); err == nil {
|
||||||
|
dev.receiveQueue = nil
|
||||||
|
} else {
|
||||||
|
errs = append(errs, fmt.Errorf("close receive queue: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if dev.transmitQueue != nil {
|
||||||
|
if err := dev.transmitQueue.Close(); err == nil {
|
||||||
|
dev.transmitQueue = nil
|
||||||
|
} else {
|
||||||
|
errs = append(errs, fmt.Errorf("close transmit queue: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errs) == 0 {
|
||||||
|
// Everything was cleaned up. No need to run the finalizer anymore.
|
||||||
|
runtime.SetFinalizer(dev, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors.Join(errs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureInitialized is used as a guard to prevent methods to be called on an
|
||||||
|
// uninitialized instance.
|
||||||
|
func (dev *Device) ensureInitialized() {
|
||||||
|
if !dev.initialized {
|
||||||
|
panic("device is not initialized")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// createQueue creates a new virtqueue and registers it with the vhost device
|
||||||
|
// using the given index.
|
||||||
|
func createQueue(controlFD int, queueIndex int, queueSize int) (*virtqueue.SplitQueue, error) {
|
||||||
|
var (
|
||||||
|
queue *virtqueue.SplitQueue
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if queue, err = virtqueue.NewSplitQueue(queueSize); err != nil {
|
||||||
|
return nil, fmt.Errorf("create virtqueue: %w", err)
|
||||||
|
}
|
||||||
|
if err = vhost.RegisterQueue(controlFD, uint32(queueIndex), queue); err != nil {
|
||||||
|
return nil, fmt.Errorf("register virtqueue with index %d: %w", queueIndex, err)
|
||||||
|
}
|
||||||
|
return queue, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// truncateBuffers returns a new list of buffers whose combined length matches
|
||||||
|
// exactly the specified length. When the specified length exceeds the length of
|
||||||
|
// the buffers, this is an error. When it is smaller, the buffer list will be
|
||||||
|
// truncated accordingly.
|
||||||
|
func truncateBuffers(buffers [][]byte, length int) (out [][]byte) {
|
||||||
|
for _, buffer := range buffers {
|
||||||
|
if length < len(buffer) {
|
||||||
|
out = append(out, buffer[:length])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out = append(out, buffer)
|
||||||
|
length -= len(buffer)
|
||||||
|
}
|
||||||
|
if length > 0 {
|
||||||
|
panic("length exceeds the combined length of all buffers")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
86
overlay/vhostnet/device_internal_test.go
Normal file
86
overlay/vhostnet/device_internal_test.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package vhostnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTruncateBuffers(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
buffers [][]byte
|
||||||
|
length int
|
||||||
|
expected [][]byte
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no buffers",
|
||||||
|
buffers: nil,
|
||||||
|
length: 0,
|
||||||
|
expected: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single buffer correct length",
|
||||||
|
buffers: [][]byte{
|
||||||
|
make([]byte, 100),
|
||||||
|
},
|
||||||
|
length: 100,
|
||||||
|
expected: [][]byte{
|
||||||
|
make([]byte, 100),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single buffer truncated",
|
||||||
|
buffers: [][]byte{
|
||||||
|
make([]byte, 100),
|
||||||
|
},
|
||||||
|
length: 90,
|
||||||
|
expected: [][]byte{
|
||||||
|
make([]byte, 90),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple buffers correct length",
|
||||||
|
buffers: [][]byte{
|
||||||
|
make([]byte, 200),
|
||||||
|
make([]byte, 100),
|
||||||
|
},
|
||||||
|
length: 300,
|
||||||
|
expected: [][]byte{
|
||||||
|
make([]byte, 200),
|
||||||
|
make([]byte, 100),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple buffers truncated",
|
||||||
|
buffers: [][]byte{
|
||||||
|
make([]byte, 200),
|
||||||
|
make([]byte, 100),
|
||||||
|
},
|
||||||
|
length: 250,
|
||||||
|
expected: [][]byte{
|
||||||
|
make([]byte, 200),
|
||||||
|
make([]byte, 50),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple buffers truncated buffer list",
|
||||||
|
buffers: [][]byte{
|
||||||
|
make([]byte, 200),
|
||||||
|
make([]byte, 200),
|
||||||
|
make([]byte, 200),
|
||||||
|
},
|
||||||
|
length: 350,
|
||||||
|
expected: [][]byte{
|
||||||
|
make([]byte, 200),
|
||||||
|
make([]byte, 150),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
actual := truncateBuffers(tt.buffers, tt.length)
|
||||||
|
assert.Equal(t, tt.expected, actual)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
224
overlay/vhostnet/device_test.go
Normal file
224
overlay/vhostnet/device_test.go
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
package vhostnet_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gopacket/gopacket/afpacket"
|
||||||
|
"github.com/hetznercloud/virtio-go/internal/testsupport"
|
||||||
|
"github.com/hetznercloud/virtio-go/tuntap"
|
||||||
|
"github.com/hetznercloud/virtio-go/vhostnet"
|
||||||
|
"github.com/hetznercloud/virtio-go/virtio"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Here is the general idea of how the following tests work to verify the
|
||||||
|
// correct communication with the vhost-net device within the kernel:
|
||||||
|
//
|
||||||
|
// +-----------------------------------+
|
||||||
|
// | go test running in user space |
|
||||||
|
// +-----------------------------------+
|
||||||
|
// ^ ^
|
||||||
|
// | |
|
||||||
|
// capture / write transmit / receive
|
||||||
|
// using AF_PACKET using this package
|
||||||
|
// | |
|
||||||
|
// v v
|
||||||
|
// +----------------+ +-----------+
|
||||||
|
// | tun (TAP mode) |<---->| vhost-net |
|
||||||
|
// +----------------+ +-----------+
|
||||||
|
//
|
||||||
|
|
||||||
|
func TestDevice_TransmitPacket(t *testing.T) {
|
||||||
|
testsupport.VirtrunOnly(t)
|
||||||
|
|
||||||
|
fx := NewTestFixture(t)
|
||||||
|
|
||||||
|
for _, length := range []int{64, 1514, 9014, 64100} {
|
||||||
|
t.Run(fmt.Sprintf("%d byte packet", length), func(t *testing.T) {
|
||||||
|
vnethdr, pkt := testsupport.TestPacket(t, fx.TAPDevice.MAC(), length)
|
||||||
|
|
||||||
|
// Transmit the packet over the vhost-net device.
|
||||||
|
require.NoError(t, fx.NetDevice.TransmitPacket(vnethdr, pkt))
|
||||||
|
|
||||||
|
// Check if the packet arrived at the TAP device. The virtio-net
|
||||||
|
// header should have been stripped by the TAP device.
|
||||||
|
data, _, err := fx.TPacket.ReadPacketData()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, pkt, data)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDevice_ReceivePacket(t *testing.T) {
|
||||||
|
testsupport.VirtrunOnly(t)
|
||||||
|
|
||||||
|
fx := NewTestFixture(t)
|
||||||
|
|
||||||
|
for _, length := range []int{64, 1514, 9014, 64100} {
|
||||||
|
t.Run(fmt.Sprintf("%d byte packet", length), func(t *testing.T) {
|
||||||
|
vnethdr, pkt := testsupport.TestPacket(t, fx.TAPDevice.MAC(), length)
|
||||||
|
prependedPkt := testsupport.PrependPacket(t, vnethdr, pkt)
|
||||||
|
|
||||||
|
// Write the prepended packet to the TAP device.
|
||||||
|
require.NoError(t, fx.TPacket.WritePacketData(prependedPkt))
|
||||||
|
|
||||||
|
// Try to receive the packet on the vhost-net device.
|
||||||
|
vnethdr, data, err := fx.NetDevice.ReceivePacket()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, pkt, data)
|
||||||
|
|
||||||
|
// Large packets should have been received as multiple buffers.
|
||||||
|
assert.Equal(t, (len(prependedPkt)/os.Getpagesize())+1, int(vnethdr.NumBuffers))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDevice_TransmitManyPackets(t *testing.T) {
|
||||||
|
testsupport.VirtrunOnly(t)
|
||||||
|
|
||||||
|
fx := NewTestFixture(t)
|
||||||
|
|
||||||
|
// Test with a packet which does not fit into a single memory page.
|
||||||
|
vnethdr, pkt := testsupport.TestPacket(t, fx.TAPDevice.MAC(), 9014)
|
||||||
|
|
||||||
|
const count = 1024
|
||||||
|
var received int
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Go(func() {
|
||||||
|
for range count {
|
||||||
|
err := fx.NetDevice.TransmitPacket(vnethdr, pkt)
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
wg.Go(func() {
|
||||||
|
for range count {
|
||||||
|
data, _, err := fx.TPacket.ReadPacketData()
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
assert.Equal(t, pkt, data)
|
||||||
|
received++
|
||||||
|
}
|
||||||
|
})
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
assert.Equal(t, count, received)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDevice_ReceiveManyPackets(t *testing.T) {
|
||||||
|
testsupport.VirtrunOnly(t)
|
||||||
|
|
||||||
|
fx := NewTestFixture(t)
|
||||||
|
|
||||||
|
// Test with a packet which does not fit into a single memory page.
|
||||||
|
vnethdr, pkt := testsupport.TestPacket(t, fx.TAPDevice.MAC(), 9014)
|
||||||
|
prependedPkt := testsupport.PrependPacket(t, vnethdr, pkt)
|
||||||
|
|
||||||
|
const count = 1024
|
||||||
|
var received int
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Go(func() {
|
||||||
|
for range count {
|
||||||
|
err := fx.TPacket.WritePacketData(prependedPkt)
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
wg.Go(func() {
|
||||||
|
for range count {
|
||||||
|
_, data, err := fx.NetDevice.ReceivePacket()
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
assert.Equal(t, pkt, data)
|
||||||
|
received++
|
||||||
|
}
|
||||||
|
})
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
assert.Equal(t, count, received)
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestFixture struct {
|
||||||
|
TAPDevice *tuntap.Device
|
||||||
|
NetDevice *vhostnet.Device
|
||||||
|
TPacket *afpacket.TPacket
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTestFixture(t *testing.T) *TestFixture {
|
||||||
|
testsupport.VirtrunOnly(t)
|
||||||
|
|
||||||
|
// In case something doesn't work, some more debug logging from the kernel
|
||||||
|
// modules may be very helpful.
|
||||||
|
testsupport.EnableDynamicDebug(t, "module tun")
|
||||||
|
testsupport.EnableDynamicDebug(t, "module vhost")
|
||||||
|
testsupport.EnableDynamicDebug(t, "module vhost_net")
|
||||||
|
|
||||||
|
// Make sure the Linux kernel does not send router solicitations that may
|
||||||
|
// interfere with these tests.
|
||||||
|
testsupport.SetSysctl(t, "net.ipv6.conf.all.disable_ipv6", "1")
|
||||||
|
|
||||||
|
var (
|
||||||
|
fx TestFixture
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
// Create a TAP device.
|
||||||
|
fx.TAPDevice, err = tuntap.NewDevice(
|
||||||
|
tuntap.WithDeviceType(tuntap.DeviceTypeTAP),
|
||||||
|
// Helps to stop the Linux kernel from sending packets on this
|
||||||
|
// interface.
|
||||||
|
tuntap.WithInterfaceFlags(unix.IFF_NOARP),
|
||||||
|
// Packets going over this device are prepended with a virtio-net
|
||||||
|
// header. When this is not set, then packets written to the TAP device
|
||||||
|
// will be passed to the Linux network stack without their virtio-net
|
||||||
|
// header stripped.
|
||||||
|
tuntap.WithVirtioNetHdr(true),
|
||||||
|
// When writing packets into the TAP device using the RAW socket, we
|
||||||
|
// don't want the offloads to be applied by the kernel. Advertising
|
||||||
|
// offload support makes the kernel pass the offload request along to
|
||||||
|
// our vhost-net device.
|
||||||
|
tuntap.WithOffloads(unix.TUN_F_CSUM|unix.TUN_F_USO4|unix.TUN_F_USO6),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
assert.NoError(t, fx.TAPDevice.Close())
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a vhost-net device that uses the TAP device as the backend.
|
||||||
|
fx.NetDevice, err = vhostnet.NewDevice(
|
||||||
|
vhostnet.WithQueueSize(32),
|
||||||
|
vhostnet.WithBackendDevice(fx.TAPDevice),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
assert.NoError(t, fx.NetDevice.Close())
|
||||||
|
})
|
||||||
|
|
||||||
|
// Open a RAW socket to capture packets arriving at the TAP device or
|
||||||
|
// write packets into it.
|
||||||
|
fx.TPacket, err = afpacket.NewTPacket(
|
||||||
|
afpacket.SocketRaw,
|
||||||
|
afpacket.TPacketVersion3,
|
||||||
|
afpacket.OptInterface(fx.TAPDevice.Name()),
|
||||||
|
|
||||||
|
// Tell the kernel that packets written to this socket are prepended
|
||||||
|
// with a virto-net header. This is used to communicate the use of GSO
|
||||||
|
// for large packets.
|
||||||
|
afpacket.OptVNetHdrSize(virtio.NetHdrSize),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(fx.TPacket.Close)
|
||||||
|
|
||||||
|
return &fx
|
||||||
|
}
|
||||||
3
overlay/vhostnet/doc.go
Normal file
3
overlay/vhostnet/doc.go
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
// Package vhostnet implements methods to initialize vhost networking devices
|
||||||
|
// within the kernel-level virtio implementation and communicate with them.
|
||||||
|
package vhostnet
|
||||||
31
overlay/vhostnet/ioctl.go
Normal file
31
overlay/vhostnet/ioctl.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package vhostnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/hetznercloud/virtio-go/vhost"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// vhostNetIoctlSetBackend can be used to attach a virtqueue to a RAW socket
|
||||||
|
// or TAP device.
|
||||||
|
//
|
||||||
|
// Request payload: [vhost.QueueFile]
|
||||||
|
// Kernel name: VHOST_NET_SET_BACKEND
|
||||||
|
vhostNetIoctlSetBackend = 0x4008af30
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetQueueBackend attaches a virtqueue of the vhost networking device
|
||||||
|
// described by controlFD to the given backend file descriptor.
|
||||||
|
// The backend file descriptor can either be a RAW socket or a TAP device. When
|
||||||
|
// it is -1, the queue will be detached.
|
||||||
|
func SetQueueBackend(controlFD int, queueIndex uint32, backendFD int) error {
|
||||||
|
if err := vhost.IoctlPtr(controlFD, vhostNetIoctlSetBackend, unsafe.Pointer(&vhost.QueueFile{
|
||||||
|
QueueIndex: queueIndex,
|
||||||
|
FD: int32(backendFD),
|
||||||
|
})); err != nil {
|
||||||
|
return fmt.Errorf("set queue backend file descriptor: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
70
overlay/vhostnet/options.go
Normal file
70
overlay/vhostnet/options.go
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
package vhostnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/hetznercloud/virtio-go/tuntap"
|
||||||
|
"github.com/hetznercloud/virtio-go/virtqueue"
|
||||||
|
)
|
||||||
|
|
||||||
|
type optionValues struct {
|
||||||
|
queueSize int
|
||||||
|
backendFD int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *optionValues) apply(options []Option) {
|
||||||
|
for _, option := range options {
|
||||||
|
option(o)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *optionValues) validate() error {
|
||||||
|
if o.queueSize == -1 {
|
||||||
|
return errors.New("queue size is required")
|
||||||
|
}
|
||||||
|
if err := virtqueue.CheckQueueSize(o.queueSize); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if o.backendFD == -1 {
|
||||||
|
return errors.New("backend file descriptor is required")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var optionDefaults = optionValues{
|
||||||
|
// Required.
|
||||||
|
queueSize: -1,
|
||||||
|
// Required.
|
||||||
|
backendFD: -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option can be passed to [NewDevice] to influence device creation.
|
||||||
|
type Option func(*optionValues)
|
||||||
|
|
||||||
|
// WithQueueSize returns an [Option] that sets the size of the TX and RX queues
|
||||||
|
// that are to be created for the device. It specifies the number of
|
||||||
|
// entries/buffers each queue can hold. This also affects the memory
|
||||||
|
// consumption.
|
||||||
|
// This is required and must be an integer from 1 to 32768 that is also a power
|
||||||
|
// of 2.
|
||||||
|
func WithQueueSize(queueSize int) Option {
|
||||||
|
return func(o *optionValues) { o.queueSize = queueSize }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBackendFD returns an [Option] that sets the file descriptor of the
|
||||||
|
// backend that will be used for the queues of the device. The device will write
|
||||||
|
// and read packets to/from that backend. The file descriptor can either be of a
|
||||||
|
// RAW socket or TUN/TAP device.
|
||||||
|
// Either this or [WithBackendDevice] is required.
|
||||||
|
func WithBackendFD(backendFD int) Option {
|
||||||
|
return func(o *optionValues) { o.backendFD = backendFD }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBackendDevice returns an [Option] that sets the given TAP device as the
|
||||||
|
// backend that will be used for the queues of the device. The device will
|
||||||
|
// write and read packets to/from that backend. The TAP device should have been
|
||||||
|
// created with the [tuntap.WithVirtioNetHdr] option enabled.
|
||||||
|
// Either this or [WithBackendFD] is required.
|
||||||
|
func WithBackendDevice(dev *tuntap.Device) Option {
|
||||||
|
return func(o *optionValues) { o.backendFD = int(dev.File().Fd()) }
|
||||||
|
}
|
||||||
66
overlay/vhostnet/options_internal_test.go
Normal file
66
overlay/vhostnet/options_internal_test.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package vhostnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOptionValues_Apply(t *testing.T) {
|
||||||
|
opts := optionDefaults
|
||||||
|
opts.apply([]Option{
|
||||||
|
WithQueueSize(256),
|
||||||
|
WithBackendFD(99),
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, optionValues{
|
||||||
|
queueSize: 256,
|
||||||
|
backendFD: 99,
|
||||||
|
}, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOptionValues_Validate(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
values optionValues
|
||||||
|
assertErr assert.ErrorAssertionFunc
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "queue size missing",
|
||||||
|
values: optionValues{
|
||||||
|
queueSize: -1,
|
||||||
|
backendFD: 99,
|
||||||
|
},
|
||||||
|
assertErr: assert.Error,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid queue size",
|
||||||
|
values: optionValues{
|
||||||
|
queueSize: 24,
|
||||||
|
backendFD: 99,
|
||||||
|
},
|
||||||
|
assertErr: assert.Error,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "backend fd missing",
|
||||||
|
values: optionValues{
|
||||||
|
queueSize: 256,
|
||||||
|
backendFD: -1,
|
||||||
|
},
|
||||||
|
assertErr: assert.Error,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid",
|
||||||
|
values: optionValues{
|
||||||
|
queueSize: 256,
|
||||||
|
backendFD: 99,
|
||||||
|
},
|
||||||
|
assertErr: assert.NoError,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.assertErr(t, tt.values.validate())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user