mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 16:34:25 +01:00
broken chkpt
This commit is contained in:
@@ -2,19 +2,22 @@ package overlay
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
||||
"github.com/slackhq/nebula/packet"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
const DefaultMTU = 1300
|
||||
|
||||
type TunDev interface {
|
||||
ReadMany([][]byte) (int, error)
|
||||
io.WriteCloser
|
||||
ReadMany([]*packet.VirtIOPacket) (int, error)
|
||||
WriteMany([][]byte) (int, error)
|
||||
GetQueues() []*virtqueue.SplitQueue
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
||||
"github.com/slackhq/nebula/packet"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
@@ -122,8 +123,8 @@ func (t *disabledTun) WriteMany(b [][]byte) (int, error) {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (t *disabledTun) ReadMany(b [][]byte) (int, error) {
|
||||
return t.Read(b[0])
|
||||
func (t *disabledTun) ReadMany(b []*packet.VirtIOPacket) (int, error) {
|
||||
return t.Read(b[0].Payload)
|
||||
}
|
||||
|
||||
func (t *disabledTun) NewMultiQueueReader() (TunDev, error) {
|
||||
|
||||
@@ -18,10 +18,11 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/overlay/vhostnet"
|
||||
"github.com/slackhq/nebula/overlay/virtio"
|
||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
||||
"github.com/slackhq/nebula/packet"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/slackhq/nebula/util/virtio"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
@@ -713,16 +714,11 @@ func (t *tun) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) ReadMany(p [][]byte) (int, error) {
|
||||
//todo call consumeUsedRing here instead of its own thread
|
||||
|
||||
n, hdr, err := t.vdev.ReceivePacket(p) //we are TXing
|
||||
func (t *tun) ReadMany(p []*packet.VirtIOPacket) (int, error) {
|
||||
n, err := t.vdev.ReceivePackets(p) //we are TXing
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if hdr.NumBuffers > 1 {
|
||||
t.l.WithField("num_buffers", hdr.NumBuffers).Info("wow, lots to TX from tun")
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
@@ -739,7 +735,7 @@ func (t *tun) Write(b []byte) (int, error) {
|
||||
NumBuffers: 0,
|
||||
}
|
||||
|
||||
err := t.vdev.TransmitPacket(hdr, b)
|
||||
err := t.vdev.TransmitPackets(hdr, [][]byte{b})
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("Transmitting packet")
|
||||
return 0, err
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
||||
"github.com/slackhq/nebula/packet"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
@@ -67,8 +68,8 @@ func (d *UserDevice) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *UserDevice) ReadMany(b [][]byte) (int, error) {
|
||||
return d.Read(b[0])
|
||||
func (d *UserDevice) ReadMany(b []*packet.VirtIOPacket) (int, error) {
|
||||
return d.Read(b[0].Payload)
|
||||
}
|
||||
|
||||
func (d *UserDevice) WriteMany(b [][]byte) (int, error) {
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"github.com/slackhq/nebula/overlay/virtio"
|
||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
||||
"github.com/slackhq/nebula/util/virtio"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
package vhostnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/slackhq/nebula/overlay/vhost"
|
||||
"github.com/slackhq/nebula/overlay/virtio"
|
||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
||||
"github.com/slackhq/nebula/packet"
|
||||
"github.com/slackhq/nebula/util/virtio"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
@@ -31,12 +33,7 @@ type Device struct {
|
||||
ReceiveQueue *virtqueue.SplitQueue
|
||||
TransmitQueue *virtqueue.SplitQueue
|
||||
|
||||
// transmitted contains channels for each possible descriptor chain head
|
||||
// index. This is used for packet transmit notifications.
|
||||
// When a packet was transmitted and the descriptor chain was used by the
|
||||
// device, the corresponding channel receives the [virtqueue.UsedElement]
|
||||
// instance provided by the device.
|
||||
transmitted []chan virtqueue.UsedElement
|
||||
extraRx []virtqueue.UsedElement
|
||||
}
|
||||
|
||||
// NewDevice initializes a new vhost networking device within the
|
||||
@@ -126,25 +123,6 @@ func NewDevice(options ...Option) (*Device, error) {
|
||||
return nil, fmt.Errorf("refill receive queue: %w", err)
|
||||
}
|
||||
|
||||
// Initialize channels for transmit notifications.
|
||||
dev.transmitted = make([]chan virtqueue.UsedElement, dev.TransmitQueue.Size())
|
||||
for i := range len(dev.transmitted) {
|
||||
// It is important to use a single-element buffered channel here.
|
||||
// When the channel was unbuffered and the monitorTransmitQueue
|
||||
// goroutine would write into it, the writing would block which could
|
||||
// lead to deadlocks in case transmit notifications do not arrive in
|
||||
// order.
|
||||
// When the goroutine would use fire-and-forget to write into that
|
||||
// channel, there may be a chance that the TransmitPacket does not
|
||||
// receive the transmit notification due to this being a race condition.
|
||||
// Buffering a single transmit notification resolves this without race
|
||||
// conditions or possible deadlocks.
|
||||
dev.transmitted[i] = make(chan virtqueue.UsedElement, 1)
|
||||
}
|
||||
|
||||
// Monitor transmit queue in background.
|
||||
go dev.monitorTransmitQueue()
|
||||
|
||||
dev.initialized = true
|
||||
|
||||
// Make sure to clean up even when the device gets garbage collected without
|
||||
@@ -155,32 +133,12 @@ func NewDevice(options ...Option) (*Device, error) {
|
||||
return devPtr, nil
|
||||
}
|
||||
|
||||
// monitorTransmitQueue waits for the device to advertise used descriptor chains
|
||||
// in the transmit queue and produces a transmit notification via the
|
||||
// corresponding channel.
|
||||
func (dev *Device) monitorTransmitQueue() {
|
||||
usedChan := dev.TransmitQueue.UsedDescriptorChains()
|
||||
for {
|
||||
used, ok := <-usedChan
|
||||
if !ok {
|
||||
// The queue was closed.
|
||||
return
|
||||
}
|
||||
if int(used.DescriptorIndex) > len(dev.transmitted) {
|
||||
panic(fmt.Sprintf("device provided a used descriptor index (%d) that is out of range",
|
||||
used.DescriptorIndex))
|
||||
}
|
||||
|
||||
dev.transmitted[used.DescriptorIndex] <- used
|
||||
}
|
||||
}
|
||||
|
||||
// refillReceiveQueue offers as many new device-writable buffers to the device
|
||||
// as the queue can fit. The device will then use these to write received
|
||||
// packets.
|
||||
func (dev *Device) refillReceiveQueue() error {
|
||||
for {
|
||||
_, err := dev.ReceiveQueue.OfferDescriptorChain(nil, 1, false)
|
||||
_, err := dev.ReceiveQueue.OfferInDescriptorChains(1)
|
||||
if err != nil {
|
||||
if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
|
||||
// Queue is full, job is done.
|
||||
@@ -279,38 +237,6 @@ func truncateBuffers(buffers [][]byte, length int) (out [][]byte) {
|
||||
return
|
||||
}
|
||||
|
||||
// TransmitPacket writes the given packet into the transmit queue of this
|
||||
// device. The packet will be prepended with the [virtio.NetHdr].
|
||||
//
|
||||
// When the queue is full, this will block until the queue has enough room to
|
||||
// transmit the packet. This method will not return before the packet was
|
||||
// transmitted and the device notifies that it has used the packet buffer.
|
||||
func (dev *Device) TransmitPacket(vnethdr virtio.NetHdr, packet []byte) error {
|
||||
// Prepend the packet with its virtio-net header.
|
||||
vnethdrBuf := make([]byte, virtio.NetHdrSize+14) //todo WHY
|
||||
if err := vnethdr.Encode(vnethdrBuf); err != nil {
|
||||
return fmt.Errorf("encode vnethdr: %w", err)
|
||||
}
|
||||
vnethdrBuf[virtio.NetHdrSize+14-2] = 0x86
|
||||
vnethdrBuf[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype
|
||||
outBuffers := [][]byte{vnethdrBuf, packet}
|
||||
//outBuffers := [][]byte{packet}
|
||||
|
||||
chainIndex, err := dev.TransmitQueue.OfferDescriptorChain(outBuffers, 0, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("offer descriptor chain: %w", err)
|
||||
}
|
||||
|
||||
// Wait for the packet to have been transmitted.
|
||||
<-dev.transmitted[chainIndex]
|
||||
|
||||
if err = dev.TransmitQueue.FreeDescriptorChain(chainIndex); err != nil {
|
||||
return fmt.Errorf("free descriptor chain: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) error {
|
||||
// Prepend the packet with its virtio-net header.
|
||||
vnethdrBuf := make([]byte, virtio.NetHdrSize+14) //todo WHY
|
||||
@@ -320,15 +246,42 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
|
||||
vnethdrBuf[virtio.NetHdrSize+14-2] = 0x86
|
||||
vnethdrBuf[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype
|
||||
|
||||
chainIndexes, err := dev.TransmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets, true)
|
||||
chainIndexes, err := dev.TransmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets)
|
||||
if err != nil {
|
||||
return fmt.Errorf("offer descriptor chain: %w", err)
|
||||
}
|
||||
//todo surely there's something better to do here
|
||||
doneYet := map[uint16]bool{}
|
||||
for _, chain := range chainIndexes {
|
||||
doneYet[chain] = false
|
||||
}
|
||||
|
||||
for {
|
||||
txedChains, err := dev.TransmitQueue.BlockAndGetHeads(context.TODO())
|
||||
if err != nil {
|
||||
return err
|
||||
} else if len(txedChains) == 0 {
|
||||
continue //todo will this ever exit?
|
||||
}
|
||||
for c := range txedChains {
|
||||
doneYet[txedChains[c].GetHead()] = true
|
||||
}
|
||||
done := true //optimism!
|
||||
for _, x := range doneYet {
|
||||
if !x {
|
||||
done = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if done {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
//todo blocking here suxxxx
|
||||
// Wait for the packet to have been transmitted.
|
||||
for i := range chainIndexes {
|
||||
<-dev.transmitted[chainIndexes[i]]
|
||||
|
||||
if err = dev.TransmitQueue.FreeDescriptorChain(chainIndexes[i]); err != nil {
|
||||
return fmt.Errorf("free descriptor chain: %w", err)
|
||||
@@ -338,106 +291,104 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReceivePacket reads the next available packet from the receive queue of this
|
||||
// device and returns its [virtio.NetHdr] and packet data separately.
|
||||
//
|
||||
// When no packet is available, this will block until there is one.
|
||||
//
|
||||
// When this method returns an error, the receive queue will likely be in a
|
||||
// broken state which this implementation cannot recover from. The caller should
|
||||
// close the device and not attempt any additional receives.
|
||||
func (dev *Device) ReceivePacket(out []byte) (int, virtio.NetHdr, error) {
|
||||
var (
|
||||
chainHeads []uint16
|
||||
// TODO: Make above methods cancelable by taking a context.Context argument?
|
||||
// TODO: Implement zero-copy variants to transmit and receive packets?
|
||||
|
||||
vnethdr virtio.NetHdr
|
||||
buffers [][]byte
|
||||
// processChains processes as many chains as needed to create one packet. The number of processed chains is returned.
|
||||
func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.UsedElement) (int, error) {
|
||||
//read first element to see how many descriptors we need:
|
||||
pkt.Payload = pkt.Payload[:cap(pkt.Payload)]
|
||||
n, err := dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[0].DescriptorIndex), pkt.Payload)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// The specification requires that the first descriptor chain starts
|
||||
// with a virtio-net header. It is not clear, whether it is also
|
||||
// required to be fully contained in the first buffer of that
|
||||
// descriptor chain, but it is reasonable to assume that this is
|
||||
// always the case.
|
||||
// The decode method already does the buffer length check.
|
||||
if err = pkt.Header.Decode(pkt.Payload[0:]); err != nil {
|
||||
// The device misbehaved. There is no way we can gracefully
|
||||
// recover from this, because we don't know how many of the
|
||||
// following descriptor chains belong to this packet.
|
||||
return 0, fmt.Errorf("decode vnethdr: %w", err)
|
||||
}
|
||||
|
||||
// Each packet starts with a virtio-net header which we have to subtract
|
||||
// from the total length.
|
||||
packetLength = -virtio.NetHdrSize
|
||||
)
|
||||
//we have the header now: what do we need to do?
|
||||
if int(pkt.Header.NumBuffers) > len(chains) {
|
||||
return 0, fmt.Errorf("number of buffers is greater than number of chains %d", len(chains))
|
||||
}
|
||||
|
||||
lenRead := 0
|
||||
//shift the buffer out of out:
|
||||
copy(pkt.Payload, pkt.Payload[virtio.NetHdrSize:])
|
||||
|
||||
// We presented FeatureNetMergeRXBuffers to the device, so one packet may be
|
||||
// made of multiple descriptor chains which are to be merged.
|
||||
for remainingChains := 1; remainingChains > 0; remainingChains-- {
|
||||
// Get the next descriptor chain.
|
||||
usedElement, ok := <-dev.ReceiveQueue.UsedDescriptorChains()
|
||||
if !ok {
|
||||
return 0, virtio.NetHdr{}, ErrDeviceClosed
|
||||
}
|
||||
cursor := n - virtio.NetHdrSize
|
||||
|
||||
// Track this chain to be freed later.
|
||||
head := uint16(usedElement.DescriptorIndex)
|
||||
chainHeads = append(chainHeads, head)
|
||||
if uint32(n) >= chains[0].Length && pkt.Header.NumBuffers == 1 {
|
||||
pkt.Payload = pkt.Payload[:chains[0].Length-virtio.NetHdrSize]
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
n, err := dev.ReceiveQueue.GetDescriptorChainContents(head, out[lenRead:])
|
||||
i := 1
|
||||
// we used chain 0 already
|
||||
for i = 1; i < len(chains); i++ {
|
||||
n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:])
|
||||
if err != nil {
|
||||
// When this fails we may miss to free some descriptor chains. We
|
||||
// could try to mitigate this by deferring the freeing somehow, but
|
||||
// it's not worth the hassle. When this method fails, the queue will
|
||||
// be in a broken state anyway.
|
||||
return 0, virtio.NetHdr{}, fmt.Errorf("get descriptor chain: %w", err)
|
||||
return i, fmt.Errorf("get descriptor chain: %w", err)
|
||||
}
|
||||
lenRead += n
|
||||
packetLength += int(usedElement.Length)
|
||||
cursor += n
|
||||
}
|
||||
//todo this has to be wrong
|
||||
pkt.Payload = pkt.Payload[:cursor]
|
||||
return i, nil
|
||||
}
|
||||
|
||||
// Is this the first descriptor chain we process?
|
||||
if len(buffers) == 0 {
|
||||
// The specification requires that the first descriptor chain starts
|
||||
// with a virtio-net header. It is not clear, whether it is also
|
||||
// required to be fully contained in the first buffer of that
|
||||
// descriptor chain, but it is reasonable to assume that this is
|
||||
// always the case.
|
||||
// The decode method already does the buffer length check.
|
||||
if err = vnethdr.Decode(out[0:]); err != nil {
|
||||
// The device misbehaved. There is no way we can gracefully
|
||||
// recover from this, because we don't know how many of the
|
||||
// following descriptor chains belong to this packet.
|
||||
return 0, virtio.NetHdr{}, fmt.Errorf("decode vnethdr: %w", err)
|
||||
}
|
||||
lenRead = 0
|
||||
out = out[virtio.NetHdrSize:]
|
||||
func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) {
|
||||
//todo optimize?
|
||||
var chains []virtqueue.UsedElement
|
||||
var err error
|
||||
//if len(dev.extraRx) == 0 {
|
||||
chains, err = dev.ReceiveQueue.BlockAndGetHeadsCapped(context.TODO(), 64) //todo config batch
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(chains) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
//} else {
|
||||
// chains = dev.extraRx
|
||||
//}
|
||||
|
||||
// The virtio-net header tells us how many descriptor chains this
|
||||
// packet is long.
|
||||
remainingChains = int(vnethdr.NumBuffers)
|
||||
numPackets := 0
|
||||
chainsIdx := 0
|
||||
for numPackets = 0; chainsIdx < len(chains); numPackets++ {
|
||||
if numPackets >= len(out) {
|
||||
//dev.extraRx = chains[chainsIdx:]
|
||||
//return numPackets, nil
|
||||
return numPackets, fmt.Errorf("dropping %d packets, no room", len(chains)-numPackets)
|
||||
}
|
||||
|
||||
//buffers = append(buffers, inBuffers...)
|
||||
numChains, err := dev.processChains(out[numPackets], chains[chainsIdx:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
chainsIdx += numChains
|
||||
}
|
||||
|
||||
// Copy all the buffers together to produce the complete packet slice.
|
||||
//out = out[:packetLength]
|
||||
//copied := 0
|
||||
//for _, buffer := range buffers {
|
||||
// copied += copy(out[copied:], buffer)
|
||||
//}
|
||||
//if copied != packetLength {
|
||||
// panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", packetLength, copied))
|
||||
//}
|
||||
|
||||
// Now that we have copied all buffers, we can free the used descriptor
|
||||
// chains again.
|
||||
// TODO: Recycling the descriptor chains would be more efficient than
|
||||
// freeing them just to offer them again right after.
|
||||
for _, head := range chainHeads {
|
||||
if err := dev.ReceiveQueue.FreeAndOfferDescriptorChains(head); err != nil {
|
||||
return 0, virtio.NetHdr{}, fmt.Errorf("free descriptor chain with head index %d: %w", head, err)
|
||||
}
|
||||
// Now that we have copied all buffers, we can recycle the used descriptor chains
|
||||
if err := dev.ReceiveQueue.RecycleDescriptorChains(chains); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
//if we don't churn chains, maybe we don't need this?
|
||||
// It's advised to always keep the receive queue fully populated with
|
||||
// available buffers which the device can write new packets into.
|
||||
// It's advised to always keep the rx queue fully populated with available buffers which the device can write new packets into.
|
||||
//if err := dev.refillReceiveQueue(); err != nil {
|
||||
// return 0, virtio.NetHdr{}, fmt.Errorf("refill receive queue: %w", err)
|
||||
//}
|
||||
|
||||
return packetLength, vnethdr, nil
|
||||
return numPackets, nil
|
||||
}
|
||||
|
||||
// TODO: Make above methods cancelable by taking a context.Context argument?
|
||||
// TODO: Implement zero-copy variants to transmit and receive packets?
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
// Package virtio contains some generic types and concepts related to the virtio
|
||||
// protocol.
|
||||
package virtio
|
||||
@@ -1,136 +0,0 @@
|
||||
package virtio
|
||||
|
||||
// Feature contains feature bits that describe a virtio device or driver.
|
||||
type Feature uint64
|
||||
|
||||
// Device-independent feature bits.
|
||||
//
|
||||
// Source: https://docs.oasis-open.org/virtio/virtio/v1.2/csd01/virtio-v1.2-csd01.html#x1-6600006
|
||||
const (
|
||||
// FeatureIndirectDescriptors indicates that the driver can use descriptors
|
||||
// with an additional layer of indirection.
|
||||
FeatureIndirectDescriptors Feature = 1 << 28
|
||||
|
||||
// FeatureVersion1 indicates compliance with version 1.0 of the virtio
|
||||
// specification.
|
||||
FeatureVersion1 Feature = 1 << 32
|
||||
)
|
||||
|
||||
// Feature bits for networking devices.
|
||||
//
|
||||
// Source: https://docs.oasis-open.org/virtio/virtio/v1.2/csd01/virtio-v1.2-csd01.html#x1-2200003
|
||||
const (
|
||||
// FeatureNetDeviceCsum indicates that the device can handle packets with
|
||||
// partial checksum (checksum offload).
|
||||
FeatureNetDeviceCsum Feature = 1 << 0
|
||||
|
||||
// FeatureNetDriverCsum indicates that the driver can handle packets with
|
||||
// partial checksum.
|
||||
FeatureNetDriverCsum Feature = 1 << 1
|
||||
|
||||
// FeatureNetCtrlDriverOffloads indicates support for dynamic offload state
|
||||
// reconfiguration.
|
||||
FeatureNetCtrlDriverOffloads Feature = 1 << 2
|
||||
|
||||
// FeatureNetMTU indicates that the device reports a maximum MTU value.
|
||||
FeatureNetMTU Feature = 1 << 3
|
||||
|
||||
// FeatureNetMAC indicates that the device provides a MAC address.
|
||||
FeatureNetMAC Feature = 1 << 5
|
||||
|
||||
// FeatureNetDriverTSO4 indicates that the driver supports the TCP
|
||||
// segmentation offload for received IPv4 packets.
|
||||
FeatureNetDriverTSO4 Feature = 1 << 7
|
||||
|
||||
// FeatureNetDriverTSO6 indicates that the driver supports the TCP
|
||||
// segmentation offload for received IPv6 packets.
|
||||
FeatureNetDriverTSO6 Feature = 1 << 8
|
||||
|
||||
// FeatureNetDriverECN indicates that the driver supports the TCP
|
||||
// segmentation offload with ECN for received packets.
|
||||
FeatureNetDriverECN Feature = 1 << 9
|
||||
|
||||
// FeatureNetDriverUFO indicates that the driver supports the UDP
|
||||
// fragmentation offload for received packets.
|
||||
FeatureNetDriverUFO Feature = 1 << 10
|
||||
|
||||
// FeatureNetDeviceTSO4 indicates that the device supports the TCP
|
||||
// segmentation offload for received IPv4 packets.
|
||||
FeatureNetDeviceTSO4 Feature = 1 << 11
|
||||
|
||||
// FeatureNetDeviceTSO6 indicates that the device supports the TCP
|
||||
// segmentation offload for received IPv6 packets.
|
||||
FeatureNetDeviceTSO6 Feature = 1 << 12
|
||||
|
||||
// FeatureNetDeviceECN indicates that the device supports the TCP
|
||||
// segmentation offload with ECN for received packets.
|
||||
FeatureNetDeviceECN Feature = 1 << 13
|
||||
|
||||
// FeatureNetDeviceUFO indicates that the device supports the UDP
|
||||
// fragmentation offload for received packets.
|
||||
FeatureNetDeviceUFO Feature = 1 << 14
|
||||
|
||||
// FeatureNetMergeRXBuffers indicates that the driver can handle merged
|
||||
// receive buffers.
|
||||
// When this feature is negotiated, devices may merge multiple descriptor
|
||||
// chains together to transport large received packets. [NetHdr.NumBuffers]
|
||||
// will then contain the number of merged descriptor chains.
|
||||
FeatureNetMergeRXBuffers Feature = 1 << 15
|
||||
|
||||
// FeatureNetStatus indicates that the device configuration status field is
|
||||
// available.
|
||||
FeatureNetStatus Feature = 1 << 16
|
||||
|
||||
// FeatureNetCtrlVQ indicates that a control channel virtqueue is
|
||||
// available.
|
||||
FeatureNetCtrlVQ Feature = 1 << 17
|
||||
|
||||
// FeatureNetCtrlRX indicates support for RX mode control (e.g. promiscuous
|
||||
// or all-multicast) for packet receive filtering.
|
||||
FeatureNetCtrlRX Feature = 1 << 18
|
||||
|
||||
// FeatureNetCtrlVLAN indicates support for VLAN filtering through the
|
||||
// control channel.
|
||||
FeatureNetCtrlVLAN Feature = 1 << 19
|
||||
|
||||
// FeatureNetDriverAnnounce indicates that the driver can send gratuitous
|
||||
// packets.
|
||||
FeatureNetDriverAnnounce Feature = 1 << 21
|
||||
|
||||
// FeatureNetMQ indicates that the device supports multiqueue with automatic
|
||||
// receive steering.
|
||||
FeatureNetMQ Feature = 1 << 22
|
||||
|
||||
// FeatureNetCtrlMACAddr indicates that the MAC address can be set through
|
||||
// the control channel.
|
||||
FeatureNetCtrlMACAddr Feature = 1 << 23
|
||||
|
||||
// FeatureNetDeviceUSO indicates that the device supports the UDP
|
||||
// segmentation offload for received packets.
|
||||
FeatureNetDeviceUSO Feature = 1 << 56
|
||||
|
||||
// FeatureNetHashReport indicates that the device can report a per-packet
|
||||
// hash value and type.
|
||||
FeatureNetHashReport Feature = 1 << 57
|
||||
|
||||
// FeatureNetDriverHdrLen indicates that the driver can provide the exact
|
||||
// header length value (see [NetHdr.HdrLen]).
|
||||
// Devices may benefit from knowing the exact header length.
|
||||
FeatureNetDriverHdrLen Feature = 1 << 59
|
||||
|
||||
// FeatureNetRSS indicates that the device supports RSS (receive-side
|
||||
// scaling) with configurable hash parameters.
|
||||
FeatureNetRSS Feature = 1 << 60
|
||||
|
||||
// FeatureNetRSCExt indicates that the device can process duplicated ACKs
|
||||
// and report the number of coalesced segments and duplicated ACKs.
|
||||
FeatureNetRSCExt Feature = 1 << 61
|
||||
|
||||
// FeatureNetStandby indicates that the device may act as a standby for a
|
||||
// primary device with the same MAC address.
|
||||
FeatureNetStandby Feature = 1 << 62
|
||||
|
||||
// FeatureNetSpeedDuplex indicates that the device can report link speed and
|
||||
// duplex mode.
|
||||
FeatureNetSpeedDuplex Feature = 1 << 63
|
||||
)
|
||||
@@ -1,77 +0,0 @@
|
||||
package virtio
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// Workaround to make Go doc links work.
|
||||
var _ unix.Errno
|
||||
|
||||
// NetHdrSize is the number of bytes needed to store a [NetHdr] in memory.
|
||||
const NetHdrSize = 12
|
||||
|
||||
// ErrNetHdrBufferTooSmall is returned when a buffer is too small to fit a
|
||||
// virtio_net_hdr.
|
||||
var ErrNetHdrBufferTooSmall = errors.New("the buffer is too small to fit a virtio_net_hdr")
|
||||
|
||||
// NetHdr defines the virtio_net_hdr as described by the virtio specification.
|
||||
type NetHdr struct {
|
||||
// Flags that describe the packet.
|
||||
// Possible values are:
|
||||
// - [unix.VIRTIO_NET_HDR_F_NEEDS_CSUM]
|
||||
// - [unix.VIRTIO_NET_HDR_F_DATA_VALID]
|
||||
// - [unix.VIRTIO_NET_HDR_F_RSC_INFO]
|
||||
Flags uint8
|
||||
// GSOType contains the type of segmentation offload that should be used for
|
||||
// the packet.
|
||||
// Possible values are:
|
||||
// - [unix.VIRTIO_NET_HDR_GSO_NONE]
|
||||
// - [unix.VIRTIO_NET_HDR_GSO_TCPV4]
|
||||
// - [unix.VIRTIO_NET_HDR_GSO_UDP]
|
||||
// - [unix.VIRTIO_NET_HDR_GSO_TCPV6]
|
||||
// - [unix.VIRTIO_NET_HDR_GSO_UDP_L4]
|
||||
// - [unix.VIRTIO_NET_HDR_GSO_ECN]
|
||||
GSOType uint8
|
||||
// HdrLen contains the length of the headers that need to be replicated by
|
||||
// segmentation offloads. It's the number of bytes from the beginning of the
|
||||
// packet to the beginning of the transport payload.
|
||||
// Only used when [FeatureNetDriverHdrLen] is negotiated.
|
||||
HdrLen uint16
|
||||
// GSOSize contains the maximum size of each segmented packet beyond the
|
||||
// header (payload size). In case of TCP, this is the MSS.
|
||||
GSOSize uint16
|
||||
// CsumStart contains the offset within the packet from which on the
|
||||
// checksum should be computed.
|
||||
CsumStart uint16
|
||||
// CsumOffset specifies how many bytes after [NetHdr.CsumStart] the computed
|
||||
// 16-bit checksum should be inserted.
|
||||
CsumOffset uint16
|
||||
// NumBuffers contains the number of merged descriptor chains when
|
||||
// [FeatureNetMergeRXBuffers] is negotiated.
|
||||
// This field is only used for packets received by the driver and should be
|
||||
// zero for transmitted packets.
|
||||
NumBuffers uint16
|
||||
}
|
||||
|
||||
// Decode decodes the [NetHdr] from the given byte slice. The slice must contain
|
||||
// at least [NetHdrSize] bytes.
|
||||
func (v *NetHdr) Decode(data []byte) error {
|
||||
if len(data) < NetHdrSize {
|
||||
return ErrNetHdrBufferTooSmall
|
||||
}
|
||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), NetHdrSize), data[:NetHdrSize])
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes the [NetHdr] into the given byte slice. The slice must have
|
||||
// room for at least [NetHdrSize] bytes.
|
||||
func (v *NetHdr) Encode(data []byte) error {
|
||||
if len(data) < NetHdrSize {
|
||||
return ErrNetHdrBufferTooSmall
|
||||
}
|
||||
copy(data[:NetHdrSize], unsafe.Slice((*byte)(unsafe.Pointer(v)), NetHdrSize))
|
||||
return nil
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
package virtio
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"unsafe"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func TestNetHdr_Size(t *testing.T) {
|
||||
assert.EqualValues(t, NetHdrSize, unsafe.Sizeof(NetHdr{}))
|
||||
}
|
||||
|
||||
func TestNetHdr_Encoding(t *testing.T) {
|
||||
vnethdr := NetHdr{
|
||||
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||||
GSOType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
|
||||
HdrLen: 42,
|
||||
GSOSize: 1472,
|
||||
CsumStart: 34,
|
||||
CsumOffset: 6,
|
||||
NumBuffers: 16,
|
||||
}
|
||||
|
||||
buf := make([]byte, NetHdrSize)
|
||||
require.NoError(t, vnethdr.Encode(buf))
|
||||
|
||||
assert.Equal(t, []byte{
|
||||
0x01, 0x05,
|
||||
0x2a, 0x00,
|
||||
0xc0, 0x05,
|
||||
0x22, 0x00,
|
||||
0x06, 0x00,
|
||||
0x10, 0x00,
|
||||
}, buf)
|
||||
|
||||
var decoded NetHdr
|
||||
require.NoError(t, decoded.Decode(buf))
|
||||
|
||||
assert.Equal(t, vnethdr, decoded)
|
||||
}
|
||||
@@ -82,22 +82,61 @@ func (r *AvailableRing) Address() uintptr {
|
||||
// offer adds the given descriptor chain heads to the available ring and
|
||||
// advances the ring index accordingly to make the device process the new
|
||||
// descriptor chains.
|
||||
func (r *AvailableRing) offer(chainHeads []uint16) {
|
||||
func (r *AvailableRing) offerElements(chains []UsedElement) {
|
||||
//always called under lock
|
||||
//r.mu.Lock()
|
||||
//defer r.mu.Unlock()
|
||||
|
||||
// Add descriptor chain heads to the ring.
|
||||
for offset, head := range chainHeads {
|
||||
for offset, x := range chains {
|
||||
// The 16-bit ring index may overflow. This is expected and is not an
|
||||
// issue because the size of the ring array (which equals the queue
|
||||
// size) is always a power of 2 and smaller than the highest possible
|
||||
// 16-bit value.
|
||||
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
|
||||
r.ring[insertIndex] = head
|
||||
r.ring[insertIndex] = x.GetHead()
|
||||
}
|
||||
|
||||
// Increase the ring index by the number of descriptor chains added to the
|
||||
// ring.
|
||||
*r.ringIndex += uint16(len(chainHeads))
|
||||
*r.ringIndex += uint16(len(chains))
|
||||
}
|
||||
|
||||
func (r *AvailableRing) offer(chains []uint16) {
|
||||
//always called under lock
|
||||
//r.mu.Lock()
|
||||
//defer r.mu.Unlock()
|
||||
|
||||
// Add descriptor chain heads to the ring.
|
||||
for offset, x := range chains {
|
||||
// The 16-bit ring index may overflow. This is expected and is not an
|
||||
// issue because the size of the ring array (which equals the queue
|
||||
// size) is always a power of 2 and smaller than the highest possible
|
||||
// 16-bit value.
|
||||
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
|
||||
r.ring[insertIndex] = x
|
||||
}
|
||||
|
||||
// Increase the ring index by the number of descriptor chains added to the
|
||||
// ring.
|
||||
*r.ringIndex += uint16(len(chains))
|
||||
}
|
||||
|
||||
func (r *AvailableRing) offerSingle(x uint16) {
|
||||
//always called under lock
|
||||
//r.mu.Lock()
|
||||
//defer r.mu.Unlock()
|
||||
|
||||
offset := 0
|
||||
// Add descriptor chain heads to the ring.
|
||||
|
||||
// The 16-bit ring index may overflow. This is expected and is not an
|
||||
// issue because the size of the ring array (which equals the queue
|
||||
// size) is always a power of 2 and smaller than the highest possible
|
||||
// 16-bit value.
|
||||
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
|
||||
r.ring[insertIndex] = x
|
||||
|
||||
// Increase the ring index by the number of descriptor chains added to the ring.
|
||||
*r.ringIndex += 1
|
||||
}
|
||||
|
||||
@@ -30,9 +30,9 @@ type SplitQueue struct {
|
||||
// chains and put them in the used ring.
|
||||
callEventFD eventfd.EventFD
|
||||
|
||||
// usedChains is a chanel that receives [UsedElement]s for descriptor chains
|
||||
// UsedChains is a chanel that receives [UsedElement]s for descriptor chains
|
||||
// that were used by the device.
|
||||
usedChains chan UsedElement
|
||||
UsedChains chan UsedElement
|
||||
|
||||
// moreFreeDescriptors is a channel that signals when any descriptors were
|
||||
// put back into the free chain of the descriptor table. This is used to
|
||||
@@ -51,6 +51,7 @@ type SplitQueue struct {
|
||||
itemSize int
|
||||
|
||||
epoll eventfd.Epoll
|
||||
more int
|
||||
}
|
||||
|
||||
// NewSplitQueue allocates a new [SplitQueue] in memory. The given queue size
|
||||
@@ -131,7 +132,7 @@ func NewSplitQueue(queueSize int) (_ *SplitQueue, err error) {
|
||||
}
|
||||
|
||||
// Initialize channels.
|
||||
sq.usedChains = make(chan UsedElement, queueSize)
|
||||
sq.UsedChains = make(chan UsedElement, queueSize)
|
||||
sq.moreFreeDescriptors = make(chan struct{})
|
||||
|
||||
sq.epoll, err = eventfd.NewEpoll()
|
||||
@@ -190,20 +191,6 @@ func (sq *SplitQueue) CallEventFD() int {
|
||||
return sq.callEventFD.FD()
|
||||
}
|
||||
|
||||
// UsedDescriptorChains returns the channel that receives [UsedElement]s for all
|
||||
// descriptor chains that were used by the device.
|
||||
//
|
||||
// Users of the [SplitQueue] should read from this channel, handle the used
|
||||
// descriptor chains and free them using [SplitQueue.FreeDescriptorChain] when
|
||||
// they're done with them. When this does not happen, the queue will run full
|
||||
// and any further calls to [SplitQueue.OfferDescriptorChain] will stall.
|
||||
//
|
||||
// When [SplitQueue.Close] is called, this channel will be closed as well.
|
||||
func (sq *SplitQueue) UsedDescriptorChains() chan UsedElement {
|
||||
sq.ensureInitialized()
|
||||
return sq.usedChains
|
||||
}
|
||||
|
||||
// startConsumeUsedRing starts a goroutine that runs [consumeUsedRing].
|
||||
// A function is returned that can be used to gracefully cancel it. todo rename
|
||||
func (sq *SplitQueue) startConsumeUsedRing() func() error {
|
||||
@@ -225,14 +212,49 @@ func (sq *SplitQueue) BlockAndGetHeads(ctx context.Context) ([]UsedElement, erro
|
||||
var n int
|
||||
var err error
|
||||
for ctx.Err() == nil {
|
||||
|
||||
// Wait for a signal from the device.
|
||||
if n, err = sq.epoll.Block(); err != nil {
|
||||
return nil, fmt.Errorf("wait: %w", err)
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
out := sq.usedRing.take()
|
||||
_ = sq.epoll.Clear() //???
|
||||
stillNeedToTake, out := sq.usedRing.take(-1)
|
||||
sq.more = stillNeedToTake
|
||||
if stillNeedToTake == 0 {
|
||||
_ = sq.epoll.Clear() //???
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
}
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) ([]UsedElement, error) {
|
||||
var n int
|
||||
var err error
|
||||
for ctx.Err() == nil {
|
||||
|
||||
//we have leftovers in the fridge
|
||||
if sq.more > 0 {
|
||||
stillNeedToTake, out := sq.usedRing.take(maxToTake)
|
||||
sq.more = stillNeedToTake
|
||||
if stillNeedToTake == 0 {
|
||||
_ = sq.epoll.Clear() //???
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Wait for a signal from the device.
|
||||
if n, err = sq.epoll.Block(); err != nil {
|
||||
return nil, fmt.Errorf("wait: %w", err)
|
||||
}
|
||||
if n > 0 {
|
||||
stillNeedToTake, out := sq.usedRing.take(maxToTake)
|
||||
sq.more = stillNeedToTake
|
||||
if stillNeedToTake == 0 {
|
||||
_ = sq.epoll.Clear() //???
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
}
|
||||
@@ -240,12 +262,6 @@ func (sq *SplitQueue) BlockAndGetHeads(ctx context.Context) ([]UsedElement, erro
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
// blockForMoreDescriptors blocks on a channel waiting for more descriptors to free up.
|
||||
// it is its own function so maybe it might show up in pprof
|
||||
func (sq *SplitQueue) blockForMoreDescriptors() {
|
||||
<-sq.moreFreeDescriptors
|
||||
}
|
||||
|
||||
// OfferDescriptorChain offers a descriptor chain to the device which contains a
|
||||
// number of device-readable buffers (out buffers) and device-writable buffers
|
||||
// (in buffers).
|
||||
@@ -271,63 +287,9 @@ func (sq *SplitQueue) blockForMoreDescriptors() {
|
||||
// used descriptor chains again using [SplitQueue.FreeDescriptorChain] when
|
||||
// they're done with them. When this does not happen, the queue will run full
|
||||
// and any further calls to [SplitQueue.OfferDescriptorChain] will stall.
|
||||
func (sq *SplitQueue) OfferDescriptorChain(outBuffers [][]byte, numInBuffers int, waitFree bool) (uint16, error) {
|
||||
|
||||
func (sq *SplitQueue) OfferInDescriptorChains(numInBuffers int) (uint16, error) {
|
||||
sq.ensureInitialized()
|
||||
|
||||
// TODO change this
|
||||
// Each descriptor can only hold a whole memory page, so split large out
|
||||
// buffers into multiple smaller ones.
|
||||
outBuffers = splitBuffers(outBuffers, sq.pageSize)
|
||||
|
||||
// Synchronize the offering of descriptor chains. While the descriptor table
|
||||
// and available ring are synchronized on their own as well, this does not
|
||||
// protect us from interleaved calls which could cause reordering.
|
||||
// By locking here, we can ensure that all descriptor chains are made
|
||||
// available to the device in the same order as this method was called.
|
||||
sq.offerMutex.Lock()
|
||||
defer sq.offerMutex.Unlock()
|
||||
|
||||
// Create a descriptor chain for the given buffers.
|
||||
var (
|
||||
head uint16
|
||||
err error
|
||||
)
|
||||
for {
|
||||
head, err = sq.descriptorTable.createDescriptorChain(outBuffers, numInBuffers)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
// I don't wanna use errors.Is, it's slow
|
||||
//goland:noinspection GoDirectComparisonOfErrors
|
||||
if err == ErrNotEnoughFreeDescriptors {
|
||||
if waitFree {
|
||||
// Wait for more free descriptors to be put back into the queue.
|
||||
// If the number of free descriptors is still not sufficient, we'll
|
||||
// land here again.
|
||||
sq.blockForMoreDescriptors()
|
||||
continue
|
||||
} else {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("create descriptor chain: %w", err)
|
||||
}
|
||||
|
||||
// Make the descriptor chain available to the device.
|
||||
sq.availableRing.offer([]uint16{head})
|
||||
|
||||
// Notify the device to make it process the updated available ring.
|
||||
if err := sq.kickEventFD.Kick(); err != nil {
|
||||
return head, fmt.Errorf("notify device: %w", err)
|
||||
}
|
||||
|
||||
return head, nil
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) OfferInDescriptorChains(numInBuffers int, waitFree bool) (uint16, error) {
|
||||
sq.ensureInitialized()
|
||||
|
||||
// Synchronize the offering of descriptor chains. While the descriptor table
|
||||
// and available ring are synchronized on their own as well, this does not
|
||||
// protect us from interleaved calls which could cause reordering.
|
||||
@@ -350,21 +312,14 @@ func (sq *SplitQueue) OfferInDescriptorChains(numInBuffers int, waitFree bool) (
|
||||
// I don't wanna use errors.Is, it's slow
|
||||
//goland:noinspection GoDirectComparisonOfErrors
|
||||
if err == ErrNotEnoughFreeDescriptors {
|
||||
if waitFree {
|
||||
// Wait for more free descriptors to be put back into the queue.
|
||||
// If the number of free descriptors is still not sufficient, we'll
|
||||
// land here again.
|
||||
sq.blockForMoreDescriptors()
|
||||
continue
|
||||
} else {
|
||||
return 0, err
|
||||
}
|
||||
return 0, err
|
||||
} else {
|
||||
return 0, fmt.Errorf("create descriptor chain: %w", err)
|
||||
}
|
||||
return 0, fmt.Errorf("create descriptor chain: %w", err)
|
||||
}
|
||||
|
||||
// Make the descriptor chain available to the device.
|
||||
sq.availableRing.offer([]uint16{head})
|
||||
sq.availableRing.offerSingle(head)
|
||||
|
||||
// Notify the device to make it process the updated available ring.
|
||||
if err := sq.kickEventFD.Kick(); err != nil {
|
||||
@@ -374,7 +329,7 @@ func (sq *SplitQueue) OfferInDescriptorChains(numInBuffers int, waitFree bool) (
|
||||
return head, nil
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) OfferOutDescriptorChains(prepend []byte, outBuffers [][]byte, waitFree bool) ([]uint16, error) {
|
||||
func (sq *SplitQueue) OfferOutDescriptorChains(prepend []byte, outBuffers [][]byte) ([]uint16, error) {
|
||||
sq.ensureInitialized()
|
||||
|
||||
// TODO change this
|
||||
@@ -408,15 +363,11 @@ func (sq *SplitQueue) OfferOutDescriptorChains(prepend []byte, outBuffers [][]by
|
||||
// I don't wanna use errors.Is, it's slow
|
||||
//goland:noinspection GoDirectComparisonOfErrors
|
||||
if err == ErrNotEnoughFreeDescriptors {
|
||||
if waitFree {
|
||||
// Wait for more free descriptors to be put back into the queue.
|
||||
// If the number of free descriptors is still not sufficient, we'll
|
||||
// land here again.
|
||||
sq.blockForMoreDescriptors()
|
||||
continue
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
// Wait for more free descriptors to be put back into the queue.
|
||||
// If the number of free descriptors is still not sufficient, we'll
|
||||
// land here again.
|
||||
<-sq.moreFreeDescriptors
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("create descriptor chain: %w", err)
|
||||
}
|
||||
@@ -473,7 +424,7 @@ func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
|
||||
|
||||
// There is more free room in the descriptor table now.
|
||||
// This is a fire-and-forget signal, so do not block when nobody listens.
|
||||
select {
|
||||
select { //todo eliminate
|
||||
case sq.moreFreeDescriptors <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
@@ -481,7 +432,7 @@ func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) FreeAndOfferDescriptorChains(head uint16) error {
|
||||
func (sq *SplitQueue) RecycleDescriptorChains(chains []UsedElement) error {
|
||||
sq.ensureInitialized()
|
||||
|
||||
//todo I don't think we need this here?
|
||||
@@ -500,7 +451,7 @@ func (sq *SplitQueue) FreeAndOfferDescriptorChains(head uint16) error {
|
||||
//}
|
||||
|
||||
// Make the descriptor chain available to the device.
|
||||
sq.availableRing.offer([]uint16{head})
|
||||
sq.availableRing.offerElements(chains)
|
||||
|
||||
// Notify the device to make it process the updated available ring.
|
||||
if err := sq.kickEventFD.Kick(); err != nil {
|
||||
@@ -524,7 +475,7 @@ func (sq *SplitQueue) Close() error {
|
||||
|
||||
// The stop function blocked until the goroutine ended, so the channel
|
||||
// can now safely be closed.
|
||||
close(sq.usedChains)
|
||||
close(sq.UsedChains)
|
||||
|
||||
// Make sure that this code block is executed only once.
|
||||
sq.stop = nil
|
||||
|
||||
@@ -15,3 +15,7 @@ type UsedElement struct {
|
||||
// the buffer described by the descriptor chain.
|
||||
Length uint32
|
||||
}
|
||||
|
||||
func (u *UsedElement) GetHead() uint16 {
|
||||
return uint16(u.DescriptorIndex)
|
||||
}
|
||||
|
||||
@@ -87,14 +87,14 @@ func (r *UsedRing) Address() uintptr {
|
||||
// take returns all new [UsedElement]s that the device put into the ring and
|
||||
// that weren't already returned by a previous call to this method.
|
||||
// had a lock, I removed it
|
||||
func (r *UsedRing) take() []UsedElement {
|
||||
func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
|
||||
//r.mu.Lock()
|
||||
//defer r.mu.Unlock()
|
||||
|
||||
ringIndex := *r.ringIndex
|
||||
if ringIndex == r.lastIndex {
|
||||
// Nothing new.
|
||||
return nil
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Calculate the number new used elements that we can read from the ring.
|
||||
@@ -104,6 +104,16 @@ func (r *UsedRing) take() []UsedElement {
|
||||
count += 0xffff
|
||||
}
|
||||
|
||||
stillNeedToTake := 0
|
||||
|
||||
if maxToTake > 0 {
|
||||
stillNeedToTake = count - maxToTake
|
||||
if stillNeedToTake < 0 {
|
||||
stillNeedToTake = 0
|
||||
}
|
||||
count = min(count, maxToTake)
|
||||
}
|
||||
|
||||
// The number of new elements can never exceed the queue size.
|
||||
if count > len(r.ring) {
|
||||
panic("used ring contains more new elements than the ring is long")
|
||||
@@ -115,5 +125,5 @@ func (r *UsedRing) take() []UsedElement {
|
||||
r.lastIndex++
|
||||
}
|
||||
|
||||
return elems
|
||||
return stillNeedToTake, elems
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user