broken chkpt

This commit is contained in:
JackDoan
2025-11-11 15:37:54 -06:00
parent e7f01390a3
commit 1719149594
16 changed files with 275 additions and 303 deletions

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"net/netip" "net/netip"
"os" "os"
"runtime" "runtime"
@@ -18,7 +17,6 @@ import (
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/overlay/virtio"
"github.com/slackhq/nebula/packet" "github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
) )
@@ -270,6 +268,7 @@ func (f *Interface) listenOut(q int) {
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
lhh := f.lightHouse.NewRequestHandler() lhh := f.lightHouse.NewRequestHandler()
outPackets := make([]*packet.OutPacket, batch) outPackets := make([]*packet.OutPacket, batch)
for i := 0; i < batch; i++ { for i := 0; i < batch; i++ {
outPackets[i] = packet.NewOut() outPackets[i] = packet.NewOut()
@@ -295,16 +294,15 @@ func (f *Interface) listenOut(q int) {
if len(outPackets[i].Segments[j]) > 0 { if len(outPackets[i].Segments[j]) > 0 {
toSend = append(toSend, outPackets[i].Segments[j]) toSend = append(toSend, outPackets[i].Segments[j])
} }
} }
//toSend = append(toSend, outPackets[i])
//toSendCount++
} }
} }
//toSend = toSend[:toSendCount] //toSend = toSend[:toSendCount]
_, err := f.readers[q].WriteMany(toSend) if len(toSend) != 0 {
if err != nil { _, err := f.readers[q].WriteMany(toSend)
f.l.WithError(err).Error("Failed to write messages") if err != nil {
f.l.WithError(err).Error("Failed to write messages")
}
} }
}) })
} }
@@ -323,17 +321,15 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
queues := reader.GetQueues() packets := make([]*packet.VirtIOPacket, batch)
if len(queues) == 0 { for i := 0; i < batch; i++ {
f.l.Fatal("Failed to get queues") packets[i] = packet.NewVIO()
} }
queue := queues[0]
for { for {
n, err := reader.ReadMany(originalPacket) n, err := reader.ReadMany(packets)
//todo!! //todo!!
pkt := originalPacket[virtio.NetHdrSize : n+virtio.NetHdrSize]
if err != nil { if err != nil {
if errors.Is(err, os.ErrClosed) && f.closed.Load() { if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return return
@@ -344,7 +340,11 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
os.Exit(2) os.Exit(2)
} }
f.consumeInsidePacket(pkt, fwPacket, nb, out, queueNum, conntrackCache.Get(f.l)) //todo vectorize
for _, pkt := range packets[:n] {
f.consumeInsidePacket(pkt.Payload, fwPacket, nb, out, queueNum, conntrackCache.Get(f.l))
}
} }
} }

View File

@@ -2,19 +2,22 @@ package overlay
import ( import (
"fmt" "fmt"
"io"
"net" "net"
"net/netip" "net/netip"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/virtqueue" "github.com/slackhq/nebula/overlay/virtqueue"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
) )
const DefaultMTU = 1300 const DefaultMTU = 1300
type TunDev interface { type TunDev interface {
ReadMany([][]byte) (int, error) io.WriteCloser
ReadMany([]*packet.VirtIOPacket) (int, error)
WriteMany([][]byte) (int, error) WriteMany([][]byte) (int, error)
GetQueues() []*virtqueue.SplitQueue GetQueues() []*virtqueue.SplitQueue
} }

View File

