mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 00:15:37 +01:00
broken chkpt
This commit is contained in:
22
interface.go
22
interface.go
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/overlay"
|
||||
"github.com/slackhq/nebula/overlay/virtio"
|
||||
"github.com/slackhq/nebula/packet"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
)
|
||||
@@ -308,18 +309,31 @@ func (f *Interface) listenOut(q int) {
|
||||
})
|
||||
}
|
||||
|
||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||
func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
|
||||
runtime.LockOSThread()
|
||||
|
||||
packet := make([]byte, mtu)
|
||||
const batch = 64
|
||||
originalPackets := make([][]byte, batch) //todo batch config
|
||||
for i := 0; i < batch; i++ {
|
||||
originalPackets[i] = make([]byte, 0xffff)
|
||||
}
|
||||
out := make([]byte, mtu)
|
||||
fwPacket := &firewall.Packet{}
|
||||
nb := make([]byte, 12, 12)
|
||||
|
||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||
|
||||
queues := reader.GetQueues()
|
||||
if len(queues) == 0 {
|
||||
f.l.Fatal("Failed to get queues")
|
||||
}
|
||||
queue := queues[0]
|
||||
|
||||
for {
|
||||
n, err := reader.Read(packet)
|
||||
|
||||
n, err := reader.ReadMany(originalPacket)
|
||||
//todo!!
|
||||
pkt := originalPacket[virtio.NetHdrSize : n+virtio.NetHdrSize]
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
||||
return
|
||||
@@ -330,7 +344,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
||||
f.consumeInsidePacket(pkt, fwPacket, nb, out, queueNum, conntrackCache.Get(f.l))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,20 +2,21 @@ package overlay
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
const DefaultMTU = 1300
|
||||
|
||||
type TunDev interface {
|
||||
io.ReadWriteCloser
|
||||
ReadMany([][]byte) (int, error)
|
||||
WriteMany([][]byte) (int, error)
|
||||
GetQueues() []*virtqueue.SplitQueue
|
||||
}
|
||||
|
||||
// TODO: We may be able to remove routines
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
@@ -40,6 +41,10 @@ func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled boo
|
||||
return tun
|
||||
}
|
||||
|
||||
func (*disabledTun) GetQueues() []*virtqueue.SplitQueue {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*disabledTun) Activate() error {
|
||||
return nil
|
||||
}
|
||||
@@ -117,6 +122,10 @@ func (t *disabledTun) WriteMany(b [][]byte) (int, error) {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (t *disabledTun) ReadMany(b [][]byte) (int, error) {
|
||||
return t.Read(b[0])
|
||||
}
|
||||
|
||||
func (t *disabledTun) NewMultiQueueReader() (TunDev, error) {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ package overlay
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
@@ -20,6 +19,7 @@ import (
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/overlay/vhostnet"
|
||||
"github.com/slackhq/nebula/overlay/virtio"
|
||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/vishvananda/netlink"
|
||||
@@ -27,7 +27,7 @@ import (
|
||||
)
|
||||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
file *os.File
|
||||
fd int
|
||||
vdev *vhostnet.Device
|
||||
Device string
|
||||
@@ -51,6 +51,10 @@ func (t *tun) Networks() []netip.Prefix {
|
||||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func (t *tun) GetQueues() []*virtqueue.SplitQueue {
|
||||
return []*virtqueue.SplitQueue{t.vdev.ReceiveQueue, t.vdev.TransmitQueue}
|
||||
}
|
||||
|
||||
type ifReq struct {
|
||||
Name [16]byte
|
||||
Flags uint16
|
||||
@@ -129,8 +133,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
||||
}
|
||||
|
||||
flags := 0
|
||||
//flags := unix.TUN_F_CSUM
|
||||
//|unix.TUN_F_USO4|unix.TUN_F_USO6
|
||||
//flags = //unix.TUN_F_CSUM //| unix.TUN_F_TSO4 | unix.TUN_F_USO4 | unix.TUN_F_TSO6 | unix.TUN_F_USO6
|
||||
err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, flags)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("set offloads: %w", err)
|
||||
@@ -168,7 +171,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
||||
|
||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
file: file,
|
||||
fd: int(file.Fd()),
|
||||
vpnNetworks: vpnNetworks,
|
||||
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||
@@ -699,8 +702,8 @@ func (t *tun) Close() error {
|
||||
_ = t.vdev.Close()
|
||||
}
|
||||
|
||||
if t.ReadWriteCloser != nil {
|
||||
_ = t.ReadWriteCloser.Close()
|
||||
if t.file != nil {
|
||||
_ = t.file.Close()
|
||||
}
|
||||
|
||||
if t.ioctlFd > 0 {
|
||||
@@ -710,17 +713,17 @@ func (t *tun) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Read(p []byte) (int, error) {
|
||||
hdr, out, err := t.vdev.ReceivePacket() //we are TXing
|
||||
func (t *tun) ReadMany(p [][]byte) (int, error) {
|
||||
//todo call consumeUsedRing here instead of its own thread
|
||||
|
||||
n, hdr, err := t.vdev.ReceivePacket(p) //we are TXing
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if hdr.NumBuffers > 1 {
|
||||
t.l.WithField("num_buffers", hdr.NumBuffers).Info("wow, lots to TX from tun")
|
||||
}
|
||||
p = p[:len(out)]
|
||||
copy(p, out)
|
||||
return len(out), nil
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (t *tun) Write(b []byte) (int, error) {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
@@ -66,6 +67,10 @@ func (d *UserDevice) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *UserDevice) ReadMany(b [][]byte) (int, error) {
|
||||
return d.Read(b[0])
|
||||
}
|
||||
|
||||
func (d *UserDevice) WriteMany(b [][]byte) (int, error) {
|
||||
out := 0
|
||||
for i := range b {
|
||||
@@ -77,3 +82,7 @@ func (d *UserDevice) WriteMany(b [][]byte) (int, error) {
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (*UserDevice) GetQueues() []*virtqueue.SplitQueue {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -28,8 +28,8 @@ type Device struct {
|
||||
initialized bool
|
||||
controlFD int
|
||||
|
||||
receiveQueue *virtqueue.SplitQueue
|
||||
transmitQueue *virtqueue.SplitQueue
|
||||
ReceiveQueue *virtqueue.SplitQueue
|
||||
TransmitQueue *virtqueue.SplitQueue
|
||||
|
||||
// transmitted contains channels for each possible descriptor chain head
|
||||
// index. This is used for packet transmit notifications.
|
||||
@@ -96,17 +96,17 @@ func NewDevice(options ...Option) (*Device, error) {
|
||||
}
|
||||
|
||||
// Initialize and register the queues needed for the networking device.
|
||||
if dev.receiveQueue, err = createQueue(dev.controlFD, receiveQueueIndex, opts.queueSize); err != nil {
|
||||
if dev.ReceiveQueue, err = createQueue(dev.controlFD, receiveQueueIndex, opts.queueSize); err != nil {
|
||||
return nil, fmt.Errorf("create receive queue: %w", err)
|
||||
}
|
||||
if dev.transmitQueue, err = createQueue(dev.controlFD, transmitQueueIndex, opts.queueSize); err != nil {
|
||||
if dev.TransmitQueue, err = createQueue(dev.controlFD, transmitQueueIndex, opts.queueSize); err != nil {
|
||||
return nil, fmt.Errorf("create transmit queue: %w", err)
|
||||
}
|
||||
|
||||
// Set up memory mappings for all buffers used by the queues. This has to
|
||||
// happen before a backend for the queues can be registered.
|
||||
memoryLayout := vhost.NewMemoryLayoutForQueues(
|
||||
[]*virtqueue.SplitQueue{dev.receiveQueue, dev.transmitQueue},
|
||||
[]*virtqueue.SplitQueue{dev.ReceiveQueue, dev.TransmitQueue},
|
||||
)
|
||||
if err = vhost.SetMemoryLayout(dev.controlFD, memoryLayout); err != nil {
|
||||
return nil, fmt.Errorf("setup memory layout: %w", err)
|
||||
@@ -127,7 +127,7 @@ func NewDevice(options ...Option) (*Device, error) {
|
||||
}
|
||||
|
||||
// Initialize channels for transmit notifications.
|
||||
dev.transmitted = make([]chan virtqueue.UsedElement, dev.transmitQueue.Size())
|
||||
dev.transmitted = make([]chan virtqueue.UsedElement, dev.TransmitQueue.Size())
|
||||
for i := range len(dev.transmitted) {
|
||||
// It is important to use a single-element buffered channel here.
|
||||
// When the channel was unbuffered and the monitorTransmitQueue
|
||||
@@ -159,7 +159,7 @@ func NewDevice(options ...Option) (*Device, error) {
|
||||
// in the transmit queue and produces a transmit notification via the
|
||||
// corresponding channel.
|
||||
func (dev *Device) monitorTransmitQueue() {
|
||||
usedChan := dev.transmitQueue.UsedDescriptorChains()
|
||||
usedChan := dev.TransmitQueue.UsedDescriptorChains()
|
||||
for {
|
||||
used, ok := <-usedChan
|
||||
if !ok {
|
||||
@@ -180,7 +180,7 @@ func (dev *Device) monitorTransmitQueue() {
|
||||
// packets.
|
||||
func (dev *Device) refillReceiveQueue() error {
|
||||
for {
|
||||
_, err := dev.receiveQueue.OfferDescriptorChain(nil, 1, false)
|
||||
_, err := dev.ReceiveQueue.OfferDescriptorChain(nil, 1, false)
|
||||
if err != nil {
|
||||
if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
|
||||
// Queue is full, job is done.
|
||||
@@ -212,17 +212,17 @@ func (dev *Device) Close() error {
|
||||
|
||||
var errs []error
|
||||
|
||||
if dev.receiveQueue != nil {
|
||||
if err := dev.receiveQueue.Close(); err == nil {
|
||||
dev.receiveQueue = nil
|
||||
if dev.ReceiveQueue != nil {
|
||||
if err := dev.ReceiveQueue.Close(); err == nil {
|
||||
dev.ReceiveQueue = nil
|
||||
} else {
|
||||
errs = append(errs, fmt.Errorf("close receive queue: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if dev.transmitQueue != nil {
|
||||
if err := dev.transmitQueue.Close(); err == nil {
|
||||
dev.transmitQueue = nil
|
||||
if dev.TransmitQueue != nil {
|
||||
if err := dev.TransmitQueue.Close(); err == nil {
|
||||
dev.TransmitQueue = nil
|
||||
} else {
|
||||
errs = append(errs, fmt.Errorf("close transmit queue: %w", err))
|
||||
}
|
||||
@@ -296,7 +296,7 @@ func (dev *Device) TransmitPacket(vnethdr virtio.NetHdr, packet []byte) error {
|
||||
outBuffers := [][]byte{vnethdrBuf, packet}
|
||||
//outBuffers := [][]byte{packet}
|
||||
|
||||
chainIndex, err := dev.transmitQueue.OfferDescriptorChain(outBuffers, 0, true)
|
||||
chainIndex, err := dev.TransmitQueue.OfferDescriptorChain(outBuffers, 0, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("offer descriptor chain: %w", err)
|
||||
}
|
||||
@@ -304,7 +304,7 @@ func (dev *Device) TransmitPacket(vnethdr virtio.NetHdr, packet []byte) error {
|
||||
// Wait for the packet to have been transmitted.
|
||||
<-dev.transmitted[chainIndex]
|
||||
|
||||
if err = dev.transmitQueue.FreeDescriptorChain(chainIndex); err != nil {
|
||||
if err = dev.TransmitQueue.FreeDescriptorChain(chainIndex); err != nil {
|
||||
return fmt.Errorf("free descriptor chain: %w", err)
|
||||
}
|
||||
|
||||
@@ -320,7 +320,7 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
|
||||
vnethdrBuf[virtio.NetHdrSize+14-2] = 0x86
|
||||
vnethdrBuf[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype
|
||||
|
||||
chainIndexes, err := dev.transmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets, true)
|
||||
chainIndexes, err := dev.TransmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("offer descriptor chain: %w", err)
|
||||
}
|
||||
@@ -330,7 +330,7 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
|
||||
for i := range chainIndexes {
|
||||
<-dev.transmitted[chainIndexes[i]]
|
||||
|
||||
if err = dev.transmitQueue.FreeDescriptorChain(chainIndexes[i]); err != nil {
|
||||
if err = dev.TransmitQueue.FreeDescriptorChain(chainIndexes[i]); err != nil {
|
||||
return fmt.Errorf("free descriptor chain: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -346,7 +346,7 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
|
||||
// When this method returns an error, the receive queue will likely be in a
|
||||
// broken state which this implementation cannot recover from. The caller should
|
||||
// close the device and not attempt any additional receives.
|
||||
func (dev *Device) ReceivePacket() (virtio.NetHdr, []byte, error) {
|
||||
func (dev *Device) ReceivePacket(out []byte) (int, virtio.NetHdr, error) {
|
||||
var (
|
||||
chainHeads []uint16
|
||||
|
||||
@@ -358,41 +358,30 @@ func (dev *Device) ReceivePacket() (virtio.NetHdr, []byte, error) {
|
||||
packetLength = -virtio.NetHdrSize
|
||||
)
|
||||
|
||||
lenRead := 0
|
||||
|
||||
// We presented FeatureNetMergeRXBuffers to the device, so one packet may be
|
||||
// made of multiple descriptor chains which are to be merged.
|
||||
for remainingChains := 1; remainingChains > 0; remainingChains-- {
|
||||
// Get the next descriptor chain.
|
||||
usedElement, ok := <-dev.receiveQueue.UsedDescriptorChains()
|
||||
usedElement, ok := <-dev.ReceiveQueue.UsedDescriptorChains()
|
||||
if !ok {
|
||||
return virtio.NetHdr{}, nil, ErrDeviceClosed
|
||||
return 0, virtio.NetHdr{}, ErrDeviceClosed
|
||||
}
|
||||
|
||||
// Track this chain to be freed later.
|
||||
head := uint16(usedElement.DescriptorIndex)
|
||||
chainHeads = append(chainHeads, head)
|
||||
|
||||
outBuffers, inBuffers, err := dev.receiveQueue.GetDescriptorChain(head)
|
||||
n, err := dev.ReceiveQueue.GetDescriptorChainContents(head, out[lenRead:])
|
||||
if err != nil {
|
||||
// When this fails we may miss to free some descriptor chains. We
|
||||
// could try to mitigate this by deferring the freeing somehow, but
|
||||
// it's not worth the hassle. When this method fails, the queue will
|
||||
// be in a broken state anyway.
|
||||
return virtio.NetHdr{}, nil, fmt.Errorf("get descriptor chain: %w", err)
|
||||
return 0, virtio.NetHdr{}, fmt.Errorf("get descriptor chain: %w", err)
|
||||
}
|
||||
if len(outBuffers) > 0 {
|
||||
// How did this happen!?
|
||||
panic("receive queue contains device-readable buffers")
|
||||
}
|
||||
if len(inBuffers) == 0 {
|
||||
// Empty descriptor chains should not be possible.
|
||||
panic("descriptor chain contains no buffers")
|
||||
}
|
||||
|
||||
// The device tells us how many bytes of the descriptor chain it has
|
||||
// actually written to. The specification forces the device to fully
|
||||
// fill up all but the last descriptor chain when multiple descriptor
|
||||
// chains are being merged, but being more compatible here doesn't hurt.
|
||||
inBuffers = truncateBuffers(inBuffers, int(usedElement.Length))
|
||||
lenRead += n
|
||||
packetLength += int(usedElement.Length)
|
||||
|
||||
// Is this the first descriptor chain we process?
|
||||
@@ -403,49 +392,51 @@ func (dev *Device) ReceivePacket() (virtio.NetHdr, []byte, error) {
|
||||
// descriptor chain, but it is reasonable to assume that this is
|
||||
// always the case.
|
||||
// The decode method already does the buffer length check.
|
||||
if err = vnethdr.Decode(inBuffers[0]); err != nil {
|
||||
if err = vnethdr.Decode(out[0:]); err != nil {
|
||||
// The device misbehaved. There is no way we can gracefully
|
||||
// recover from this, because we don't know how many of the
|
||||
// following descriptor chains belong to this packet.
|
||||
return virtio.NetHdr{}, nil, fmt.Errorf("decode vnethdr: %w", err)
|
||||
return 0, virtio.NetHdr{}, fmt.Errorf("decode vnethdr: %w", err)
|
||||
}
|
||||
inBuffers[0] = inBuffers[0][virtio.NetHdrSize:]
|
||||
lenRead = 0
|
||||
out = out[virtio.NetHdrSize:]
|
||||
|
||||
// The virtio-net header tells us how many descriptor chains this
|
||||
// packet is long.
|
||||
remainingChains = int(vnethdr.NumBuffers)
|
||||
}
|
||||
|
||||
buffers = append(buffers, inBuffers...)
|
||||
//buffers = append(buffers, inBuffers...)
|
||||
}
|
||||
|
||||
// Copy all the buffers together to produce the complete packet slice.
|
||||
packet := make([]byte, packetLength)
|
||||
copied := 0
|
||||
for _, buffer := range buffers {
|
||||
copied += copy(packet[copied:], buffer)
|
||||
}
|
||||
if copied != packetLength {
|
||||
panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", packetLength, copied))
|
||||
}
|
||||
//out = out[:packetLength]
|
||||
//copied := 0
|
||||
//for _, buffer := range buffers {
|
||||
// copied += copy(out[copied:], buffer)
|
||||
//}
|
||||
//if copied != packetLength {
|
||||
// panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", packetLength, copied))
|
||||
//}
|
||||
|
||||
// Now that we have copied all buffers, we can free the used descriptor
|
||||
// chains again.
|
||||
// TODO: Recycling the descriptor chains would be more efficient than
|
||||
// freeing them just to offer them again right after.
|
||||
for _, head := range chainHeads {
|
||||
if err := dev.receiveQueue.FreeDescriptorChain(head); err != nil {
|
||||
return virtio.NetHdr{}, nil, fmt.Errorf("free descriptor chain with head index %d: %w", head, err)
|
||||
if err := dev.ReceiveQueue.FreeAndOfferDescriptorChains(head); err != nil {
|
||||
return 0, virtio.NetHdr{}, fmt.Errorf("free descriptor chain with head index %d: %w", head, err)
|
||||
}
|
||||
}
|
||||
|
||||
//if we don't churn chains, maybe we don't need this?
|
||||
// It's advised to always keep the receive queue fully populated with
|
||||
// available buffers which the device can write new packets into.
|
||||
if err := dev.refillReceiveQueue(); err != nil {
|
||||
return virtio.NetHdr{}, nil, fmt.Errorf("refill receive queue: %w", err)
|
||||
}
|
||||
//if err := dev.refillReceiveQueue(); err != nil {
|
||||
// return 0, virtio.NetHdr{}, fmt.Errorf("refill receive queue: %w", err)
|
||||
//}
|
||||
|
||||
return vnethdr, packet, nil
|
||||
return packetLength, vnethdr, nil
|
||||
}
|
||||
|
||||
// TODO: Make above methods cancelable by taking a context.Context argument?
|
||||
|
||||
@@ -349,6 +349,74 @@ func (dt *DescriptorTable) getDescriptorChain(head uint16) (outBuffers, inBuffer
|
||||
return
|
||||
}
|
||||
|
||||
func (dt *DescriptorTable) getDescriptorChainContents(head uint16, out []byte) (int, error) {
|
||||
if int(head) > len(dt.descriptors) {
|
||||
return 0, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
dt.mu.Lock()
|
||||
defer dt.mu.Unlock()
|
||||
|
||||
// Iterate over the chain. The iteration is limited to the queue size to
|
||||
// avoid ending up in an endless loop when things go very wrong.
|
||||
|
||||
length := 0
|
||||
//find length
|
||||
next := head
|
||||
for range len(dt.descriptors) {
|
||||
if next == dt.freeHeadIndex {
|
||||
return 0, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
desc := &dt.descriptors[next]
|
||||
|
||||
if desc.flags&descriptorFlagWritable == 0 {
|
||||
return 0, fmt.Errorf("receive queue contains device-readable buffer")
|
||||
}
|
||||
length += int(desc.length)
|
||||
|
||||
// Is this the tail of the chain?
|
||||
if desc.flags&descriptorFlagHasNext == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Detect loops.
|
||||
if desc.next == head {
|
||||
return 0, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
next = desc.next
|
||||
}
|
||||
|
||||
//set out to length:
|
||||
out = out[:length]
|
||||
|
||||
//now do the copying
|
||||
copied := 0
|
||||
for range len(dt.descriptors) {
|
||||
desc := &dt.descriptors[next]
|
||||
|
||||
// The descriptor address points to memory not managed by Go, so this
|
||||
// conversion is safe. See https://github.com/golang/go/issues/58625
|
||||
//goland:noinspection GoVetUnsafePointer
|
||||
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
|
||||
copied += copy(out[copied:], bs)
|
||||
|
||||
// Is this the tail of the chain?
|
||||
if desc.flags&descriptorFlagHasNext == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// we did this already, no need to detect loops.
|
||||
next = desc.next
|
||||
}
|
||||
if copied != length {
|
||||
panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", length, copied))
|
||||
}
|
||||
|
||||
return length, nil
|
||||
}
|
||||
|
||||
// freeDescriptorChain can be used to free a descriptor chain when it is no
|
||||
// longer in use. The descriptor chain that starts with the given index will be
|
||||
// put back into the free chain, so the descriptors can be used for later calls
|
||||
|
||||
@@ -49,6 +49,8 @@ type SplitQueue struct {
|
||||
offerMutex sync.Mutex
|
||||
pageSize int
|
||||
itemSize int
|
||||
|
||||
epoll eventfd.Epoll
|
||||
}
|
||||
|
||||
// NewSplitQueue allocates a new [SplitQueue] in memory. The given queue size
|
||||
@@ -132,6 +134,15 @@ func NewSplitQueue(queueSize int) (_ *SplitQueue, err error) {
|
||||
sq.usedChains = make(chan UsedElement, queueSize)
|
||||
sq.moreFreeDescriptors = make(chan struct{})
|
||||
|
||||
sq.epoll, err = eventfd.NewEpoll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = sq.epoll.AddEvent(sq.callEventFD.FD())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Consume used buffer notifications in the background.
|
||||
sq.stop = sq.startConsumeUsedRing()
|
||||
|
||||
@@ -194,25 +205,9 @@ func (sq *SplitQueue) UsedDescriptorChains() chan UsedElement {
|
||||
}
|
||||
|
||||
// startConsumeUsedRing starts a goroutine that runs [consumeUsedRing].
|
||||
// A function is returned that can be used to gracefully cancel it.
|
||||
// A function is returned that can be used to gracefully cancel it. todo rename
|
||||
func (sq *SplitQueue) startConsumeUsedRing() func() error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan error)
|
||||
|
||||
ep, err := eventfd.NewEpoll()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = ep.AddEvent(sq.callEventFD.FD())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
done <- sq.consumeUsedRing(ctx, &ep)
|
||||
}()
|
||||
return func() error {
|
||||
cancel()
|
||||
|
||||
// The goroutine blocks until it receives a signal on the event file
|
||||
// descriptor, so it will never notice the context being canceled.
|
||||
@@ -221,43 +216,28 @@ func (sq *SplitQueue) startConsumeUsedRing() func() error {
|
||||
if err := sq.callEventFD.Kick(); err != nil {
|
||||
return fmt.Errorf("wake up goroutine: %w", err)
|
||||
}
|
||||
|
||||
// Wait for the goroutine to end. This prevents the event file
|
||||
// descriptor to be closed while it's still being used.
|
||||
// If the goroutine failed, this is the last chance to propagate the
|
||||
// error so it at least doesn't go unnoticed, even though the error may
|
||||
// be older already.
|
||||
if err := <-done; err != nil {
|
||||
return fmt.Errorf("goroutine: consume used ring: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// consumeUsedRing runs in a goroutine, waits for the device to signal that it
|
||||
// has used descriptor chains and puts all new [UsedElement]s into the channel
|
||||
// for them.
|
||||
func (sq *SplitQueue) consumeUsedRing(ctx context.Context, epoll *eventfd.Epoll) error {
|
||||
// BlockAndGetHeads waits for the device to signal that it has used descriptor chains and returns all [UsedElement]s
|
||||
func (sq *SplitQueue) BlockAndGetHeads(ctx context.Context) ([]UsedElement, error) {
|
||||
var n int
|
||||
var err error
|
||||
for ctx.Err() == nil {
|
||||
|
||||
// Wait for a signal from the device.
|
||||
if n, err = epoll.Block(); err != nil {
|
||||
return fmt.Errorf("wait: %w", err)
|
||||
if n, err = sq.epoll.Block(); err != nil {
|
||||
return nil, fmt.Errorf("wait: %w", err)
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
_ = epoll.Clear() //???
|
||||
|
||||
// Process all new used elements.
|
||||
for _, usedElement := range sq.usedRing.take() {
|
||||
sq.usedChains <- usedElement
|
||||
}
|
||||
out := sq.usedRing.take()
|
||||
_ = sq.epoll.Clear() //???
|
||||
return out, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
// blockForMoreDescriptors blocks on a channel waiting for more descriptors to free up.
|
||||
@@ -345,6 +325,55 @@ func (sq *SplitQueue) OfferDescriptorChain(outBuffers [][]byte, numInBuffers int
|
||||
return head, nil
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) OfferInDescriptorChains(numInBuffers int, waitFree bool) (uint16, error) {
|
||||
sq.ensureInitialized()
|
||||
|
||||
// Synchronize the offering of descriptor chains. While the descriptor table
|
||||
// and available ring are synchronized on their own as well, this does not
|
||||
// protect us from interleaved calls which could cause reordering.
|
||||
// By locking here, we can ensure that all descriptor chains are made
|
||||
// available to the device in the same order as this method was called.
|
||||
sq.offerMutex.Lock()
|
||||
defer sq.offerMutex.Unlock()
|
||||
|
||||
// Create a descriptor chain for the given buffers.
|
||||
var (
|
||||
head uint16
|
||||
err error
|
||||
)
|
||||
for {
|
||||
head, err = sq.descriptorTable.createDescriptorChain(nil, numInBuffers)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
// I don't wanna use errors.Is, it's slow
|
||||
//goland:noinspection GoDirectComparisonOfErrors
|
||||
if err == ErrNotEnoughFreeDescriptors {
|
||||
if waitFree {
|
||||
// Wait for more free descriptors to be put back into the queue.
|
||||
// If the number of free descriptors is still not sufficient, we'll
|
||||
// land here again.
|
||||
sq.blockForMoreDescriptors()
|
||||
continue
|
||||
} else {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("create descriptor chain: %w", err)
|
||||
}
|
||||
|
||||
// Make the descriptor chain available to the device.
|
||||
sq.availableRing.offer([]uint16{head})
|
||||
|
||||
// Notify the device to make it process the updated available ring.
|
||||
if err := sq.kickEventFD.Kick(); err != nil {
|
||||
return head, fmt.Errorf("notify device: %w", err)
|
||||
}
|
||||
|
||||
return head, nil
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) OfferOutDescriptorChains(prepend []byte, outBuffers [][]byte, waitFree bool) ([]uint16, error) {
|
||||
sq.ensureInitialized()
|
||||
|
||||
@@ -420,6 +449,11 @@ func (sq *SplitQueue) GetDescriptorChain(head uint16) (outBuffers, inBuffers [][
|
||||
return sq.descriptorTable.getDescriptorChain(head)
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) GetDescriptorChainContents(head uint16, out []byte) (int, error) {
|
||||
sq.ensureInitialized()
|
||||
return sq.descriptorTable.getDescriptorChainContents(head, out)
|
||||
}
|
||||
|
||||
// FreeDescriptorChain frees the descriptor chain with the given head index.
|
||||
// The head index must be one that was returned by a previous call to
|
||||
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
|
||||
@@ -447,6 +481,35 @@ func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) FreeAndOfferDescriptorChains(head uint16) error {
|
||||
sq.ensureInitialized()
|
||||
|
||||
//todo I don't think we need this here?
|
||||
// Synchronize the offering of descriptor chains. While the descriptor table
|
||||
// and available ring are synchronized on their own as well, this does not
|
||||
// protect us from interleaved calls which could cause reordering.
|
||||
// By locking here, we can ensure that all descriptor chains are made
|
||||
// available to the device in the same order as this method was called.
|
||||
//sq.offerMutex.Lock()
|
||||
//defer sq.offerMutex.Unlock()
|
||||
|
||||
//todo not doing this may break eventually?
|
||||
//not called under lock
|
||||
//if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
|
||||
// return fmt.Errorf("free: %w", err)
|
||||
//}
|
||||
|
||||
// Make the descriptor chain available to the device.
|
||||
sq.availableRing.offer([]uint16{head})
|
||||
|
||||
// Notify the device to make it process the updated available ring.
|
||||
if err := sq.kickEventFD.Kick(); err != nil {
|
||||
return fmt.Errorf("notify device: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close releases all resources used for this queue.
|
||||
// The implementation will try to release as many resources as possible and
|
||||
// collect potential errors before returning them.
|
||||
|
||||
Reference in New Issue
Block a user