From e7f01390a3a52a0118638e56e1fca27cb623e037 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Tue, 11 Nov 2025 11:38:43 -0600 Subject: [PATCH] broken chkpt --- interface.go | 22 +++- overlay/tun.go | 5 +- overlay/tun_disabled.go | 9 ++ overlay/tun_linux.go | 27 ++--- overlay/user.go | 9 ++ overlay/vhostnet/device.go | 101 +++++++++--------- overlay/virtqueue/descriptor_table.go | 68 ++++++++++++ overlay/virtqueue/split_virtqueue.go | 143 +++++++++++++++++++------- 8 files changed, 271 insertions(+), 113 deletions(-) diff --git a/interface.go b/interface.go index 4fcfd67..2b13931 100644 --- a/interface.go +++ b/interface.go @@ -18,6 +18,7 @@ import ( "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/overlay" + "github.com/slackhq/nebula/overlay/virtio" "github.com/slackhq/nebula/packet" "github.com/slackhq/nebula/udp" ) @@ -308,18 +309,31 @@ func (f *Interface) listenOut(q int) { }) } -func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { +func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) { runtime.LockOSThread() - packet := make([]byte, mtu) + const batch = 64 + originalPackets := make([][]byte, batch) //todo batch config + for i := 0; i < batch; i++ { + originalPackets[i] = make([]byte, 0xffff) + } out := make([]byte, mtu) fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) + queues := reader.GetQueues() + if len(queues) == 0 { + f.l.Fatal("Failed to get queues") + } + queue := queues[0] + for { - n, err := reader.Read(packet) + + n, err := reader.ReadMany(originalPacket) + //todo!! + pkt := originalPacket[virtio.NetHdrSize : n+virtio.NetHdrSize] if err != nil { if errors.Is(err, os.ErrClosed) && f.closed.Load() { return @@ -330,7 +344,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { os.Exit(2) } - f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) + f.consumeInsidePacket(pkt, fwPacket, nb, out, queueNum, conntrackCache.Get(f.l)) } } diff --git a/overlay/tun.go b/overlay/tun.go index 7c84d97..b58d6a8 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -2,20 +2,21 @@ package overlay import ( "fmt" - "io" "net" "net/netip" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/virtqueue" "github.com/slackhq/nebula/util" ) const DefaultMTU = 1300 type TunDev interface { - io.ReadWriteCloser + ReadMany([][]byte) (int, error) WriteMany([][]byte) (int, error) + GetQueues() []*virtqueue.SplitQueue } // TODO: We may be able to remove routines diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index 875fa3c..f2e9c6b 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -9,6 +9,7 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/overlay/virtqueue" "github.com/slackhq/nebula/routing" ) @@ -40,6 +41,10 @@ func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled boo return tun } +func (*disabledTun) GetQueues() []*virtqueue.SplitQueue { + return nil +} + func (*disabledTun) Activate() error { return nil } @@ -117,6 +122,10 @@ func (t *disabledTun) WriteMany(b [][]byte) (int, error) { return out, nil } +func (t *disabledTun) ReadMany(b [][]byte) (int, error) { + return t.Read(b[0]) +} + func (t *disabledTun) NewMultiQueueReader() (TunDev, error) { return t, nil } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 722091c..f663ee0 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -5,7 +5,6 @@ package overlay import ( "fmt" - "io" "net" "net/netip" "os" @@ -20,6 +19,7 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/overlay/vhostnet" "github.com/slackhq/nebula/overlay/virtio" + "github.com/slackhq/nebula/overlay/virtqueue" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" "github.com/vishvananda/netlink" @@ -27,7 +27,7 @@ import ( ) type tun struct { - io.ReadWriteCloser + file *os.File fd int vdev *vhostnet.Device Device string @@ -51,6 +51,10 @@ func (t *tun) Networks() []netip.Prefix { return t.vpnNetworks } +func (t *tun) GetQueues() []*virtqueue.SplitQueue { + return []*virtqueue.SplitQueue{t.vdev.ReceiveQueue, t.vdev.TransmitQueue} +} + type ifReq struct { Name [16]byte Flags uint16 @@ -129,8 +133,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu } flags := 0 - //flags := unix.TUN_F_CSUM - //|unix.TUN_F_USO4|unix.TUN_F_USO6 + //flags = //unix.TUN_F_CSUM //| unix.TUN_F_TSO4 | unix.TUN_F_USO4 | unix.TUN_F_TSO6 | unix.TUN_F_USO6 err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, flags) if err != nil { return nil, fmt.Errorf("set offloads: %w", err) @@ -168,7 +171,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) { t := &tun{ - ReadWriteCloser: file, + file: file, fd: int(file.Fd()), vpnNetworks: vpnNetworks, TXQueueLen: c.GetInt("tun.tx_queue", 500), @@ -699,8 +702,8 @@ func (t *tun) Close() error { _ = t.vdev.Close() } - if t.ReadWriteCloser != nil { - _ = t.ReadWriteCloser.Close() + if t.file != nil { + _ = t.file.Close() } if t.ioctlFd > 0 { @@ -710,17 +713,17 @@ func (t *tun) Close() error { return nil } -func (t *tun) Read(p []byte) (int, error) { - hdr, out, err := t.vdev.ReceivePacket() //we are TXing +func (t *tun) ReadMany(p [][]byte) (int, error) { + //todo call consumeUsedRing here instead of its own thread + + n, hdr, err := t.vdev.ReceivePacket(p) //we are TXing if err != nil { return 0, err } if hdr.NumBuffers > 1 { t.l.WithField("num_buffers", hdr.NumBuffers).Info("wow, lots to TX from tun") } - p = p[:len(out)] - copy(p, out) - return len(out), nil + return n, nil } func (t *tun) Write(b []byte) (int, error) { diff --git a/overlay/user.go b/overlay/user.go index a1a937c..34b359f 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -6,6 +6,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/virtqueue" "github.com/slackhq/nebula/routing" ) @@ -66,6 +67,10 @@ func (d *UserDevice) Close() error { return nil } +func (d *UserDevice) ReadMany(b [][]byte) (int, error) { + return d.Read(b[0]) +} + func (d *UserDevice) WriteMany(b [][]byte) (int, error) { out := 0 for i := range b { @@ -77,3 +82,7 @@ func (d *UserDevice) WriteMany(b [][]byte) (int, error) { } return out, nil } + +func (*UserDevice) GetQueues() []*virtqueue.SplitQueue { + return nil +} diff --git a/overlay/vhostnet/device.go b/overlay/vhostnet/device.go index 53f0308..f4f132c 100644 --- a/overlay/vhostnet/device.go +++ b/overlay/vhostnet/device.go @@ -28,8 +28,8 @@ type Device struct { initialized bool controlFD int - receiveQueue *virtqueue.SplitQueue - transmitQueue *virtqueue.SplitQueue + ReceiveQueue *virtqueue.SplitQueue + TransmitQueue *virtqueue.SplitQueue // transmitted contains channels for each possible descriptor chain head // index. This is used for packet transmit notifications. @@ -96,17 +96,17 @@ func NewDevice(options ...Option) (*Device, error) { } // Initialize and register the queues needed for the networking device. - if dev.receiveQueue, err = createQueue(dev.controlFD, receiveQueueIndex, opts.queueSize); err != nil { + if dev.ReceiveQueue, err = createQueue(dev.controlFD, receiveQueueIndex, opts.queueSize); err != nil { return nil, fmt.Errorf("create receive queue: %w", err) } - if dev.transmitQueue, err = createQueue(dev.controlFD, transmitQueueIndex, opts.queueSize); err != nil { + if dev.TransmitQueue, err = createQueue(dev.controlFD, transmitQueueIndex, opts.queueSize); err != nil { return nil, fmt.Errorf("create transmit queue: %w", err) } // Set up memory mappings for all buffers used by the queues. This has to // happen before a backend for the queues can be registered. memoryLayout := vhost.NewMemoryLayoutForQueues( - []*virtqueue.SplitQueue{dev.receiveQueue, dev.transmitQueue}, + []*virtqueue.SplitQueue{dev.ReceiveQueue, dev.TransmitQueue}, ) if err = vhost.SetMemoryLayout(dev.controlFD, memoryLayout); err != nil { return nil, fmt.Errorf("setup memory layout: %w", err) @@ -127,7 +127,7 @@ func NewDevice(options ...Option) (*Device, error) { } // Initialize channels for transmit notifications. - dev.transmitted = make([]chan virtqueue.UsedElement, dev.transmitQueue.Size()) + dev.transmitted = make([]chan virtqueue.UsedElement, dev.TransmitQueue.Size()) for i := range len(dev.transmitted) { // It is important to use a single-element buffered channel here. // When the channel was unbuffered and the monitorTransmitQueue @@ -159,7 +159,7 @@ func NewDevice(options ...Option) (*Device, error) { // in the transmit queue and produces a transmit notification via the // corresponding channel. func (dev *Device) monitorTransmitQueue() { - usedChan := dev.transmitQueue.UsedDescriptorChains() + usedChan := dev.TransmitQueue.UsedDescriptorChains() for { used, ok := <-usedChan if !ok { @@ -180,7 +180,7 @@ func (dev *Device) monitorTransmitQueue() { // packets. func (dev *Device) refillReceiveQueue() error { for { - _, err := dev.receiveQueue.OfferDescriptorChain(nil, 1, false) + _, err := dev.ReceiveQueue.OfferDescriptorChain(nil, 1, false) if err != nil { if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) { // Queue is full, job is done. @@ -212,17 +212,17 @@ func (dev *Device) Close() error { var errs []error - if dev.receiveQueue != nil { - if err := dev.receiveQueue.Close(); err == nil { - dev.receiveQueue = nil + if dev.ReceiveQueue != nil { + if err := dev.ReceiveQueue.Close(); err == nil { + dev.ReceiveQueue = nil } else { errs = append(errs, fmt.Errorf("close receive queue: %w", err)) } } - if dev.transmitQueue != nil { - if err := dev.transmitQueue.Close(); err == nil { - dev.transmitQueue = nil + if dev.TransmitQueue != nil { + if err := dev.TransmitQueue.Close(); err == nil { + dev.TransmitQueue = nil } else { errs = append(errs, fmt.Errorf("close transmit queue: %w", err)) } @@ -296,7 +296,7 @@ func (dev *Device) TransmitPacket(vnethdr virtio.NetHdr, packet []byte) error { outBuffers := [][]byte{vnethdrBuf, packet} //outBuffers := [][]byte{packet} - chainIndex, err := dev.transmitQueue.OfferDescriptorChain(outBuffers, 0, true) + chainIndex, err := dev.TransmitQueue.OfferDescriptorChain(outBuffers, 0, true) if err != nil { return fmt.Errorf("offer descriptor chain: %w", err) } @@ -304,7 +304,7 @@ func (dev *Device) TransmitPacket(vnethdr virtio.NetHdr, packet []byte) error { // Wait for the packet to have been transmitted. <-dev.transmitted[chainIndex] - if err = dev.transmitQueue.FreeDescriptorChain(chainIndex); err != nil { + if err = dev.TransmitQueue.FreeDescriptorChain(chainIndex); err != nil { return fmt.Errorf("free descriptor chain: %w", err) } @@ -320,7 +320,7 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro vnethdrBuf[virtio.NetHdrSize+14-2] = 0x86 vnethdrBuf[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype - chainIndexes, err := dev.transmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets, true) + chainIndexes, err := dev.TransmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets, true) if err != nil { return fmt.Errorf("offer descriptor chain: %w", err) } @@ -330,7 +330,7 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro for i := range chainIndexes { <-dev.transmitted[chainIndexes[i]] - if err = dev.transmitQueue.FreeDescriptorChain(chainIndexes[i]); err != nil { + if err = dev.TransmitQueue.FreeDescriptorChain(chainIndexes[i]); err != nil { return fmt.Errorf("free descriptor chain: %w", err) } } @@ -346,7 +346,7 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro // When this method returns an error, the receive queue will likely be in a // broken state which this implementation cannot recover from. The caller should // close the device and not attempt any additional receives. -func (dev *Device) ReceivePacket() (virtio.NetHdr, []byte, error) { +func (dev *Device) ReceivePacket(out []byte) (int, virtio.NetHdr, error) { var ( chainHeads []uint16 @@ -358,41 +358,30 @@ func (dev *Device) ReceivePacket() (virtio.NetHdr, []byte, error) { packetLength = -virtio.NetHdrSize ) + lenRead := 0 + // We presented FeatureNetMergeRXBuffers to the device, so one packet may be // made of multiple descriptor chains which are to be merged. for remainingChains := 1; remainingChains > 0; remainingChains-- { // Get the next descriptor chain. - usedElement, ok := <-dev.receiveQueue.UsedDescriptorChains() + usedElement, ok := <-dev.ReceiveQueue.UsedDescriptorChains() if !ok { - return virtio.NetHdr{}, nil, ErrDeviceClosed + return 0, virtio.NetHdr{}, ErrDeviceClosed } // Track this chain to be freed later. head := uint16(usedElement.DescriptorIndex) chainHeads = append(chainHeads, head) - outBuffers, inBuffers, err := dev.receiveQueue.GetDescriptorChain(head) + n, err := dev.ReceiveQueue.GetDescriptorChainContents(head, out[lenRead:]) if err != nil { // When this fails we may miss to free some descriptor chains. We // could try to mitigate this by deferring the freeing somehow, but // it's not worth the hassle. When this method fails, the queue will // be in a broken state anyway. - return virtio.NetHdr{}, nil, fmt.Errorf("get descriptor chain: %w", err) + return 0, virtio.NetHdr{}, fmt.Errorf("get descriptor chain: %w", err) } - if len(outBuffers) > 0 { - // How did this happen!? - panic("receive queue contains device-readable buffers") - } - if len(inBuffers) == 0 { - // Empty descriptor chains should not be possible. - panic("descriptor chain contains no buffers") - } - - // The device tells us how many bytes of the descriptor chain it has - // actually written to. The specification forces the device to fully - // fill up all but the last descriptor chain when multiple descriptor - // chains are being merged, but being more compatible here doesn't hurt. - inBuffers = truncateBuffers(inBuffers, int(usedElement.Length)) + lenRead += n packetLength += int(usedElement.Length) // Is this the first descriptor chain we process? @@ -403,49 +392,51 @@ func (dev *Device) ReceivePacket() (virtio.NetHdr, []byte, error) { // descriptor chain, but it is reasonable to assume that this is // always the case. // The decode method already does the buffer length check. - if err = vnethdr.Decode(inBuffers[0]); err != nil { + if err = vnethdr.Decode(out[0:]); err != nil { // The device misbehaved. There is no way we can gracefully // recover from this, because we don't know how many of the // following descriptor chains belong to this packet. - return virtio.NetHdr{}, nil, fmt.Errorf("decode vnethdr: %w", err) + return 0, virtio.NetHdr{}, fmt.Errorf("decode vnethdr: %w", err) } - inBuffers[0] = inBuffers[0][virtio.NetHdrSize:] + lenRead = 0 + out = out[virtio.NetHdrSize:] // The virtio-net header tells us how many descriptor chains this // packet is long. remainingChains = int(vnethdr.NumBuffers) } - buffers = append(buffers, inBuffers...) + //buffers = append(buffers, inBuffers...) } // Copy all the buffers together to produce the complete packet slice. - packet := make([]byte, packetLength) - copied := 0 - for _, buffer := range buffers { - copied += copy(packet[copied:], buffer) - } - if copied != packetLength { - panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", packetLength, copied)) - } + //out = out[:packetLength] + //copied := 0 + //for _, buffer := range buffers { + // copied += copy(out[copied:], buffer) + //} + //if copied != packetLength { + // panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", packetLength, copied)) + //} // Now that we have copied all buffers, we can free the used descriptor // chains again. // TODO: Recycling the descriptor chains would be more efficient than // freeing them just to offer them again right after. for _, head := range chainHeads { - if err := dev.receiveQueue.FreeDescriptorChain(head); err != nil { - return virtio.NetHdr{}, nil, fmt.Errorf("free descriptor chain with head index %d: %w", head, err) + if err := dev.ReceiveQueue.FreeAndOfferDescriptorChains(head); err != nil { + return 0, virtio.NetHdr{}, fmt.Errorf("free descriptor chain with head index %d: %w", head, err) } } + //if we don't churn chains, maybe we don't need this? // It's advised to always keep the receive queue fully populated with // available buffers which the device can write new packets into. - if err := dev.refillReceiveQueue(); err != nil { - return virtio.NetHdr{}, nil, fmt.Errorf("refill receive queue: %w", err) - } + //if err := dev.refillReceiveQueue(); err != nil { + // return 0, virtio.NetHdr{}, fmt.Errorf("refill receive queue: %w", err) + //} - return vnethdr, packet, nil + return packetLength, vnethdr, nil } // TODO: Make above methods cancelable by taking a context.Context argument? diff --git a/overlay/virtqueue/descriptor_table.go b/overlay/virtqueue/descriptor_table.go index 0b014e2..515facf 100644 --- a/overlay/virtqueue/descriptor_table.go +++ b/overlay/virtqueue/descriptor_table.go @@ -349,6 +349,74 @@ func (dt *DescriptorTable) getDescriptorChain(head uint16) (outBuffers, inBuffer return } +func (dt *DescriptorTable) getDescriptorChainContents(head uint16, out []byte) (int, error) { + if int(head) > len(dt.descriptors) { + return 0, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain) + } + + dt.mu.Lock() + defer dt.mu.Unlock() + + // Iterate over the chain. The iteration is limited to the queue size to + // avoid ending up in an endless loop when things go very wrong. + + length := 0 + //find length + next := head + for range len(dt.descriptors) { + if next == dt.freeHeadIndex { + return 0, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain) + } + + desc := &dt.descriptors[next] + + if desc.flags&descriptorFlagWritable == 0 { + return 0, fmt.Errorf("receive queue contains device-readable buffer") + } + length += int(desc.length) + + // Is this the tail of the chain? + if desc.flags&descriptorFlagHasNext == 0 { + break + } + + // Detect loops. + if desc.next == head { + return 0, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain) + } + + next = desc.next + } + + //set out to length: + out = out[:length] + + //now do the copying + copied := 0 + for range len(dt.descriptors) { + desc := &dt.descriptors[next] + + // The descriptor address points to memory not managed by Go, so this + // conversion is safe. See https://github.com/golang/go/issues/58625 + //goland:noinspection GoVetUnsafePointer + bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length) + copied += copy(out[copied:], bs) + + // Is this the tail of the chain? + if desc.flags&descriptorFlagHasNext == 0 { + break + } + + // we did this already, no need to detect loops. + next = desc.next + } + if copied != length { + panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", length, copied)) + } + + return length, nil +} + // freeDescriptorChain can be used to free a descriptor chain when it is no // longer in use. The descriptor chain that starts with the given index will be // put back into the free chain, so the descriptors can be used for later calls diff --git a/overlay/virtqueue/split_virtqueue.go b/overlay/virtqueue/split_virtqueue.go index 9a0ba76..8f17a50 100644 --- a/overlay/virtqueue/split_virtqueue.go +++ b/overlay/virtqueue/split_virtqueue.go @@ -49,6 +49,8 @@ type SplitQueue struct { offerMutex sync.Mutex pageSize int itemSize int + + epoll eventfd.Epoll } // NewSplitQueue allocates a new [SplitQueue] in memory. The given queue size @@ -132,6 +134,15 @@ func NewSplitQueue(queueSize int) (_ *SplitQueue, err error) { sq.usedChains = make(chan UsedElement, queueSize) sq.moreFreeDescriptors = make(chan struct{}) + sq.epoll, err = eventfd.NewEpoll() + if err != nil { + return nil, err + } + err = sq.epoll.AddEvent(sq.callEventFD.FD()) + if err != nil { + return nil, err + } + // Consume used buffer notifications in the background. sq.stop = sq.startConsumeUsedRing() @@ -194,25 +205,9 @@ func (sq *SplitQueue) UsedDescriptorChains() chan UsedElement { } // startConsumeUsedRing starts a goroutine that runs [consumeUsedRing]. -// A function is returned that can be used to gracefully cancel it. +// A function is returned that can be used to gracefully cancel it. todo rename func (sq *SplitQueue) startConsumeUsedRing() func() error { - ctx, cancel := context.WithCancel(context.Background()) - done := make(chan error) - - ep, err := eventfd.NewEpoll() - if err != nil { - panic(err) - } - err = ep.AddEvent(sq.callEventFD.FD()) - if err != nil { - panic(err) - } - - go func() { - done <- sq.consumeUsedRing(ctx, &ep) - }() return func() error { - cancel() // The goroutine blocks until it receives a signal on the event file // descriptor, so it will never notice the context being canceled. @@ -221,43 +216,28 @@ func (sq *SplitQueue) startConsumeUsedRing() func() error { if err := sq.callEventFD.Kick(); err != nil { return fmt.Errorf("wake up goroutine: %w", err) } - - // Wait for the goroutine to end. This prevents the event file - // descriptor to be closed while it's still being used. - // If the goroutine failed, this is the last chance to propagate the - // error so it at least doesn't go unnoticed, even though the error may - // be older already. - if err := <-done; err != nil { - return fmt.Errorf("goroutine: consume used ring: %w", err) - } return nil } } -// consumeUsedRing runs in a goroutine, waits for the device to signal that it -// has used descriptor chains and puts all new [UsedElement]s into the channel -// for them. -func (sq *SplitQueue) consumeUsedRing(ctx context.Context, epoll *eventfd.Epoll) error { +// BlockAndGetHeads waits for the device to signal that it has used descriptor chains and returns all [UsedElement]s +func (sq *SplitQueue) BlockAndGetHeads(ctx context.Context) ([]UsedElement, error) { var n int var err error for ctx.Err() == nil { - // Wait for a signal from the device. - if n, err = epoll.Block(); err != nil { - return fmt.Errorf("wait: %w", err) + if n, err = sq.epoll.Block(); err != nil { + return nil, fmt.Errorf("wait: %w", err) } if n > 0 { - _ = epoll.Clear() //??? - - // Process all new used elements. - for _, usedElement := range sq.usedRing.take() { - sq.usedChains <- usedElement - } + out := sq.usedRing.take() + _ = sq.epoll.Clear() //??? + return out, nil } } - return nil + return nil, ctx.Err() } // blockForMoreDescriptors blocks on a channel waiting for more descriptors to free up. @@ -345,6 +325,55 @@ func (sq *SplitQueue) OfferDescriptorChain(outBuffers [][]byte, numInBuffers int return head, nil } +func (sq *SplitQueue) OfferInDescriptorChains(numInBuffers int, waitFree bool) (uint16, error) { + sq.ensureInitialized() + + // Synchronize the offering of descriptor chains. While the descriptor table + // and available ring are synchronized on their own as well, this does not + // protect us from interleaved calls which could cause reordering. + // By locking here, we can ensure that all descriptor chains are made + // available to the device in the same order as this method was called. + sq.offerMutex.Lock() + defer sq.offerMutex.Unlock() + + // Create a descriptor chain for the given buffers. + var ( + head uint16 + err error + ) + for { + head, err = sq.descriptorTable.createDescriptorChain(nil, numInBuffers) + if err == nil { + break + } + + // I don't wanna use errors.Is, it's slow + //goland:noinspection GoDirectComparisonOfErrors + if err == ErrNotEnoughFreeDescriptors { + if waitFree { + // Wait for more free descriptors to be put back into the queue. + // If the number of free descriptors is still not sufficient, we'll + // land here again. + sq.blockForMoreDescriptors() + continue + } else { + return 0, err + } + } + return 0, fmt.Errorf("create descriptor chain: %w", err) + } + + // Make the descriptor chain available to the device. + sq.availableRing.offer([]uint16{head}) + + // Notify the device to make it process the updated available ring. + if err := sq.kickEventFD.Kick(); err != nil { + return head, fmt.Errorf("notify device: %w", err) + } + + return head, nil +} + func (sq *SplitQueue) OfferOutDescriptorChains(prepend []byte, outBuffers [][]byte, waitFree bool) ([]uint16, error) { sq.ensureInitialized() @@ -420,6 +449,11 @@ func (sq *SplitQueue) GetDescriptorChain(head uint16) (outBuffers, inBuffers [][ return sq.descriptorTable.getDescriptorChain(head) } +func (sq *SplitQueue) GetDescriptorChainContents(head uint16, out []byte) (int, error) { + sq.ensureInitialized() + return sq.descriptorTable.getDescriptorChainContents(head, out) +} + // FreeDescriptorChain frees the descriptor chain with the given head index. // The head index must be one that was returned by a previous call to // [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been @@ -447,6 +481,35 @@ func (sq *SplitQueue) FreeDescriptorChain(head uint16) error { return nil } +func (sq *SplitQueue) FreeAndOfferDescriptorChains(head uint16) error { + sq.ensureInitialized() + + //todo I don't think we need this here? + // Synchronize the offering of descriptor chains. While the descriptor table + // and available ring are synchronized on their own as well, this does not + // protect us from interleaved calls which could cause reordering. + // By locking here, we can ensure that all descriptor chains are made + // available to the device in the same order as this method was called. + //sq.offerMutex.Lock() + //defer sq.offerMutex.Unlock() + + //todo not doing this may break eventually? + //not called under lock + //if err := sq.descriptorTable.freeDescriptorChain(head); err != nil { + // return fmt.Errorf("free: %w", err) + //} + + // Make the descriptor chain available to the device. + sq.availableRing.offer([]uint16{head}) + + // Notify the device to make it process the updated available ring. + if err := sq.kickEventFD.Kick(); err != nil { + return fmt.Errorf("notify device: %w", err) + } + + return nil +} + // Close releases all resources used for this queue. // The implementation will try to release as many resources as possible and // collect potential errors before returning them.