@@ -10,6 +10,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/overlay/virtqueue" "github.com/slackhq/nebula/overlay/virtqueue"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
@@ -122,8 +123,8 @@ func (t *disabledTun) WriteMany(b [][]byte) (int, error) {
return out, nil return out, nil
} }
func (t *disabledTun) ReadMany(b [][]byte) (int, error) { func (t *disabledTun) ReadMany(b []*packet.VirtIOPacket) (int, error) {
return t.Read(b[0]) return t.Read(b[0].Payload)
} }
func (t *disabledTun) NewMultiQueueReader() (TunDev, error) { func (t *disabledTun) NewMultiQueueReader() (TunDev, error) {

View File

@@ -18,10 +18,11 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/vhostnet" "github.com/slackhq/nebula/overlay/vhostnet"
"github.com/slackhq/nebula/overlay/virtio"
"github.com/slackhq/nebula/overlay/virtqueue" "github.com/slackhq/nebula/overlay/virtqueue"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/util/virtio"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@@ -713,16 +714,11 @@ func (t *tun) Close() error {
return nil return nil
} }
func (t *tun) ReadMany(p [][]byte) (int, error) { func (t *tun) ReadMany(p []*packet.VirtIOPacket) (int, error) {
//todo call consumeUsedRing here instead of its own thread n, err := t.vdev.ReceivePackets(p) //we are TXing
n, hdr, err := t.vdev.ReceivePacket(p) //we are TXing
if err != nil { if err != nil {
return 0, err return 0, err
} }
if hdr.NumBuffers > 1 {
t.l.WithField("num_buffers", hdr.NumBuffers).Info("wow, lots to TX from tun")
}
return n, nil return n, nil
} }
@@ -739,7 +735,7 @@ func (t *tun) Write(b []byte) (int, error) {
NumBuffers: 0, NumBuffers: 0,
} }
err := t.vdev.TransmitPacket(hdr, b) err := t.vdev.TransmitPackets(hdr, [][]byte{b})
if err != nil { if err != nil {
t.l.WithError(err).Error("Transmitting packet") t.l.WithError(err).Error("Transmitting packet")
return 0, err return 0, err

View File

@@ -7,6 +7,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/virtqueue" "github.com/slackhq/nebula/overlay/virtqueue"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
@@ -67,8 +68,8 @@ func (d *UserDevice) Close() error {
return nil return nil
} }
func (d *UserDevice) ReadMany(b [][]byte) (int, error) { func (d *UserDevice) ReadMany(b []*packet.VirtIOPacket) (int, error) {
return d.Read(b[0]) return d.Read(b[0].Payload)
} }
func (d *UserDevice) WriteMany(b [][]byte) (int, error) { func (d *UserDevice) WriteMany(b [][]byte) (int, error) {

View File

@@ -4,8 +4,8 @@ import (
"fmt" "fmt"
"unsafe" "unsafe"
"github.com/slackhq/nebula/overlay/virtio"
"github.com/slackhq/nebula/overlay/virtqueue" "github.com/slackhq/nebula/overlay/virtqueue"
"github.com/slackhq/nebula/util/virtio"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )

View File

@@ -1,14 +1,16 @@
package vhostnet package vhostnet
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"os" "os"
"runtime" "runtime"
"github.com/slackhq/nebula/overlay/vhost" "github.com/slackhq/nebula/overlay/vhost"
"github.com/slackhq/nebula/overlay/virtio"
"github.com/slackhq/nebula/overlay/virtqueue" "github.com/slackhq/nebula/overlay/virtqueue"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/util/virtio"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@@ -31,12 +33,7 @@ type Device struct {
ReceiveQueue *virtqueue.SplitQueue ReceiveQueue *virtqueue.SplitQueue
TransmitQueue *virtqueue.SplitQueue TransmitQueue *virtqueue.SplitQueue
// transmitted contains channels for each possible descriptor chain head extraRx []virtqueue.UsedElement
// index. This is used for packet transmit notifications.
// When a packet was transmitted and the descriptor chain was used by the
// device, the corresponding channel receives the [virtqueue.UsedElement]
// instance provided by the device.
transmitted []chan virtqueue.UsedElement
} }
// NewDevice initializes a new vhost networking device within the // NewDevice initializes a new vhost networking device within the
@@ -126,25 +123,6 @@ func NewDevice(options ...Option) (*Device, error) {
return nil, fmt.Errorf("refill receive queue: %w", err) return nil, fmt.Errorf("refill receive queue: %w", err)
} }
// Initialize channels for transmit notifications.
dev.transmitted = make([]chan virtqueue.UsedElement, dev.TransmitQueue.Size())
for i := range len(dev.transmitted) {
// It is important to use a single-element buffered channel here.
// When the channel was unbuffered and the monitorTransmitQueue
// goroutine would write into it, the writing would block which could
// lead to deadlocks in case transmit notifications do not arrive in
// order.
// When the goroutine would use fire-and-forget to write into that
// channel, there may be a chance that the TransmitPacket does not
// receive the transmit notification due to this being a race condition.
// Buffering a single transmit notification resolves this without race
// conditions or possible deadlocks.
dev.transmitted[i] = make(chan virtqueue.UsedElement, 1)
}
// Monitor transmit queue in background.
go dev.monitorTransmitQueue()
dev.initialized = true dev.initialized = true
// Make sure to clean up even when the device gets garbage collected without // Make sure to clean up even when the device gets garbage collected without
@@ -155,32 +133,12 @@ func NewDevice(options ...Option) (*Device, error) {
return devPtr, nil return devPtr, nil
} }
// monitorTransmitQueue waits for the device to advertise used descriptor chains
// in the transmit queue and produces a transmit notification via the
// corresponding channel.
func (dev *Device) monitorTransmitQueue() {
usedChan := dev.TransmitQueue.UsedDescriptorChains()
for {
used, ok := <-usedChan
if !ok {
// The queue was closed.
return
}
if int(used.DescriptorIndex) > len(dev.transmitted) {
panic(fmt.Sprintf("device provided a used descriptor index (%d) that is out of range",
used.DescriptorIndex))
}
dev.transmitted[used.DescriptorIndex] <- used
}
}
// refillReceiveQueue offers as many new device-writable buffers to the device // refillReceiveQueue offers as many new device-writable buffers to the device
// as the queue can fit. The device will then use these to write received // as the queue can fit. The device will then use these to write received
// packets. // packets.
func (dev *Device) refillReceiveQueue() error { func (dev *Device) refillReceiveQueue() error {
for { for {
_, err := dev.ReceiveQueue.OfferDescriptorChain(nil, 1, false) _, err := dev.ReceiveQueue.OfferInDescriptorChains(1)
if err != nil { if err != nil {
if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) { if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
// Queue is full, job is done. // Queue is full, job is done.
@@ -279,38 +237,6 @@ func truncateBuffers(buffers [][]byte, length int) (out [][]byte) {
return return
} }
// TransmitPacket writes the given packet into the transmit queue of this
// device. The packet will be prepended with the [virtio.NetHdr].
//
// When the queue is full, this will block until the queue has enough room to
// transmit the packet. This method will not return before the packet was
// transmitted and the device notifies that it has used the packet buffer.
func (dev *Device) TransmitPacket(vnethdr virtio.NetHdr, packet []byte) error {
// Prepend the packet with its virtio-net header.
vnethdrBuf := make([]byte, virtio.NetHdrSize+14) //todo WHY
if err := vnethdr.Encode(vnethdrBuf); err != nil {
return fmt.Errorf("encode vnethdr: %w", err)
}
vnethdrBuf[virtio.NetHdrSize+14-2] = 0x86
vnethdrBuf[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype
outBuffers := [][]byte{vnethdrBuf, packet}
//outBuffers := [][]byte{packet}
chainIndex, err := dev.TransmitQueue.OfferDescriptorChain(outBuffers, 0, true)
if err != nil {
return fmt.Errorf("offer descriptor chain: %w", err)
}
// Wait for the packet to have been transmitted.
<-dev.transmitted[chainIndex]
if err = dev.TransmitQueue.FreeDescriptorChain(chainIndex); err != nil {
return fmt.Errorf("free descriptor chain: %w", err)
}
return nil
}
func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) error { func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) error {
// Prepend the packet with its virtio-net header. // Prepend the packet with its virtio-net header.
vnethdrBuf := make([]byte, virtio.NetHdrSize+14) //todo WHY vnethdrBuf := make([]byte, virtio.NetHdrSize+14) //todo WHY
@@ -320,15 +246,42 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
vnethdrBuf[virtio.NetHdrSize+14-2] = 0x86 vnethdrBuf[virtio.NetHdrSize+14-2] = 0x86
vnethdrBuf[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype vnethdrBuf[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype
chainIndexes, err := dev.TransmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets, true) chainIndexes, err := dev.TransmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets)
if err != nil { if err != nil {
return fmt.Errorf("offer descriptor chain: %w", err) return fmt.Errorf("offer descriptor chain: %w", err)
} }
//todo surely there's something better to do here
doneYet := map[uint16]bool{}
for _, chain := range chainIndexes {
doneYet[chain] = false
}
for {
txedChains, err := dev.TransmitQueue.BlockAndGetHeads(context.TODO())
if err != nil {
return err
} else if len(txedChains) == 0 {
continue //todo will this ever exit?
}
for c := range txedChains {
doneYet[txedChains[c].GetHead()] = true
}
done := true //optimism!
for _, x := range doneYet {
if !x {
done = false
break
}
}
if done {
break
}
}
//todo blocking here suxxxx //todo blocking here suxxxx
// Wait for the packet to have been transmitted. // Wait for the packet to have been transmitted.
for i := range chainIndexes { for i := range chainIndexes {
<-dev.transmitted[chainIndexes[i]]
if err = dev.TransmitQueue.FreeDescriptorChain(chainIndexes[i]); err != nil { if err = dev.TransmitQueue.FreeDescriptorChain(chainIndexes[i]); err != nil {
return fmt.Errorf("free descriptor chain: %w", err) return fmt.Errorf("free descriptor chain: %w", err)
@@ -338,106 +291,104 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
return nil return nil
} }
// ReceivePacket reads the next available packet from the receive queue of this // TODO: Make above methods cancelable by taking a context.Context argument?
// device and returns its [virtio.NetHdr] and packet data separately. // TODO: Implement zero-copy variants to transmit and receive packets?
//
// When no packet is available, this will block until there is one.
//
// When this method returns an error, the receive queue will likely be in a
// broken state which this implementation cannot recover from. The caller should
// close the device and not attempt any additional receives.
func (dev *Device) ReceivePacket(out []byte) (int, virtio.NetHdr, error) {
var (
chainHeads []uint16
vnethdr virtio.NetHdr // processChains processes as many chains as needed to create one packet. The number of processed chains is returned.
buffers [][]byte func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.UsedElement) (int, error) {
//read first element to see how many descriptors we need:
pkt.Payload = pkt.Payload[:cap(pkt.Payload)]
n, err := dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[0].DescriptorIndex), pkt.Payload)
if err != nil {
return 0, err
}
// The specification requires that the first descriptor chain starts
// with a virtio-net header. It is not clear, whether it is also
// required to be fully contained in the first buffer of that
// descriptor chain, but it is reasonable to assume that this is
// always the case.
// The decode method already does the buffer length check.
if err = pkt.Header.Decode(pkt.Payload[0:]); err != nil {
// The device misbehaved. There is no way we can gracefully
// recover from this, because we don't know how many of the
// following descriptor chains belong to this packet.
return 0, fmt.Errorf("decode vnethdr: %w", err)
}
// Each packet starts with a virtio-net header which we have to subtract //we have the header now: what do we need to do?
// from the total length. if int(pkt.Header.NumBuffers) > len(chains) {
packetLength = -virtio.NetHdrSize return 0, fmt.Errorf("number of buffers is greater than number of chains %d", len(chains))
) }
lenRead := 0 //shift the buffer out of out:
copy(pkt.Payload, pkt.Payload[virtio.NetHdrSize:])
// We presented FeatureNetMergeRXBuffers to the device, so one packet may be cursor := n - virtio.NetHdrSize
// made of multiple descriptor chains which are to be merged.
for remainingChains := 1; remainingChains > 0; remainingChains-- {
// Get the next descriptor chain.
usedElement, ok := <-dev.ReceiveQueue.UsedDescriptorChains()
if !ok {
return 0, virtio.NetHdr{}, ErrDeviceClosed
}
// Track this chain to be freed later. if uint32(n) >= chains[0].Length && pkt.Header.NumBuffers == 1 {
head := uint16(usedElement.DescriptorIndex) pkt.Payload = pkt.Payload[:chains[0].Length-virtio.NetHdrSize]
chainHeads = append(chainHeads, head) return 1, nil
}
n, err := dev.ReceiveQueue.GetDescriptorChainContents(head, out[lenRead:]) i := 1
// we used chain 0 already
for i = 1; i < len(chains); i++ {
n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:])
if err != nil { if err != nil {
// When this fails we may miss to free some descriptor chains. We // When this fails we may miss to free some descriptor chains. We
// could try to mitigate this by deferring the freeing somehow, but // could try to mitigate this by deferring the freeing somehow, but
// it's not worth the hassle. When this method fails, the queue will // it's not worth the hassle. When this method fails, the queue will
// be in a broken state anyway. // be in a broken state anyway.
return 0, virtio.NetHdr{}, fmt.Errorf("get descriptor chain: %w", err) return i, fmt.Errorf("get descriptor chain: %w", err)
} }
lenRead += n cursor += n
packetLength += int(usedElement.Length) }
//todo this has to be wrong
pkt.Payload = pkt.Payload[:cursor]
return i, nil
}
// Is this the first descriptor chain we process? func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) {
if len(buffers) == 0 { //todo optimize?
// The specification requires that the first descriptor chain starts var chains []virtqueue.UsedElement
// with a virtio-net header. It is not clear, whether it is also var err error
// required to be fully contained in the first buffer of that //if len(dev.extraRx) == 0 {
// descriptor chain, but it is reasonable to assume that this is chains, err = dev.ReceiveQueue.BlockAndGetHeadsCapped(context.TODO(), 64) //todo config batch
// always the case. if err != nil {
// The decode method already does the buffer length check. return 0, err
if err = vnethdr.Decode(out[0:]); err != nil { }
// The device misbehaved. There is no way we can gracefully if len(chains) == 0 {
// recover from this, because we don't know how many of the return 0, nil
// following descriptor chains belong to this packet. }
return 0, virtio.NetHdr{}, fmt.Errorf("decode vnethdr: %w", err) //} else {
} // chains = dev.extraRx
lenRead = 0 //}
out = out[virtio.NetHdrSize:]
// The virtio-net header tells us how many descriptor chains this numPackets := 0
// packet is long. chainsIdx := 0
remainingChains = int(vnethdr.NumBuffers) for numPackets = 0; chainsIdx < len(chains); numPackets++ {
if numPackets >= len(out) {
//dev.extraRx = chains[chainsIdx:]
//return numPackets, nil
return numPackets, fmt.Errorf("dropping %d packets, no room", len(chains)-numPackets)
} }
numChains, err := dev.processChains(out[numPackets], chains[chainsIdx:])
//buffers = append(buffers, inBuffers...) if err != nil {
return 0, err
}
chainsIdx += numChains
} }
// Copy all the buffers together to produce the complete packet slice. // Now that we have copied all buffers, we can recycle the used descriptor chains
//out = out[:packetLength] if err := dev.ReceiveQueue.RecycleDescriptorChains(chains); err != nil {
//copied := 0 return 0, err
//for _, buffer := range buffers {
// copied += copy(out[copied:], buffer)
//}
//if copied != packetLength {
// panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", packetLength, copied))
//}
// Now that we have copied all buffers, we can free the used descriptor
// chains again.
// TODO: Recycling the descriptor chains would be more efficient than
// freeing them just to offer them again right after.
for _, head := range chainHeads {
if err := dev.ReceiveQueue.FreeAndOfferDescriptorChains(head); err != nil {
return 0, virtio.NetHdr{}, fmt.Errorf("free descriptor chain with head index %d: %w", head, err)
}
} }
//if we don't churn chains, maybe we don't need this? //if we don't churn chains, maybe we don't need this?
// It's advised to always keep the receive queue fully populated with // It's advised to always keep the rx queue fully populated with available buffers which the device can write new packets into.
// available buffers which the device can write new packets into.
//if err := dev.refillReceiveQueue(); err != nil { //if err := dev.refillReceiveQueue(); err != nil {
// return 0, virtio.NetHdr{}, fmt.Errorf("refill receive queue: %w", err) // return 0, virtio.NetHdr{}, fmt.Errorf("refill receive queue: %w", err)
//} //}
return packetLength, vnethdr, nil return numPackets, nil
} }
// TODO: Make above methods cancelable by taking a context.Context argument?
// TODO: Implement zero-copy variants to transmit and receive packets?

