broken chkpt

This commit is contained in:
JackDoan
2025-11-11 11:38:43 -06:00
parent c645a45438
commit e7f01390a3
8 changed files with 271 additions and 113 deletions

View File

@@ -18,6 +18,7 @@ import (
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/overlay/virtio"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/udp"
)
@@ -308,18 +309,31 @@ func (f *Interface) listenOut(q int) {
})
}
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
runtime.LockOSThread()
packet := make([]byte, mtu)
const batch = 64
originalPackets := make([][]byte, batch) //todo batch config
for i := 0; i < batch; i++ {
originalPackets[i] = make([]byte, 0xffff)
}
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
queues := reader.GetQueues()
if len(queues) == 0 {
f.l.Fatal("Failed to get queues")
}
queue := queues[0]
for {
n, err := reader.Read(packet)
n, err := reader.ReadMany(originalPacket)
//todo!!
pkt := originalPacket[virtio.NetHdrSize : n+virtio.NetHdrSize]
if err != nil {
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return
@@ -330,7 +344,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
os.Exit(2)
}
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
f.consumeInsidePacket(pkt, fwPacket, nb, out, queueNum, conntrackCache.Get(f.l))
}
}

View File

@@ -2,20 +2,21 @@ package overlay
import (
"fmt"
"io"
"net"
"net/netip"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/virtqueue"
"github.com/slackhq/nebula/util"
)
const DefaultMTU = 1300
type TunDev interface {
io.ReadWriteCloser
ReadMany([][]byte) (int, error)
WriteMany([][]byte) (int, error)
GetQueues() []*virtqueue.SplitQueue
}
// TODO: We may be able to remove routines

View File

@@ -9,6 +9,7 @@ import (
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/overlay/virtqueue"
"github.com/slackhq/nebula/routing"
)
@@ -40,6 +41,10 @@ func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled boo
return tun
}
func (*disabledTun) GetQueues() []*virtqueue.SplitQueue {
return nil
}
func (*disabledTun) Activate() error {
return nil
}
@@ -117,6 +122,10 @@ func (t *disabledTun) WriteMany(b [][]byte) (int, error) {
return out, nil
}
func (t *disabledTun) ReadMany(b [][]byte) (int, error) {
return t.Read(b[0])
}
func (t *disabledTun) NewMultiQueueReader() (TunDev, error) {
return t, nil
}

View File

