This commit is contained in:
JackDoan
2025-11-07 14:26:35 -06:00
parent 3ec527e42c
commit ac5382928e
50 changed files with 4013 additions and 123 deletions

View File

@@ -0,0 +1,23 @@
Significant portions of this code are derived from https://pkg.go.dev/github.com/hetznercloud/virtio-go
MIT License
Copyright (c) 2025 Hetzner Cloud GmbH
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

372
overlay/vhostnet/device.go Normal file
View File

@@ -0,0 +1,372 @@
package vhostnet
import (
"context"
"errors"
"fmt"
"os"
"runtime"
"github.com/slackhq/nebula/overlay/vhost"
"github.com/slackhq/nebula/overlay/virtqueue"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/util/virtio"
"golang.org/x/sys/unix"
)
// ErrDeviceClosed is returned when the [Device] is closed while operations are
// still running.
var ErrDeviceClosed = errors.New("device was closed")
// The indexes for the receive and transmit queues.
const (
receiveQueueIndex = 0
transmitQueueIndex = 1
)
// Device represents a vhost networking device within the kernel-level virtio
// implementation and provides methods to interact with it.
type Device struct {
initialized bool
controlFD int
fullTable bool
ReceiveQueue *virtqueue.SplitQueue
TransmitQueue *virtqueue.SplitQueue
}
// NewDevice initializes a new vhost networking device within the
// kernel-level virtio implementation, sets up the virtqueues and returns a
// [Device] instance that can be used to communicate with that vhost device.
//
// There are multiple options that can be passed to this constructor to
// influence device creation:
// - [WithQueueSize]
// - [WithBackendFD]
// - [WithBackendDevice]
//
// Remember to call [Device.Close] after use to free up resources.
func NewDevice(options ...Option) (*Device, error) {
var err error
opts := optionDefaults
opts.apply(options)
if err = opts.validate(); err != nil {
return nil, fmt.Errorf("invalid options: %w", err)
}
dev := Device{
controlFD: -1,
}
// Clean up a partially initialized device when something fails.
defer func() {
if err != nil {
_ = dev.Close()
}
}()
// Retrieve a new control file descriptor. This will be used to configure
// the vhost networking device in the kernel.
dev.controlFD, err = unix.Open("/dev/vhost-net", os.O_RDWR, 0666)
if err != nil {
return nil, fmt.Errorf("get control file descriptor: %w", err)
}
if err = vhost.OwnControlFD(dev.controlFD); err != nil {
return nil, fmt.Errorf("own control file descriptor: %w", err)
}
// Advertise the supported features. This isn't much for now.
// TODO: Add feature options and implement proper feature negotiation.
getFeatures, err := vhost.GetFeatures(dev.controlFD) //0x1033D008000 but why
if err != nil {
return nil, fmt.Errorf("get features: %w", err)
}
if getFeatures == 0 {
}
//const funky = virtio.Feature(1 << 27)
//features := virtio.FeatureVersion1 | funky // | todo virtio.FeatureNetMergeRXBuffers
features := virtio.FeatureVersion1 | virtio.FeatureNetMergeRXBuffers
if err = vhost.SetFeatures(dev.controlFD, features); err != nil {
return nil, fmt.Errorf("set features: %w", err)
}
itemSize := os.Getpagesize() * 4 //todo config
// Initialize and register the queues needed for the networking device.
if dev.ReceiveQueue, err = createQueue(dev.controlFD, receiveQueueIndex, opts.queueSize, itemSize); err != nil {
return nil, fmt.Errorf("create receive queue: %w", err)
}
if dev.TransmitQueue, err = createQueue(dev.controlFD, transmitQueueIndex, opts.queueSize, itemSize); err != nil {
return nil, fmt.Errorf("create transmit queue: %w", err)
}
// Set up memory mappings for all buffers used by the queues. This has to
// happen before a backend for the queues can be registered.
memoryLayout := vhost.NewMemoryLayoutForQueues(
[]*virtqueue.SplitQueue{dev.ReceiveQueue, dev.TransmitQueue},
)
if err = vhost.SetMemoryLayout(dev.controlFD, memoryLayout); err != nil {
return nil, fmt.Errorf("setup memory layout: %w", err)
}
// Set the queue backends. This activates the queues within the kernel.
if err = SetQueueBackend(dev.controlFD, receiveQueueIndex, opts.backendFD); err != nil {
return nil, fmt.Errorf("set receive queue backend: %w", err)
}
if err = SetQueueBackend(dev.controlFD, transmitQueueIndex, opts.backendFD); err != nil {
return nil, fmt.Errorf("set transmit queue backend: %w", err)
}
// Fully populate the receive queue with available buffers which the device
// can write new packets into.
if err = dev.refillReceiveQueue(); err != nil {
return nil, fmt.Errorf("refill receive queue: %w", err)
}
dev.initialized = true
// Make sure to clean up even when the device gets garbage collected without
// Close being called first.
devPtr := &dev
runtime.SetFinalizer(devPtr, (*Device).Close)
return devPtr, nil
}
// refillReceiveQueue offers as many new device-writable buffers to the device
// as the queue can fit. The device will then use these to write received
// packets.
func (dev *Device) refillReceiveQueue() error {
for {
_, err := dev.ReceiveQueue.OfferInDescriptorChains()
if err != nil {
if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
// Queue is full, job is done.
return nil
}
return fmt.Errorf("offer descriptor chain: %w", err)
}
}
}
// Close cleans up the vhost networking device within the kernel and releases
// all resources used for it.
// The implementation will try to release as many resources as possible and
// collect potential errors before returning them.
func (dev *Device) Close() error {
dev.initialized = false
// Closing the control file descriptor will unregister all queues from the
// kernel.
if dev.controlFD >= 0 {
if err := unix.Close(dev.controlFD); err != nil {
// Return an error and do not continue, because the memory used for
// the queues should not be released before they were unregistered
// from the kernel.
return fmt.Errorf("close control file descriptor: %w", err)
}
dev.controlFD = -1
}
var errs []error
if dev.ReceiveQueue != nil {
if err := dev.ReceiveQueue.Close(); err == nil {
dev.ReceiveQueue = nil
} else {
errs = append(errs, fmt.Errorf("close receive queue: %w", err))
}
}
if dev.TransmitQueue != nil {
if err := dev.TransmitQueue.Close(); err == nil {
dev.TransmitQueue = nil
} else {
errs = append(errs, fmt.Errorf("close transmit queue: %w", err))
}
}
if len(errs) == 0 {
// Everything was cleaned up. No need to run the finalizer anymore.
runtime.SetFinalizer(dev, nil)
}
return errors.Join(errs...)
}
// createQueue creates a new virtqueue and registers it with the vhost device
// using the given index.
func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*virtqueue.SplitQueue, error) {
var (
queue *virtqueue.SplitQueue
err error
)
if queue, err = virtqueue.NewSplitQueue(queueSize, itemSize); err != nil {
return nil, fmt.Errorf("create virtqueue: %w", err)
}
if err = vhost.RegisterQueue(controlFD, uint32(queueIndex), queue); err != nil {
return nil, fmt.Errorf("register virtqueue with index %d: %w", queueIndex, err)
}
return queue, nil
}
func (dev *Device) GetPacketForTx() (uint16, []byte, error) {
var err error
var idx uint16
if !dev.fullTable {
idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
if err == virtqueue.ErrNotEnoughFreeDescriptors {
dev.fullTable = true
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
}
} else {
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
}
if err != nil {
return 0, nil, fmt.Errorf("transmit queue: %w", err)
}
buf, err := dev.TransmitQueue.GetDescriptorItem(idx)
if err != nil {
return 0, nil, fmt.Errorf("get descriptor chain: %w", err)
}
return idx, buf, nil
}
func (dev *Device) TransmitPacket(pkt *packet.OutPacket, kick bool) error {
if len(pkt.SegmentIDs) == 0 {
return nil
}
for idx := range pkt.SegmentIDs {
segmentID := pkt.SegmentIDs[idx]
dev.TransmitQueue.SetDescSize(segmentID, len(pkt.Segments[idx]))
}
err := dev.TransmitQueue.OfferDescriptorChains(pkt.SegmentIDs, false)
if err != nil {
return fmt.Errorf("offer descriptor chains: %w", err)
}
pkt.Reset()
if kick {
if err := dev.TransmitQueue.Kick(); err != nil {
return err
}
}
return nil
}
func (dev *Device) TransmitPackets(pkts []*packet.OutPacket) error {
if len(pkts) == 0 {
return nil
}
for i := range pkts {
if err := dev.TransmitPacket(pkts[i], false); err != nil {
return err
}
}
if err := dev.TransmitQueue.Kick(); err != nil {
return err
}
return nil
}
// TODO: Make above methods cancelable by taking a context.Context argument?
// TODO: Implement zero-copy variants to transmit and receive packets?
// processChains processes as many chains as needed to create one packet. The number of processed chains is returned.
func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.UsedElement) (int, error) {
//read first element to see how many descriptors we need:
pkt.Reset()
err := dev.ReceiveQueue.GetDescriptorInbuffers(uint16(chains[0].DescriptorIndex), &pkt.ChainRefs)
if err != nil {
return 0, fmt.Errorf("get descriptor chain: %w", err)
}
if len(pkt.ChainRefs) == 0 {
return 1, nil
}
// The specification requires that the first descriptor chain starts
// with a virtio-net header. It is not clear, whether it is also
// required to be fully contained in the first buffer of that
// descriptor chain, but it is reasonable to assume that this is
// always the case.
// The decode method already does the buffer length check.
if err = pkt.Header.Decode(pkt.ChainRefs[0][0:]); err != nil {
// The device misbehaved. There is no way we can gracefully
// recover from this, because we don't know how many of the
// following descriptor chains belong to this packet.
return 0, fmt.Errorf("decode vnethdr: %w", err)
}
//we have the header now: what do we need to do?
if int(pkt.Header.NumBuffers) > len(chains) {
return 0, fmt.Errorf("number of buffers is greater than number of chains %d", len(chains))
}
if int(pkt.Header.NumBuffers) != 1 {
return 0, fmt.Errorf("too smol-brain to handle more than one chain right now: %d chains", len(chains))
}
if chains[0].Length > 16000 {
//todo!
return 1, fmt.Errorf("too big packet length: %d", chains[0].Length)
}
//shift the buffer out of out:
pkt.Payload = pkt.ChainRefs[0][virtio.NetHdrSize:chains[0].Length]
pkt.Chains = append(pkt.Chains, uint16(chains[0].DescriptorIndex))
return 1, nil
//cursor := n - virtio.NetHdrSize
//
//if uint32(n) >= chains[0].Length && pkt.Header.NumBuffers == 1 {
// pkt.Payload = pkt.Payload[:chains[0].Length-virtio.NetHdrSize]
// return 1, nil
//}
//
//i := 1
//// we used chain 0 already
//for i = 1; i < len(chains); i++ {
// n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:], int(chains[i].Length))
// if err != nil {
// // When this fails we may miss to free some descriptor chains. We
// // could try to mitigate this by deferring the freeing somehow, but
// // it's not worth the hassle. When this method fails, the queue will
// // be in a broken state anyway.
// return i, fmt.Errorf("get descriptor chain: %w", err)
// }
// cursor += n
//}
////todo this has to be wrong
//pkt.Payload = pkt.Payload[:cursor]
//return i, nil
}
func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) {
//todo optimize?
var chains []virtqueue.UsedElement
var err error
chains, err = dev.ReceiveQueue.BlockAndGetHeadsCapped(context.TODO(), len(out))
if err != nil {
return 0, err
}
if len(chains) == 0 {
return 0, nil
}
numPackets := 0
chainsIdx := 0
for numPackets = 0; chainsIdx < len(chains); numPackets++ {
if numPackets >= len(out) {
return numPackets, fmt.Errorf("dropping %d packets, no room", len(chains)-numPackets)
}
numChains, err := dev.processChains(out[numPackets], chains[chainsIdx:])
if err != nil {
return 0, err
}
chainsIdx += numChains
}
return numPackets, nil
}