View File

@@ -82,22 +82,61 @@ func (r *AvailableRing) Address() uintptr {
// offer adds the given descriptor chain heads to the available ring and // offer adds the given descriptor chain heads to the available ring and
// advances the ring index accordingly to make the device process the new // advances the ring index accordingly to make the device process the new
// descriptor chains. // descriptor chains.
func (r *AvailableRing) offer(chainHeads []uint16) { func (r *AvailableRing) offerElements(chains []UsedElement) {
//always called under lock //always called under lock
//r.mu.Lock() //r.mu.Lock()
//defer r.mu.Unlock() //defer r.mu.Unlock()
// Add descriptor chain heads to the ring. // Add descriptor chain heads to the ring.
for offset, head := range chainHeads { for offset, x := range chains {
// The 16-bit ring index may overflow. This is expected and is not an // The 16-bit ring index may overflow. This is expected and is not an
// issue because the size of the ring array (which equals the queue // issue because the size of the ring array (which equals the queue
// size) is always a power of 2 and smaller than the highest possible // size) is always a power of 2 and smaller than the highest possible
// 16-bit value. // 16-bit value.
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring) insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
r.ring[insertIndex] = head r.ring[insertIndex] = x.GetHead()
} }
// Increase the ring index by the number of descriptor chains added to the // Increase the ring index by the number of descriptor chains added to the
// ring. // ring.
*r.ringIndex += uint16(len(chainHeads)) *r.ringIndex += uint16(len(chains))
}
func (r *AvailableRing) offer(chains []uint16) {
//always called under lock
//r.mu.Lock()
//defer r.mu.Unlock()
// Add descriptor chain heads to the ring.
for offset, x := range chains {
// The 16-bit ring index may overflow. This is expected and is not an
// issue because the size of the ring array (which equals the queue
// size) is always a power of 2 and smaller than the highest possible
// 16-bit value.
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
r.ring[insertIndex] = x
}
// Increase the ring index by the number of descriptor chains added to the
// ring.
*r.ringIndex += uint16(len(chains))
}
func (r *AvailableRing) offerSingle(x uint16) {
//always called under lock
//r.mu.Lock()
//defer r.mu.Unlock()
offset := 0
// Add descriptor chain heads to the ring.
// The 16-bit ring index may overflow. This is expected and is not an
// issue because the size of the ring array (which equals the queue
// size) is always a power of 2 and smaller than the highest possible
// 16-bit value.
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
r.ring[insertIndex] = x
// Increase the ring index by the number of descriptor chains added to the ring.
*r.ringIndex += 1
} }

