package vhostnet import ( "context" "errors" "fmt" "os" "runtime" "github.com/slackhq/nebula/overlay/vhost" "github.com/slackhq/nebula/overlay/virtqueue" "github.com/slackhq/nebula/packet" "github.com/slackhq/nebula/util/virtio" "golang.org/x/sys/unix" ) // ErrDeviceClosed is returned when the [Device] is closed while operations are // still running. var ErrDeviceClosed = errors.New("device was closed") // The indexes for the receive and transmit queues. const ( receiveQueueIndex = 0 transmitQueueIndex = 1 ) // Device represents a vhost networking device within the kernel-level virtio // implementation and provides methods to interact with it. type Device struct { initialized bool controlFD int ReceiveQueue *virtqueue.SplitQueue TransmitQueue *virtqueue.SplitQueue } // NewDevice initializes a new vhost networking device within the // kernel-level virtio implementation, sets up the virtqueues and returns a // [Device] instance that can be used to communicate with that vhost device. // // There are multiple options that can be passed to this constructor to // influence device creation: // - [WithQueueSize] // - [WithBackendFD] // - [WithBackendDevice] // // Remember to call [Device.Close] after use to free up resources. func NewDevice(options ...Option) (*Device, error) { var err error opts := optionDefaults opts.apply(options) if err = opts.validate(); err != nil { return nil, fmt.Errorf("invalid options: %w", err) } dev := Device{ controlFD: -1, } // Clean up a partially initialized device when something fails. defer func() { if err != nil { _ = dev.Close() } }() // Retrieve a new control file descriptor. This will be used to configure // the vhost networking device in the kernel. dev.controlFD, err = unix.Open("/dev/vhost-net", os.O_RDWR, 0666) if err != nil { return nil, fmt.Errorf("get control file descriptor: %w", err) } if err = vhost.OwnControlFD(dev.controlFD); err != nil { return nil, fmt.Errorf("own control file descriptor: %w", err) } // Advertise the supported features. This isn't much for now. // TODO: Add feature options and implement proper feature negotiation. getFeatures, err := vhost.GetFeatures(dev.controlFD) if err != nil { return nil, fmt.Errorf("get features: %w", err) } if getFeatures == 0 { } //const funky = virtio.Feature(1 << 27) //features := virtio.FeatureVersion1 | funky // | todo virtio.FeatureNetMergeRXBuffers features := virtio.FeatureVersion1 | virtio.FeatureNetMergeRXBuffers if err = vhost.SetFeatures(dev.controlFD, features); err != nil { return nil, fmt.Errorf("set features: %w", err) } // Initialize and register the queues needed for the networking device. if dev.ReceiveQueue, err = createQueue(dev.controlFD, receiveQueueIndex, opts.queueSize); err != nil { return nil, fmt.Errorf("create receive queue: %w", err) } if dev.TransmitQueue, err = createQueue(dev.controlFD, transmitQueueIndex, opts.queueSize); err != nil { return nil, fmt.Errorf("create transmit queue: %w", err) } // Set up memory mappings for all buffers used by the queues. This has to // happen before a backend for the queues can be registered. memoryLayout := vhost.NewMemoryLayoutForQueues( []*virtqueue.SplitQueue{dev.ReceiveQueue, dev.TransmitQueue}, ) if err = vhost.SetMemoryLayout(dev.controlFD, memoryLayout); err != nil { return nil, fmt.Errorf("setup memory layout: %w", err) } // Set the queue backends. This activates the queues within the kernel. if err = SetQueueBackend(dev.controlFD, receiveQueueIndex, opts.backendFD); err != nil { return nil, fmt.Errorf("set receive queue backend: %w", err) } if err = SetQueueBackend(dev.controlFD, transmitQueueIndex, opts.backendFD); err != nil { return nil, fmt.Errorf("set transmit queue backend: %w", err) } // Fully populate the receive queue with available buffers which the device // can write new packets into. if err = dev.refillReceiveQueue(); err != nil { return nil, fmt.Errorf("refill receive queue: %w", err) } dev.initialized = true // Make sure to clean up even when the device gets garbage collected without // Close being called first. devPtr := &dev runtime.SetFinalizer(devPtr, (*Device).Close) return devPtr, nil } // refillReceiveQueue offers as many new device-writable buffers to the device // as the queue can fit. The device will then use these to write received // packets. func (dev *Device) refillReceiveQueue() error { for { _, err := dev.ReceiveQueue.OfferInDescriptorChains(1) if err != nil { if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) { // Queue is full, job is done. return nil } return fmt.Errorf("offer descriptor chain: %w", err) } } } // Close cleans up the vhost networking device within the kernel and releases // all resources used for it. // The implementation will try to release as many resources as possible and // collect potential errors before returning them. func (dev *Device) Close() error { dev.initialized = false // Closing the control file descriptor will unregister all queues from the // kernel. if dev.controlFD >= 0 { if err := unix.Close(dev.controlFD); err != nil { // Return an error and do not continue, because the memory used for // the queues should not be released before they were unregistered // from the kernel. return fmt.Errorf("close control file descriptor: %w", err) } dev.controlFD = -1 } var errs []error if dev.ReceiveQueue != nil { if err := dev.ReceiveQueue.Close(); err == nil { dev.ReceiveQueue = nil } else { errs = append(errs, fmt.Errorf("close receive queue: %w", err)) } } if dev.TransmitQueue != nil { if err := dev.TransmitQueue.Close(); err == nil { dev.TransmitQueue = nil } else { errs = append(errs, fmt.Errorf("close transmit queue: %w", err)) } } if len(errs) == 0 { // Everything was cleaned up. No need to run the finalizer anymore. runtime.SetFinalizer(dev, nil) } return errors.Join(errs...) } // ensureInitialized is used as a guard to prevent methods to be called on an // uninitialized instance. func (dev *Device) ensureInitialized() { if !dev.initialized { panic("device is not initialized") } } // createQueue creates a new virtqueue and registers it with the vhost device // using the given index. func createQueue(controlFD int, queueIndex int, queueSize int) (*virtqueue.SplitQueue, error) { var ( queue *virtqueue.SplitQueue err error ) if queue, err = virtqueue.NewSplitQueue(queueSize); err != nil { return nil, fmt.Errorf("create virtqueue: %w", err) } if err = vhost.RegisterQueue(controlFD, uint32(queueIndex), queue); err != nil { return nil, fmt.Errorf("register virtqueue with index %d: %w", queueIndex, err) } return queue, nil } // truncateBuffers returns a new list of buffers whose combined length matches // exactly the specified length. When the specified length exceeds the length of // the buffers, this is an error. When it is smaller, the buffer list will be // truncated accordingly. func truncateBuffers(buffers [][]byte, length int) (out [][]byte) { for _, buffer := range buffers { if length < len(buffer) { out = append(out, buffer[:length]) return } out = append(out, buffer) length -= len(buffer) } if length > 0 { panic("length exceeds the combined length of all buffers") } return } func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) error { // Prepend the packet with its virtio-net header. vnethdrBuf := make([]byte, virtio.NetHdrSize+14) //todo WHY if err := vnethdr.Encode(vnethdrBuf); err != nil { return fmt.Errorf("encode vnethdr: %w", err) } vnethdrBuf[virtio.NetHdrSize+14-2] = 0x86 vnethdrBuf[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype chainIndexes, err := dev.TransmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets) if err != nil { return fmt.Errorf("offer descriptor chain: %w", err) } //todo surely there's something better to do here doneYet := map[uint16]bool{} for _, chain := range chainIndexes { doneYet[chain] = false } for { txedChains, err := dev.TransmitQueue.BlockAndGetHeads(context.TODO()) if err != nil { return err } else if len(txedChains) == 0 { continue //todo will this ever exit? } for c := range txedChains { doneYet[txedChains[c].GetHead()] = true } done := true //optimism! for _, x := range doneYet { if !x { done = false break } } if done { break } } // Wait for the packet to have been transmitted. for i := range chainIndexes { if err = dev.TransmitQueue.FreeDescriptorChain(chainIndexes[i]); err != nil { return fmt.Errorf("free descriptor chain: %w", err) } } return nil } // TODO: Make above methods cancelable by taking a context.Context argument? // TODO: Implement zero-copy variants to transmit and receive packets? // processChains processes as many chains as needed to create one packet. The number of processed chains is returned. func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.UsedElement) (int, error) { //read first element to see how many descriptors we need: pkt.Reset() n, err := dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[0].DescriptorIndex), pkt.Payload, int(chains[0].Length)) //todo if err != nil { return 0, err } // The specification requires that the first descriptor chain starts // with a virtio-net header. It is not clear, whether it is also // required to be fully contained in the first buffer of that // descriptor chain, but it is reasonable to assume that this is // always the case. // The decode method already does the buffer length check. if err = pkt.Header.Decode(pkt.Payload[0:]); err != nil { // The device misbehaved. There is no way we can gracefully // recover from this, because we don't know how many of the // following descriptor chains belong to this packet. return 0, fmt.Errorf("decode vnethdr: %w", err) } //we have the header now: what do we need to do? if int(pkt.Header.NumBuffers) > len(chains) { return 0, fmt.Errorf("number of buffers is greater than number of chains %d", len(chains)) } //shift the buffer out of out: pkt.Payload = pkt.Payload[virtio.NetHdrSize:] cursor := n - virtio.NetHdrSize if uint32(n) >= chains[0].Length && pkt.Header.NumBuffers == 1 { pkt.Payload = pkt.Payload[:chains[0].Length-virtio.NetHdrSize] return 1, nil } i := 1 // we used chain 0 already for i = 1; i < len(chains); i++ { n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:], int(chains[i].Length)) if err != nil { // When this fails we may miss to free some descriptor chains. We // could try to mitigate this by deferring the freeing somehow, but // it's not worth the hassle. When this method fails, the queue will // be in a broken state anyway. return i, fmt.Errorf("get descriptor chain: %w", err) } cursor += n } //todo this has to be wrong pkt.Payload = pkt.Payload[:cursor] return i, nil } func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) { //todo optimize? var chains []virtqueue.UsedElement var err error //if len(dev.extraRx) == 0 { chains, err = dev.ReceiveQueue.BlockAndGetHeadsCapped(context.TODO(), 64) //todo config batch if err != nil { return 0, err } if len(chains) == 0 { return 0, nil } //} else { // chains = dev.extraRx //} numPackets := 0 chainsIdx := 0 for numPackets = 0; chainsIdx < len(chains); numPackets++ { if numPackets >= len(out) { //dev.extraRx = chains[chainsIdx:] //return numPackets, nil return numPackets, fmt.Errorf("dropping %d packets, no room", len(chains)-numPackets) } numChains, err := dev.processChains(out[numPackets], chains[chainsIdx:]) if err != nil { return 0, err } chainsIdx += numChains } // Now that we have copied all buffers, we can recycle the used descriptor chains if err := dev.ReceiveQueue.RecycleDescriptorChains(chains); err != nil { return 0, err } //if we don't churn chains, maybe we don't need this? // It's advised to always keep the rx queue fully populated with available buffers which the device can write new packets into. //if err := dev.refillReceiveQueue(); err != nil { // return 0, virtio.NetHdr{}, fmt.Errorf("refill receive queue: %w", err) //} return numPackets, nil }