@@ -5,7 +5,6 @@ package overlay
import (
"fmt"
"io"
"net"
"net/netip"
"os"
@@ -20,6 +19,7 @@ import (
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/vhostnet"
"github.com/slackhq/nebula/overlay/virtio"
"github.com/slackhq/nebula/overlay/virtqueue"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
"github.com/vishvananda/netlink"
@@ -27,7 +27,7 @@ import (
)
type tun struct {
io.ReadWriteCloser
file *os.File
fd int
vdev *vhostnet.Device
Device string
@@ -51,6 +51,10 @@ func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *tun) GetQueues() []*virtqueue.SplitQueue {
return []*virtqueue.SplitQueue{t.vdev.ReceiveQueue, t.vdev.TransmitQueue}
}
type ifReq struct {
Name [16]byte
Flags uint16
@@ -129,8 +133,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
}
flags := 0
//flags := unix.TUN_F_CSUM
//|unix.TUN_F_USO4|unix.TUN_F_USO6
//flags = //unix.TUN_F_CSUM //| unix.TUN_F_TSO4 | unix.TUN_F_USO4 | unix.TUN_F_TSO6 | unix.TUN_F_USO6
err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, flags)
if err != nil {
return nil, fmt.Errorf("set offloads: %w", err)
@@ -168,7 +171,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
t := &tun{
ReadWriteCloser: file,
file: file,
fd: int(file.Fd()),
vpnNetworks: vpnNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500),
@@ -699,8 +702,8 @@ func (t *tun) Close() error {
_ = t.vdev.Close()
}
if t.ReadWriteCloser != nil {
_ = t.ReadWriteCloser.Close()
if t.file != nil {
_ = t.file.Close()
}
if t.ioctlFd > 0 {
@@ -710,17 +713,17 @@ func (t *tun) Close() error {
return nil
}
func (t *tun) Read(p []byte) (int, error) {
hdr, out, err := t.vdev.ReceivePacket() //we are TXing
func (t *tun) ReadMany(p [][]byte) (int, error) {
//todo call consumeUsedRing here instead of its own thread
n, hdr, err := t.vdev.ReceivePacket(p) //we are TXing
if err != nil {
return 0, err
}
if hdr.NumBuffers > 1 {
t.l.WithField("num_buffers", hdr.NumBuffers).Info("wow, lots to TX from tun")
}
p = p[:len(out)]
copy(p, out)
return len(out), nil
return n, nil
}
func (t *tun) Write(b []byte) (int, error) {

View File

@@ -6,6 +6,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/virtqueue"
"github.com/slackhq/nebula/routing"
)
@@ -66,6 +67,10 @@ func (d *UserDevice) Close() error {
return nil
}
func (d *UserDevice) ReadMany(b [][]byte) (int, error) {
return d.Read(b[0])
}
func (d *UserDevice) WriteMany(b [][]byte) (int, error) {
out := 0
for i := range b {
@@ -77,3 +82,7 @@ func (d *UserDevice) WriteMany(b [][]byte) (int, error) {
}
return out, nil
}
func (*UserDevice) GetQueues() []*virtqueue.SplitQueue {
return nil
}

View File

@@ -28,8 +28,8 @@ type Device struct {
initialized bool
controlFD int
receiveQueue *virtqueue.SplitQueue
transmitQueue *virtqueue.SplitQueue
ReceiveQueue *virtqueue.SplitQueue
TransmitQueue *virtqueue.SplitQueue
// transmitted contains channels for each possible descriptor chain head
// index. This is used for packet transmit notifications.
@@ -96,17 +96,17 @@ func NewDevice(options ...Option) (*Device, error) {
}
// Initialize and register the queues needed for the networking device.
if dev.receiveQueue, err = createQueue(dev.controlFD, receiveQueueIndex, opts.queueSize); err != nil {
if dev.ReceiveQueue, err = createQueue(dev.controlFD, receiveQueueIndex, opts.queueSize); err != nil {
return nil, fmt.Errorf("create receive queue: %w", err)
}
if dev.transmitQueue, err = createQueue(dev.controlFD, transmitQueueIndex, opts.queueSize); err != nil {
if dev.TransmitQueue, err = createQueue(dev.controlFD, transmitQueueIndex, opts.queueSize); err != nil {
return nil, fmt.Errorf("create transmit queue: %w", err)
}
// Set up memory mappings for all buffers used by the queues. This has to
// happen before a backend for the queues can be registered.
memoryLayout := vhost.NewMemoryLayoutForQueues(
[]*virtqueue.SplitQueue{dev.receiveQueue, dev.transmitQueue},
[]*virtqueue.SplitQueue{dev.ReceiveQueue, dev.TransmitQueue},
)
if err = vhost.SetMemoryLayout(dev.controlFD, memoryLayout); err != nil {
return nil, fmt.Errorf("setup memory layout: %w", err)
@@ -127,7 +127,7 @@ func NewDevice(options ...Option) (*Device, error) {
}
// Initialize channels for transmit notifications.
dev.transmitted = make([]chan virtqueue.UsedElement, dev.transmitQueue.Size())
dev.transmitted = make([]chan virtqueue.UsedElement, dev.TransmitQueue.Size())
for i := range len(dev.transmitted) {
// It is important to use a single-element buffered channel here.
// When the channel was unbuffered and the monitorTransmitQueue
@@ -159,7 +159,7 @@ func NewDevice(options ...Option) (*Device, error) {
// in the transmit queue and produces a transmit notification via the
// corresponding channel.
func (dev *Device) monitorTransmitQueue() {
usedChan := dev.transmitQueue.UsedDescriptorChains()
usedChan := dev.TransmitQueue.UsedDescriptorChains()
for {
used, ok := <-usedChan
if !ok {
@@ -180,7 +180,7 @@ func (dev *Device) monitorTransmitQueue() {
// packets.
func (dev *Device) refillReceiveQueue() error {
for {
_, err := dev.receiveQueue.OfferDescriptorChain(nil, 1, false)
_, err := dev.ReceiveQueue.OfferDescriptorChain(nil, 1, false)
if err != nil {
if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
// Queue is full, job is done.
@@ -212,17 +212,17 @@ func (dev *Device) Close() error {
var errs []error
if dev.receiveQueue != nil {
if err := dev.receiveQueue.Close(); err == nil {
dev.receiveQueue = nil
if dev.ReceiveQueue != nil {
if err := dev.ReceiveQueue.Close(); err == nil {
dev.ReceiveQueue = nil
} else {
errs = append(errs, fmt.Errorf("close receive queue: %w", err))
}
}
if dev.transmitQueue != nil {
if err := dev.transmitQueue.Close(); err == nil {
dev.transmitQueue = nil
if dev.TransmitQueue != nil {
if err := dev.TransmitQueue.Close(); err == nil {
dev.TransmitQueue = nil
} else {
errs = append(errs, fmt.Errorf("close transmit queue: %w", err))
}
@@ -296,7 +296,7 @@ func (dev *Device) TransmitPacket(vnethdr virtio.NetHdr, packet []byte) error {
outBuffers := [][]byte{vnethdrBuf, packet}
//outBuffers := [][]byte{packet}
chainIndex, err := dev.transmitQueue.OfferDescriptorChain(outBuffers, 0, true)
chainIndex, err := dev.TransmitQueue.OfferDescriptorChain(outBuffers, 0, true)
if err != nil {
return fmt.Errorf("offer descriptor chain: %w", err)
}
@@ -304,7 +304,7 @@ func (dev *Device) TransmitPacket(vnethdr virtio.NetHdr, packet []byte) error {
// Wait for the packet to have been transmitted.
<-dev.transmitted[chainIndex]
if err = dev.transmitQueue.FreeDescriptorChain(chainIndex); err != nil {
if err = dev.TransmitQueue.FreeDescriptorChain(chainIndex); err != nil {
return fmt.Errorf("free descriptor chain: %w", err)
}
@@ -320,7 +320,7 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
vnethdrBuf[virtio.NetHdrSize+14-2] = 0x86
vnethdrBuf[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype
chainIndexes, err := dev.transmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets, true)
chainIndexes, err := dev.TransmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets, true)
if err != nil {
return fmt.Errorf("offer descriptor chain: %w", err)
}
@@ -330,7 +330,7 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
for i := range chainIndexes {
<-dev.transmitted[chainIndexes[i]]
if err = dev.transmitQueue.FreeDescriptorChain(chainIndexes[i]); err != nil {
if err = dev.TransmitQueue.FreeDescriptorChain(chainIndexes[i]); err != nil {
return fmt.Errorf("free descriptor chain: %w", err)
}
}
@@ -346,7 +346,7 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
// When this method returns an error, the receive queue will likely be in a
// broken state which this implementation cannot recover from. The caller should
// close the device and not attempt any additional receives.
func (dev *Device) ReceivePacket() (virtio.NetHdr, []byte, error) {
func (dev *Device) ReceivePacket(out []byte) (int, virtio.NetHdr, error) {
var (
chainHeads []uint16
@@ -358,41 +358,30 @@ func (dev *Device) ReceivePacket() (virtio.NetHdr, []byte, error) {
packetLength = -virtio.NetHdrSize
)
lenRead := 0
// We presented FeatureNetMergeRXBuffers to the device, so one packet may be
// made of multiple descriptor chains which are to be merged.
for remainingChains := 1; remainingChains > 0; remainingChains-- {
// Get the next descriptor chain.
usedElement, ok := <-dev.receiveQueue.UsedDescriptorChains()
usedElement, ok := <-dev.ReceiveQueue.UsedDescriptorChains()
if !ok {
return virtio.NetHdr{}, nil, ErrDeviceClosed
return 0, virtio.NetHdr{}, ErrDeviceClosed
}
// Track this chain to be freed later.
head := uint16(usedElement.DescriptorIndex)
chainHeads = append(chainHeads, head)
outBuffers, inBuffers, err := dev.receiveQueue.GetDescriptorChain(head)
n, err := dev.ReceiveQueue.GetDescriptorChainContents(head, out[lenRead:])
if err != nil {
// When this fails we may miss to free some descriptor chains. We
// could try to mitigate this by deferring the freeing somehow, but
// it's not worth the hassle. When this method fails, the queue will
// be in a broken state anyway.
return virtio.NetHdr{}, nil, fmt.Errorf("get descriptor chain: %w", err)
return 0, virtio.NetHdr{}, fmt.Errorf("get descriptor chain: %w", err)
}
if len(outBuffers) > 0 {
// How did this happen!?
panic("receive queue contains device-readable buffers")
}
if len(inBuffers) == 0 {
// Empty descriptor chains should not be possible.
panic("descriptor chain contains no buffers")
}
// The device tells us how many bytes of the descriptor chain it has
// actually written to. The specification forces the device to fully
// fill up all but the last descriptor chain when multiple descriptor
// chains are being merged, but being more compatible here doesn't hurt.
inBuffers = truncateBuffers(inBuffers, int(usedElement.Length))
lenRead += n
packetLength += int(usedElement.Length)
// Is this the first descriptor chain we process?
@@ -403,49 +392,51 @@ func (dev *Device) ReceivePacket() (virtio.NetHdr, []byte, error) {
// descriptor chain, but it is reasonable to assume that this is
// always the case.
// The decode method already does the buffer length check.
if err = vnethdr.Decode(inBuffers[0]); err != nil {
if err = vnethdr.Decode(out[0:]); err != nil {
// The device misbehaved. There is no way we can gracefully
// recover from this, because we don't know how many of the
// following descriptor chains belong to this packet.
return virtio.NetHdr{}, nil, fmt.Errorf("decode vnethdr: %w", err)
return 0, virtio.NetHdr{}, fmt.Errorf("decode vnethdr: %w", err)
}
inBuffers[0] = inBuffers[0][virtio.NetHdrSize:]
lenRead = 0
out = out[virtio.NetHdrSize:]
// The virtio-net header tells us how many descriptor chains this
// packet is long.
remainingChains = int(vnethdr.NumBuffers)
}
buffers = append(buffers, inBuffers...)
//buffers = append(buffers, inBuffers...)
}
// Copy all the buffers together to produce the complete packet slice.
packet := make([]byte, packetLength)
copied := 0
for _, buffer := range buffers {
copied += copy(packet[copied:], buffer)
}
if copied != packetLength {
panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", packetLength, copied))
}
//out = out[:packetLength]
//copied := 0
//for _, buffer := range buffers {
// copied += copy(out[copied:], buffer)
//}
//if copied != packetLength {
// panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", packetLength, copied))
//}
// Now that we have copied all buffers, we can free the used descriptor
// chains again.
// TODO: Recycling the descriptor chains would be more efficient than
// freeing them just to offer them again right after.
for _, head := range chainHeads {
if err := dev.receiveQueue.FreeDescriptorChain(head); err != nil {
return virtio.NetHdr{}, nil, fmt.Errorf("free descriptor chain with head index %d: %w", head, err)
if err := dev.ReceiveQueue.FreeAndOfferDescriptorChains(head); err != nil {
return 0, virtio.NetHdr{}, fmt.Errorf("free descriptor chain with head index %d: %w", head, err)
}
}
//if we don't churn chains, maybe we don't need this?
// It's advised to always keep the receive queue fully populated with
// available buffers which the device can write new packets into.
if err := dev.refillReceiveQueue(); err != nil {
return virtio.NetHdr{}, nil, fmt.Errorf("refill receive queue: %w", err)
}
//if err := dev.refillReceiveQueue(); err != nil {
// return 0, virtio.NetHdr{}, fmt.Errorf("refill receive queue: %w", err)
//}
return vnethdr, packet, nil
return packetLength, vnethdr, nil
}
// TODO: Make above methods cancelable by taking a context.Context argument?

View File

@@ -349,6 +349,74 @@ func (dt *DescriptorTable) getDescriptorChain(head uint16) (outBuffers, inBuffer
return
}
func (dt *DescriptorTable) getDescriptorChainContents(head uint16, out []byte) (int, error) {
if int(head) > len(dt.descriptors) {
return 0, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
}
dt.mu.Lock()
defer dt.mu.Unlock()
// Iterate over the chain. The iteration is limited to the queue size to
// avoid ending up in an endless loop when things go very wrong.
length := 0
//find length
next := head
for range len(dt.descriptors) {
if next == dt.freeHeadIndex {
return 0, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
}
desc := &dt.descriptors[next]
if desc.flags&descriptorFlagWritable == 0 {
return 0, fmt.Errorf("receive queue contains device-readable buffer")
}
length += int(desc.length)
// Is this the tail of the chain?
if desc.flags&descriptorFlagHasNext == 0 {
break
}
// Detect loops.
if desc.next == head {
return 0, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
}
next = desc.next
}
//set out to length:
out = out[:length]
//now do the copying
copied := 0
for range len(dt.descriptors) {
desc := &dt.descriptors[next]
// The descriptor address points to memory not managed by Go, so this
// conversion is safe. See https://github.com/golang/go/issues/58625
//goland:noinspection GoVetUnsafePointer
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
copied += copy(out[copied:], bs)
// Is this the tail of the chain?
if desc.flags&descriptorFlagHasNext == 0 {
break
}
// we did this already, no need to detect loops.
next = desc.next
}
if copied != length {
panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", length, copied))
}
return length, nil
}
// freeDescriptorChain can be used to free a descriptor chain when it is no
// longer in use. The descriptor chain that starts with the given index will be
// put back into the free chain, so the descriptors can be used for later calls

View File

@@ -49,6 +49,8 @@ type SplitQueue struct {
offerMutex sync.Mutex
pageSize int
itemSize int
epoll eventfd.Epoll
}
// NewSplitQueue allocates a new [SplitQueue] in memory. The given queue size
@@ -132,6 +134,15 @@ func NewSplitQueue(queueSize int) (_ *SplitQueue, err error) {
sq.usedChains = make(chan UsedElement, queueSize)
sq.moreFreeDescriptors = make(chan struct{})
sq.epoll, err = eventfd.NewEpoll()
if err != nil {
return nil, err
}
err = sq.epoll.AddEvent(sq.callEventFD.FD())
if err != nil {
return nil, err
}
// Consume used buffer notifications in the background.
sq.stop = sq.startConsumeUsedRing()
@@ -194,25 +205,9 @@ func (sq *SplitQueue) UsedDescriptorChains() chan UsedElement {
}
// startConsumeUsedRing starts a goroutine that runs [consumeUsedRing].
// A function is returned that can be used to gracefully cancel it.
// A function is returned that can be used to gracefully cancel it. todo rename
func (sq *SplitQueue) startConsumeUsedRing() func() error {
ctx, cancel := context.WithCancel(context.Background())
done := make(chan error)
ep, err := eventfd.NewEpoll()
if err != nil {
panic(err)
}
err = ep.AddEvent(sq.callEventFD.FD())
if err != nil {
panic(err)
}
go func() {
done <- sq.consumeUsedRing(ctx, &ep)
}()
return func() error {
cancel()
// The goroutine blocks until it receives a signal on the event file
// descriptor, so it will never notice the context being canceled.
@@ -221,43 +216,28 @@ func (sq *SplitQueue) startConsumeUsedRing() func() error {
if err := sq.callEventFD.Kick(); err != nil {
return fmt.Errorf("wake up goroutine: %w", err)
}
// Wait for the goroutine to end. This prevents the event file
// descriptor to be closed while it's still being used.
// If the goroutine failed, this is the last chance to propagate the
// error so it at least doesn't go unnoticed, even though the error may
// be older already.
if err := <-done; err != nil {
return fmt.Errorf("goroutine: consume used ring: %w", err)
}
return nil
}
}
// consumeUsedRing runs in a goroutine, waits for the device to signal that it
// has used descriptor chains and puts all new [UsedElement]s into the channel
// for them.
func (sq *SplitQueue) consumeUsedRing(ctx context.Context, epoll *eventfd.Epoll) error {
// BlockAndGetHeads waits for the device to signal that it has used descriptor chains and returns all [UsedElement]s
func (sq *SplitQueue) BlockAndGetHeads(ctx context.Context) ([]UsedElement, error) {
var n int
var err error
for ctx.Err() == nil {
// Wait for a signal from the device.
if n, err = epoll.Block(); err != nil {
return fmt.Errorf("wait: %w", err)
if n, err = sq.epoll.Block(); err != nil {
return nil, fmt.Errorf("wait: %w", err)
}
if n > 0 {
_ = epoll.Clear() //???
// Process all new used elements.
for _, usedElement := range sq.usedRing.take() {
sq.usedChains <- usedElement
}
out := sq.usedRing.take()
_ = sq.epoll.Clear() //???
return out, nil
}
}
return nil
return nil, ctx.Err()
}
// blockForMoreDescriptors blocks on a channel waiting for more descriptors to free up.
@@ -345,6 +325,55 @@ func (sq *SplitQueue) OfferDescriptorChain(outBuffers [][]byte, numInBuffers int
return head, nil
}
func (sq *SplitQueue) OfferInDescriptorChains(numInBuffers int, waitFree bool) (uint16, error) {
sq.ensureInitialized()
// Synchronize the offering of descriptor chains. While the descriptor table
// and available ring are synchronized on their own as well, this does not
// protect us from interleaved calls which could cause reordering.
// By locking here, we can ensure that all descriptor chains are made
// available to the device in the same order as this method was called.
sq.offerMutex.Lock()
defer sq.offerMutex.Unlock()
// Create a descriptor chain for the given buffers.
var (
head uint16
err error
)
for {
head, err = sq.descriptorTable.createDescriptorChain(nil, numInBuffers)
if err == nil {
break
}
// I don't wanna use errors.Is, it's slow
//goland:noinspection GoDirectComparisonOfErrors
if err == ErrNotEnoughFreeDescriptors {
if waitFree {
// Wait for more free descriptors to be put back into the queue.
// If the number of free descriptors is still not sufficient, we'll
// land here again.
sq.blockForMoreDescriptors()
continue
} else {
return 0, err
}
}
return 0, fmt.Errorf("create descriptor chain: %w", err)
}
// Make the descriptor chain available to the device.
sq.availableRing.offer([]uint16{head})
// Notify the device to make it process the updated available ring.
if err := sq.kickEventFD.Kick(); err != nil {
return head, fmt.Errorf("notify device: %w", err)
}
return head, nil
}
func (sq *SplitQueue) OfferOutDescriptorChains(prepend []byte, outBuffers [][]byte, waitFree bool) ([]uint16, error) {
sq.ensureInitialized()
@@ -420,6 +449,11 @@ func (sq *SplitQueue) GetDescriptorChain(head uint16) (outBuffers, inBuffers [][
return sq.descriptorTable.getDescriptorChain(head)
}
func (sq *SplitQueue) GetDescriptorChainContents(head uint16, out []byte) (int, error) {
sq.ensureInitialized()
return sq.descriptorTable.getDescriptorChainContents(head, out)
}
// FreeDescriptorChain frees the descriptor chain with the given head index.
// The head index must be one that was returned by a previous call to
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
@@ -447,6 +481,35 @@ func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
return nil
}
func (sq *SplitQueue) FreeAndOfferDescriptorChains(head uint16) error {
sq.ensureInitialized()
//todo I don't think we need this here?
// Synchronize the offering of descriptor chains. While the descriptor table
// and available ring are synchronized on their own as well, this does not
// protect us from interleaved calls which could cause reordering.
// By locking here, we can ensure that all descriptor chains are made
// available to the device in the same order as this method was called.
//sq.offerMutex.Lock()
//defer sq.offerMutex.Unlock()
//todo not doing this may break eventually?
//not called under lock
//if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
// return fmt.Errorf("free: %w", err)
//}
// Make the descriptor chain available to the device.
sq.availableRing.offer([]uint16{head})
// Notify the device to make it process the updated available ring.
if err := sq.kickEventFD.Kick(); err != nil {
return fmt.Errorf("notify device: %w", err)
}
return nil
}
// Close releases all resources used for this queue.
// The implementation will try to release as many resources as possible and
// collect potential errors before returning them.