3
overlay/vhostnet/doc.go Normal file
View File

@@ -0,0 +1,3 @@
// Package vhostnet implements methods to initialize vhost networking devices
// within the kernel-level virtio implementation and communicate with them.
package vhostnet

31
overlay/vhostnet/ioctl.go Normal file
View File

@@ -0,0 +1,31 @@
package vhostnet
import (
"fmt"
"unsafe"
"github.com/slackhq/nebula/overlay/vhost"
)
const (
// vhostNetIoctlSetBackend can be used to attach a virtqueue to a RAW socket
// or TAP device.
//
// Request payload: [vhost.QueueFile]
// Kernel name: VHOST_NET_SET_BACKEND
vhostNetIoctlSetBackend = 0x4008af30
)
// SetQueueBackend attaches a virtqueue of the vhost networking device
// described by controlFD to the given backend file descriptor.
// The backend file descriptor can either be a RAW socket or a TAP device. When
// it is -1, the queue will be detached.
func SetQueueBackend(controlFD int, queueIndex uint32, backendFD int) error {
if err := vhost.IoctlPtr(controlFD, vhostNetIoctlSetBackend, unsafe.Pointer(&vhost.QueueFile{
QueueIndex: queueIndex,
FD: int32(backendFD),
})); err != nil {
return fmt.Errorf("set queue backend file descriptor: %w", err)
}
return nil
}