View File

@@ -30,9 +30,9 @@ type SplitQueue struct {
// chains and put them in the used ring. // chains and put them in the used ring.
callEventFD eventfd.EventFD callEventFD eventfd.EventFD
// usedChains is a chanel that receives [UsedElement]s for descriptor chains // UsedChains is a chanel that receives [UsedElement]s for descriptor chains
// that were used by the device. // that were used by the device.
usedChains chan UsedElement UsedChains chan UsedElement
// moreFreeDescriptors is a channel that signals when any descriptors were // moreFreeDescriptors is a channel that signals when any descriptors were
// put back into the free chain of the descriptor table. This is used to // put back into the free chain of the descriptor table. This is used to
@@ -51,6 +51,7 @@ type SplitQueue struct {
itemSize int itemSize int
epoll eventfd.Epoll epoll eventfd.Epoll
more int
} }
// NewSplitQueue allocates a new [SplitQueue] in memory. The given queue size // NewSplitQueue allocates a new [SplitQueue] in memory. The given queue size
@@ -131,7 +132,7 @@ func NewSplitQueue(queueSize int) (_ *SplitQueue, err error) {
} }
// Initialize channels. // Initialize channels.
sq.usedChains = make(chan UsedElement, queueSize) sq.UsedChains = make(chan UsedElement, queueSize)
sq.moreFreeDescriptors = make(chan struct{}) sq.moreFreeDescriptors = make(chan struct{})
sq.epoll, err = eventfd.NewEpoll() sq.epoll, err = eventfd.NewEpoll()
@@ -190,20 +191,6 @@ func (sq *SplitQueue) CallEventFD() int {
return sq.callEventFD.FD() return sq.callEventFD.FD()
} }
// UsedDescriptorChains returns the channel that receives [UsedElement]s for all
// descriptor chains that were used by the device.
//
// Users of the [SplitQueue] should read from this channel, handle the used
// descriptor chains and free them using [SplitQueue.FreeDescriptorChain] when
// they're done with them. When this does not happen, the queue will run full
// and any further calls to [SplitQueue.OfferDescriptorChain] will stall.
//
// When [SplitQueue.Close] is called, this channel will be closed as well.
func (sq *SplitQueue) UsedDescriptorChains() chan UsedElement {
sq.ensureInitialized()
return sq.usedChains
}
// startConsumeUsedRing starts a goroutine that runs [consumeUsedRing]. // startConsumeUsedRing starts a goroutine that runs [consumeUsedRing].
// A function is returned that can be used to gracefully cancel it. todo rename // A function is returned that can be used to gracefully cancel it. todo rename
func (sq *SplitQueue) startConsumeUsedRing() func() error { func (sq *SplitQueue) startConsumeUsedRing() func() error {
@@ -225,14 +212,49 @@ func (sq *SplitQueue) BlockAndGetHeads(ctx context.Context) ([]UsedElement, erro
var n int var n int
var err error var err error
for ctx.Err() == nil { for ctx.Err() == nil {
// Wait for a signal from the device. // Wait for a signal from the device.
if n, err = sq.epoll.Block(); err != nil { if n, err = sq.epoll.Block(); err != nil {
return nil, fmt.Errorf("wait: %w", err) return nil, fmt.Errorf("wait: %w", err)
} }
if n > 0 { if n > 0 {
out := sq.usedRing.take() stillNeedToTake, out := sq.usedRing.take(-1)
_ = sq.epoll.Clear() //??? sq.more = stillNeedToTake
if stillNeedToTake == 0 {
_ = sq.epoll.Clear() //???
}
return out, nil
}
}
return nil, ctx.Err()
}
func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) ([]UsedElement, error) {
var n int
var err error
for ctx.Err() == nil {
//we have leftovers in the fridge
if sq.more > 0 {
stillNeedToTake, out := sq.usedRing.take(maxToTake)
sq.more = stillNeedToTake
if stillNeedToTake == 0 {
_ = sq.epoll.Clear() //???
}
return out, nil
}
// Wait for a signal from the device.
if n, err = sq.epoll.Block(); err != nil {
return nil, fmt.Errorf("wait: %w", err)
}
if n > 0 {
stillNeedToTake, out := sq.usedRing.take(maxToTake)
sq.more = stillNeedToTake
if stillNeedToTake == 0 {
_ = sq.epoll.Clear() //???
}
return out, nil return out, nil
} }
} }
@@ -240,12 +262,6 @@ func (sq *SplitQueue) BlockAndGetHeads(ctx context.Context) ([]UsedElement, erro
return nil, ctx.Err() return nil, ctx.Err()
} }
// blockForMoreDescriptors blocks on a channel waiting for more descriptors to free up.
// it is its own function so maybe it might show up in pprof
func (sq *SplitQueue) blockForMoreDescriptors() {
<-sq.moreFreeDescriptors
}
// OfferDescriptorChain offers a descriptor chain to the device which contains a // OfferDescriptorChain offers a descriptor chain to the device which contains a
// number of device-readable buffers (out buffers) and device-writable buffers // number of device-readable buffers (out buffers) and device-writable buffers
// (in buffers). // (in buffers).
@@ -271,63 +287,9 @@ func (sq *SplitQueue) blockForMoreDescriptors() {
// used descriptor chains again using [SplitQueue.FreeDescriptorChain] when // used descriptor chains again using [SplitQueue.FreeDescriptorChain] when
// they're done with them. When this does not happen, the queue will run full // they're done with them. When this does not happen, the queue will run full
// and any further calls to [SplitQueue.OfferDescriptorChain] will stall. // and any further calls to [SplitQueue.OfferDescriptorChain] will stall.
func (sq *SplitQueue) OfferDescriptorChain(outBuffers [][]byte, numInBuffers int, waitFree bool) (uint16, error) {
func (sq *SplitQueue) OfferInDescriptorChains(numInBuffers int) (uint16, error) {
sq.ensureInitialized() sq.ensureInitialized()
// TODO change this
// Each descriptor can only hold a whole memory page, so split large out
// buffers into multiple smaller ones.
outBuffers = splitBuffers(outBuffers, sq.pageSize)
// Synchronize the offering of descriptor chains. While the descriptor table
// and available ring are synchronized on their own as well, this does not
// protect us from interleaved calls which could cause reordering.
// By locking here, we can ensure that all descriptor chains are made
// available to the device in the same order as this method was called.
sq.offerMutex.Lock()
defer sq.offerMutex.Unlock()
// Create a descriptor chain for the given buffers.
var (
head uint16
err error
)
for {
head, err = sq.descriptorTable.createDescriptorChain(outBuffers, numInBuffers)
if err == nil {
break
}
// I don't wanna use errors.Is, it's slow
//goland:noinspection GoDirectComparisonOfErrors
if err == ErrNotEnoughFreeDescriptors {
if waitFree {
// Wait for more free descriptors to be put back into the queue.
// If the number of free descriptors is still not sufficient, we'll
// land here again.
sq.blockForMoreDescriptors()
continue
} else {
return 0, err
}
}
return 0, fmt.Errorf("create descriptor chain: %w", err)
}
// Make the descriptor chain available to the device.
sq.availableRing.offer([]uint16{head})
// Notify the device to make it process the updated available ring.
if err := sq.kickEventFD.Kick(); err != nil {
return head, fmt.Errorf("notify device: %w", err)
}
return head, nil
}
func (sq *SplitQueue) OfferInDescriptorChains(numInBuffers int, waitFree bool) (uint16, error) {
sq.ensureInitialized()
// Synchronize the offering of descriptor chains. While the descriptor table // Synchronize the offering of descriptor chains. While the descriptor table
// and available ring are synchronized on their own as well, this does not // and available ring are synchronized on their own as well, this does not
// protect us from interleaved calls which could cause reordering. // protect us from interleaved calls which could cause reordering.
@@ -350,21 +312,14 @@ func (sq *SplitQueue) OfferInDescriptorChains(numInBuffers int, waitFree bool) (
// I don't wanna use errors.Is, it's slow // I don't wanna use errors.Is, it's slow
//goland:noinspection GoDirectComparisonOfErrors //goland:noinspection GoDirectComparisonOfErrors
if err == ErrNotEnoughFreeDescriptors { if err == ErrNotEnoughFreeDescriptors {
if waitFree { return 0, err
// Wait for more free descriptors to be put back into the queue. } else {
// If the number of free descriptors is still not sufficient, we'll return 0, fmt.Errorf("create descriptor chain: %w", err)
// land here again.
sq.blockForMoreDescriptors()
continue
} else {
return 0, err
}
} }
return 0, fmt.Errorf("create descriptor chain: %w", err)
} }
// Make the descriptor chain available to the device. // Make the descriptor chain available to the device.
sq.availableRing.offer([]uint16{head}) sq.availableRing.offerSingle(head)
// Notify the device to make it process the updated available ring. // Notify the device to make it process the updated available ring.
if err := sq.kickEventFD.Kick(); err != nil { if err := sq.kickEventFD.Kick(); err != nil {
@@ -374,7 +329,7 @@ func (sq *SplitQueue) OfferInDescriptorChains(numInBuffers int, waitFree bool) (
return head, nil return head, nil
} }
func (sq *SplitQueue) OfferOutDescriptorChains(prepend []byte, outBuffers [][]byte, waitFree bool) ([]uint16, error) { func (sq *SplitQueue) OfferOutDescriptorChains(prepend []byte, outBuffers [][]byte) ([]uint16, error) {
sq.ensureInitialized() sq.ensureInitialized()
// TODO change this // TODO change this
@@ -408,15 +363,11 @@ func (sq *SplitQueue) OfferOutDescriptorChains(prepend []byte, outBuffers [][]by
// I don't wanna use errors.Is, it's slow // I don't wanna use errors.Is, it's slow
//goland:noinspection GoDirectComparisonOfErrors //goland:noinspection GoDirectComparisonOfErrors
if err == ErrNotEnoughFreeDescriptors { if err == ErrNotEnoughFreeDescriptors {
if waitFree { // Wait for more free descriptors to be put back into the queue.
// Wait for more free descriptors to be put back into the queue. // If the number of free descriptors is still not sufficient, we'll
// If the number of free descriptors is still not sufficient, we'll // land here again.
// land here again. <-sq.moreFreeDescriptors
sq.blockForMoreDescriptors() continue
continue
} else {
return nil, err
}
} }
return nil, fmt.Errorf("create descriptor chain: %w", err) return nil, fmt.Errorf("create descriptor chain: %w", err)
} }
@@ -473,7 +424,7 @@ func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
// There is more free room in the descriptor table now. // There is more free room in the descriptor table now.
// This is a fire-and-forget signal, so do not block when nobody listens. // This is a fire-and-forget signal, so do not block when nobody listens.
select { select { //todo eliminate
case sq.moreFreeDescriptors <- struct{}{}: case sq.moreFreeDescriptors <- struct{}{}:
default: default:
} }
@@ -481,7 +432,7 @@ func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
return nil return nil
} }
func (sq *SplitQueue) FreeAndOfferDescriptorChains(head uint16) error { func (sq *SplitQueue) RecycleDescriptorChains(chains []UsedElement) error {
sq.ensureInitialized() sq.ensureInitialized()
//todo I don't think we need this here? //todo I don't think we need this here?
@@ -500,7 +451,7 @@ func (sq *SplitQueue) FreeAndOfferDescriptorChains(head uint16) error {
//} //}
// Make the descriptor chain available to the device. // Make the descriptor chain available to the device.
sq.availableRing.offer([]uint16{head}) sq.availableRing.offerElements(chains)
// Notify the device to make it process the updated available ring. // Notify the device to make it process the updated available ring.
if err := sq.kickEventFD.Kick(); err != nil { if err := sq.kickEventFD.Kick(); err != nil {
@@ -524,7 +475,7 @@ func (sq *SplitQueue) Close() error {
// The stop function blocked until the goroutine ended, so the channel // The stop function blocked until the goroutine ended, so the channel
// can now safely be closed. // can now safely be closed.
close(sq.usedChains) close(sq.UsedChains)
// Make sure that this code block is executed only once. // Make sure that this code block is executed only once.
sq.stop = nil sq.stop = nil

View File

@@ -15,3 +15,7 @@ type UsedElement struct {
// the buffer described by the descriptor chain. // the buffer described by the descriptor chain.
Length uint32 Length uint32
} }
func (u *UsedElement) GetHead() uint16 {
return uint16(u.DescriptorIndex)
}

View File

@@ -87,14 +87,14 @@ func (r *UsedRing) Address() uintptr {
// take returns all new [UsedElement]s that the device put into the ring and // take returns all new [UsedElement]s that the device put into the ring and
// that weren't already returned by a previous call to this method. // that weren't already returned by a previous call to this method.
// had a lock, I removed it // had a lock, I removed it
func (r *UsedRing) take() []UsedElement { func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
//r.mu.Lock() //r.mu.Lock()
//defer r.mu.Unlock() //defer r.mu.Unlock()
ringIndex := *r.ringIndex ringIndex := *r.ringIndex
if ringIndex == r.lastIndex { if ringIndex == r.lastIndex {
// Nothing new. // Nothing new.
return nil return 0, nil
} }
// Calculate the number new used elements that we can read from the ring. // Calculate the number new used elements that we can read from the ring.
@@ -104,6 +104,16 @@ func (r *UsedRing) take() []UsedElement {
count += 0xffff count += 0xffff
} }
stillNeedToTake := 0
if maxToTake > 0 {
stillNeedToTake = count - maxToTake
if stillNeedToTake < 0 {
stillNeedToTake = 0
}
count = min(count, maxToTake)
}
// The number of new elements can never exceed the queue size. // The number of new elements can never exceed the queue size.
if count > len(r.ring) { if count > len(r.ring) {
panic("used ring contains more new elements than the ring is long") panic("used ring contains more new elements than the ring is long")
@@ -115,5 +125,5 @@ func (r *UsedRing) take() []UsedElement {
r.lastIndex++ r.lastIndex++
} }
return elems return stillNeedToTake, elems
} }

16
packet/virtio.go Normal file
View File

@@ -0,0 +1,16 @@
package packet
import (
"github.com/slackhq/nebula/util/virtio"
)
type VirtIOPacket struct {
Payload []byte
Header virtio.NetHdr
}
func NewVIO() *VirtIOPacket {
out := new(VirtIOPacket)
out.Payload = make([]byte, Size)
return out
}