View File

@@ -0,0 +1,69 @@
package vhostnet
import (
"errors"
"github.com/slackhq/nebula/overlay/virtqueue"
)
type optionValues struct {
queueSize int
backendFD int
}
func (o *optionValues) apply(options []Option) {
for _, option := range options {
option(o)
}
}
func (o *optionValues) validate() error {
if o.queueSize == -1 {
return errors.New("queue size is required")
}
if err := virtqueue.CheckQueueSize(o.queueSize); err != nil {
return err
}
if o.backendFD == -1 {
return errors.New("backend file descriptor is required")
}
return nil
}
var optionDefaults = optionValues{
// Required.
queueSize: -1,
// Required.
backendFD: -1,
}
// Option can be passed to [NewDevice] to influence device creation.
type Option func(*optionValues)
// WithQueueSize returns an [Option] that sets the size of the TX and RX queues
// that are to be created for the device. It specifies the number of
// entries/buffers each queue can hold. This also affects the memory
// consumption.
// This is required and must be an integer from 1 to 32768 that is also a power
// of 2.
func WithQueueSize(queueSize int) Option {
return func(o *optionValues) { o.queueSize = queueSize }
}
// WithBackendFD returns an [Option] that sets the file descriptor of the
// backend that will be used for the queues of the device. The device will write
// and read packets to/from that backend. The file descriptor can either be of a
// RAW socket or TUN/TAP device.
// Either this or [WithBackendDevice] is required.
func WithBackendFD(backendFD int) Option {
return func(o *optionValues) { o.backendFD = backendFD }
}
//// WithBackendDevice returns an [Option] that sets the given TAP device as the
//// backend that will be used for the queues of the device. The device will
//// write and read packets to/from that backend. The TAP device should have been
//// created with the [tuntap.WithVirtioNetHdr] option enabled.
//// Either this or [WithBackendFD] is required.
//func WithBackendDevice(dev *tuntap.Device) Option {
// return func(o *optionValues) { o.backendFD = int(dev.File().Fd()) }
//}