From 1719149594ebf0cde9387bdc7468fb404875c960 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Tue, 11 Nov 2025 15:37:54 -0600 Subject: [PATCH] broken chkpt --- interface.go | 30 +-- overlay/tun.go | 5 +- overlay/tun_disabled.go | 5 +- overlay/tun_linux.go | 14 +- overlay/user.go | 5 +- overlay/vhost/ioctl.go | 2 +- overlay/vhostnet/device.go | 269 +++++++++-------------- overlay/virtqueue/available_ring.go | 47 +++- overlay/virtqueue/split_virtqueue.go | 165 +++++--------- overlay/virtqueue/used_element.go | 4 + overlay/virtqueue/used_ring.go | 16 +- packet/virtio.go | 16 ++ {overlay => util}/virtio/doc.go | 0 {overlay => util}/virtio/features.go | 0 {overlay => util}/virtio/net_hdr.go | 0 {overlay => util}/virtio/net_hdr_test.go | 0 16 files changed, 275 insertions(+), 303 deletions(-) create mode 100644 packet/virtio.go rename {overlay => util}/virtio/doc.go (100%) rename {overlay => util}/virtio/features.go (100%) rename {overlay => util}/virtio/net_hdr.go (100%) rename {overlay => util}/virtio/net_hdr_test.go (100%) diff --git a/interface.go b/interface.go index 2b13931..a814b4b 100644 --- a/interface.go +++ b/interface.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io" "net/netip" "os" "runtime" @@ -18,7 +17,6 @@ import ( "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/overlay" - "github.com/slackhq/nebula/overlay/virtio" "github.com/slackhq/nebula/packet" "github.com/slackhq/nebula/udp" ) @@ -270,6 +268,7 @@ func (f *Interface) listenOut(q int) { ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) lhh := f.lightHouse.NewRequestHandler() + outPackets := make([]*packet.OutPacket, batch) for i := 0; i < batch; i++ { outPackets[i] = packet.NewOut() @@ -295,16 +294,15 @@ func (f *Interface) listenOut(q int) { if len(outPackets[i].Segments[j]) > 0 { toSend = append(toSend, outPackets[i].Segments[j]) } - } - //toSend = append(toSend, outPackets[i]) - //toSendCount++ } } //toSend = toSend[:toSendCount] - _, err := f.readers[q].WriteMany(toSend) - if err != nil { - f.l.WithError(err).Error("Failed to write messages") + if len(toSend) != 0 { + _, err := f.readers[q].WriteMany(toSend) + if err != nil { + f.l.WithError(err).Error("Failed to write messages") + } } }) } @@ -323,17 +321,15 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) { conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) - queues := reader.GetQueues() - if len(queues) == 0 { - f.l.Fatal("Failed to get queues") + packets := make([]*packet.VirtIOPacket, batch) + for i := 0; i < batch; i++ { + packets[i] = packet.NewVIO() } - queue := queues[0] for { - n, err := reader.ReadMany(originalPacket) + n, err := reader.ReadMany(packets) //todo!! - pkt := originalPacket[virtio.NetHdrSize : n+virtio.NetHdrSize] if err != nil { if errors.Is(err, os.ErrClosed) && f.closed.Load() { return @@ -344,7 +340,11 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) { os.Exit(2) } - f.consumeInsidePacket(pkt, fwPacket, nb, out, queueNum, conntrackCache.Get(f.l)) + //todo vectorize + for _, pkt := range packets[:n] { + f.consumeInsidePacket(pkt.Payload, fwPacket, nb, out, queueNum, conntrackCache.Get(f.l)) + } + } } diff --git a/overlay/tun.go b/overlay/tun.go index b58d6a8..fdf8a55 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -2,19 +2,22 @@ package overlay import ( "fmt" + "io" "net" "net/netip" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/overlay/virtqueue" + "github.com/slackhq/nebula/packet" "github.com/slackhq/nebula/util" ) const DefaultMTU = 1300 type TunDev interface { - ReadMany([][]byte) (int, error) + io.WriteCloser + ReadMany([]*packet.VirtIOPacket) (int, error) WriteMany([][]byte) (int, error) GetQueues() []*virtqueue.SplitQueue } diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index f2e9c6b..1adb062 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -10,6 +10,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/overlay/virtqueue" + "github.com/slackhq/nebula/packet" "github.com/slackhq/nebula/routing" ) @@ -122,8 +123,8 @@ func (t *disabledTun) WriteMany(b [][]byte) (int, error) { return out, nil } -func (t *disabledTun) ReadMany(b [][]byte) (int, error) { - return t.Read(b[0]) +func (t *disabledTun) ReadMany(b []*packet.VirtIOPacket) (int, error) { + return t.Read(b[0].Payload) } func (t *disabledTun) NewMultiQueueReader() (TunDev, error) { diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index f663ee0..9c65162 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -18,10 +18,11 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/overlay/vhostnet" - "github.com/slackhq/nebula/overlay/virtio" "github.com/slackhq/nebula/overlay/virtqueue" + "github.com/slackhq/nebula/packet" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" + "github.com/slackhq/nebula/util/virtio" "github.com/vishvananda/netlink" "golang.org/x/sys/unix" ) @@ -713,16 +714,11 @@ func (t *tun) Close() error { return nil } -func (t *tun) ReadMany(p [][]byte) (int, error) { - //todo call consumeUsedRing here instead of its own thread - - n, hdr, err := t.vdev.ReceivePacket(p) //we are TXing +func (t *tun) ReadMany(p []*packet.VirtIOPacket) (int, error) { + n, err := t.vdev.ReceivePackets(p) //we are TXing if err != nil { return 0, err } - if hdr.NumBuffers > 1 { - t.l.WithField("num_buffers", hdr.NumBuffers).Info("wow, lots to TX from tun") - } return n, nil } @@ -739,7 +735,7 @@ func (t *tun) Write(b []byte) (int, error) { NumBuffers: 0, } - err := t.vdev.TransmitPacket(hdr, b) + err := t.vdev.TransmitPackets(hdr, [][]byte{b}) if err != nil { t.l.WithError(err).Error("Transmitting packet") return 0, err diff --git a/overlay/user.go b/overlay/user.go index 34b359f..62cf786 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -7,6 +7,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/overlay/virtqueue" + "github.com/slackhq/nebula/packet" "github.com/slackhq/nebula/routing" ) @@ -67,8 +68,8 @@ func (d *UserDevice) Close() error { return nil } -func (d *UserDevice) ReadMany(b [][]byte) (int, error) { - return d.Read(b[0]) +func (d *UserDevice) ReadMany(b []*packet.VirtIOPacket) (int, error) { + return d.Read(b[0].Payload) } func (d *UserDevice) WriteMany(b [][]byte) (int, error) { diff --git a/overlay/vhost/ioctl.go b/overlay/vhost/ioctl.go index e24604a..7e855db 100644 --- a/overlay/vhost/ioctl.go +++ b/overlay/vhost/ioctl.go @@ -4,8 +4,8 @@ import ( "fmt" "unsafe" - "github.com/slackhq/nebula/overlay/virtio" "github.com/slackhq/nebula/overlay/virtqueue" + "github.com/slackhq/nebula/util/virtio" "golang.org/x/sys/unix" ) diff --git a/overlay/vhostnet/device.go b/overlay/vhostnet/device.go index f4f132c..e4e65cd 100644 --- a/overlay/vhostnet/device.go +++ b/overlay/vhostnet/device.go @@ -1,14 +1,16 @@ package vhostnet import ( + "context" "errors" "fmt" "os" "runtime" "github.com/slackhq/nebula/overlay/vhost" - "github.com/slackhq/nebula/overlay/virtio" "github.com/slackhq/nebula/overlay/virtqueue" + "github.com/slackhq/nebula/packet" + "github.com/slackhq/nebula/util/virtio" "golang.org/x/sys/unix" ) @@ -31,12 +33,7 @@ type Device struct { ReceiveQueue *virtqueue.SplitQueue TransmitQueue *virtqueue.SplitQueue - // transmitted contains channels for each possible descriptor chain head - // index. This is used for packet transmit notifications. - // When a packet was transmitted and the descriptor chain was used by the - // device, the corresponding channel receives the [virtqueue.UsedElement] - // instance provided by the device. - transmitted []chan virtqueue.UsedElement + extraRx []virtqueue.UsedElement } // NewDevice initializes a new vhost networking device within the @@ -126,25 +123,6 @@ func NewDevice(options ...Option) (*Device, error) { return nil, fmt.Errorf("refill receive queue: %w", err) } - // Initialize channels for transmit notifications. - dev.transmitted = make([]chan virtqueue.UsedElement, dev.TransmitQueue.Size()) - for i := range len(dev.transmitted) { - // It is important to use a single-element buffered channel here. - // When the channel was unbuffered and the monitorTransmitQueue - // goroutine would write into it, the writing would block which could - // lead to deadlocks in case transmit notifications do not arrive in - // order. - // When the goroutine would use fire-and-forget to write into that - // channel, there may be a chance that the TransmitPacket does not - // receive the transmit notification due to this being a race condition. - // Buffering a single transmit notification resolves this without race - // conditions or possible deadlocks. - dev.transmitted[i] = make(chan virtqueue.UsedElement, 1) - } - - // Monitor transmit queue in background. - go dev.monitorTransmitQueue() - dev.initialized = true // Make sure to clean up even when the device gets garbage collected without @@ -155,32 +133,12 @@ func NewDevice(options ...Option) (*Device, error) { return devPtr, nil } -// monitorTransmitQueue waits for the device to advertise used descriptor chains -// in the transmit queue and produces a transmit notification via the -// corresponding channel. -func (dev *Device) monitorTransmitQueue() { - usedChan := dev.TransmitQueue.UsedDescriptorChains() - for { - used, ok := <-usedChan - if !ok { - // The queue was closed. - return - } - if int(used.DescriptorIndex) > len(dev.transmitted) { - panic(fmt.Sprintf("device provided a used descriptor index (%d) that is out of range", - used.DescriptorIndex)) - } - - dev.transmitted[used.DescriptorIndex] <- used - } -} - // refillReceiveQueue offers as many new device-writable buffers to the device // as the queue can fit. The device will then use these to write received // packets. func (dev *Device) refillReceiveQueue() error { for { - _, err := dev.ReceiveQueue.OfferDescriptorChain(nil, 1, false) + _, err := dev.ReceiveQueue.OfferInDescriptorChains(1) if err != nil { if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) { // Queue is full, job is done. @@ -279,38 +237,6 @@ func truncateBuffers(buffers [][]byte, length int) (out [][]byte) { return } -// TransmitPacket writes the given packet into the transmit queue of this -// device. The packet will be prepended with the [virtio.NetHdr]. -// -// When the queue is full, this will block until the queue has enough room to -// transmit the packet. This method will not return before the packet was -// transmitted and the device notifies that it has used the packet buffer. -func (dev *Device) TransmitPacket(vnethdr virtio.NetHdr, packet []byte) error { - // Prepend the packet with its virtio-net header. - vnethdrBuf := make([]byte, virtio.NetHdrSize+14) //todo WHY - if err := vnethdr.Encode(vnethdrBuf); err != nil { - return fmt.Errorf("encode vnethdr: %w", err) - } - vnethdrBuf[virtio.NetHdrSize+14-2] = 0x86 - vnethdrBuf[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype - outBuffers := [][]byte{vnethdrBuf, packet} - //outBuffers := [][]byte{packet} - - chainIndex, err := dev.TransmitQueue.OfferDescriptorChain(outBuffers, 0, true) - if err != nil { - return fmt.Errorf("offer descriptor chain: %w", err) - } - - // Wait for the packet to have been transmitted. - <-dev.transmitted[chainIndex] - - if err = dev.TransmitQueue.FreeDescriptorChain(chainIndex); err != nil { - return fmt.Errorf("free descriptor chain: %w", err) - } - - return nil -} - func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) error { // Prepend the packet with its virtio-net header. vnethdrBuf := make([]byte, virtio.NetHdrSize+14) //todo WHY @@ -320,15 +246,42 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro vnethdrBuf[virtio.NetHdrSize+14-2] = 0x86 vnethdrBuf[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype - chainIndexes, err := dev.TransmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets, true) + chainIndexes, err := dev.TransmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets) if err != nil { return fmt.Errorf("offer descriptor chain: %w", err) } + //todo surely there's something better to do here + doneYet := map[uint16]bool{} + for _, chain := range chainIndexes { + doneYet[chain] = false + } + + for { + txedChains, err := dev.TransmitQueue.BlockAndGetHeads(context.TODO()) + if err != nil { + return err + } else if len(txedChains) == 0 { + continue //todo will this ever exit? + } + for c := range txedChains { + doneYet[txedChains[c].GetHead()] = true + } + done := true //optimism! + for _, x := range doneYet { + if !x { + done = false + break + } + } + + if done { + break + } + } //todo blocking here suxxxx // Wait for the packet to have been transmitted. for i := range chainIndexes { - <-dev.transmitted[chainIndexes[i]] if err = dev.TransmitQueue.FreeDescriptorChain(chainIndexes[i]); err != nil { return fmt.Errorf("free descriptor chain: %w", err) @@ -338,106 +291,104 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro return nil } -// ReceivePacket reads the next available packet from the receive queue of this -// device and returns its [virtio.NetHdr] and packet data separately. -// -// When no packet is available, this will block until there is one. -// -// When this method returns an error, the receive queue will likely be in a -// broken state which this implementation cannot recover from. The caller should -// close the device and not attempt any additional receives. -func (dev *Device) ReceivePacket(out []byte) (int, virtio.NetHdr, error) { - var ( - chainHeads []uint16 +// TODO: Make above methods cancelable by taking a context.Context argument? +// TODO: Implement zero-copy variants to transmit and receive packets? - vnethdr virtio.NetHdr - buffers [][]byte +// processChains processes as many chains as needed to create one packet. The number of processed chains is returned. +func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.UsedElement) (int, error) { + //read first element to see how many descriptors we need: + pkt.Payload = pkt.Payload[:cap(pkt.Payload)] + n, err := dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[0].DescriptorIndex), pkt.Payload) + if err != nil { + return 0, err + } + // The specification requires that the first descriptor chain starts + // with a virtio-net header. It is not clear, whether it is also + // required to be fully contained in the first buffer of that + // descriptor chain, but it is reasonable to assume that this is + // always the case. + // The decode method already does the buffer length check. + if err = pkt.Header.Decode(pkt.Payload[0:]); err != nil { + // The device misbehaved. There is no way we can gracefully + // recover from this, because we don't know how many of the + // following descriptor chains belong to this packet. + return 0, fmt.Errorf("decode vnethdr: %w", err) + } - // Each packet starts with a virtio-net header which we have to subtract - // from the total length. - packetLength = -virtio.NetHdrSize - ) + //we have the header now: what do we need to do? + if int(pkt.Header.NumBuffers) > len(chains) { + return 0, fmt.Errorf("number of buffers is greater than number of chains %d", len(chains)) + } - lenRead := 0 + //shift the buffer out of out: + copy(pkt.Payload, pkt.Payload[virtio.NetHdrSize:]) - // We presented FeatureNetMergeRXBuffers to the device, so one packet may be - // made of multiple descriptor chains which are to be merged. - for remainingChains := 1; remainingChains > 0; remainingChains-- { - // Get the next descriptor chain. - usedElement, ok := <-dev.ReceiveQueue.UsedDescriptorChains() - if !ok { - return 0, virtio.NetHdr{}, ErrDeviceClosed - } + cursor := n - virtio.NetHdrSize - // Track this chain to be freed later. - head := uint16(usedElement.DescriptorIndex) - chainHeads = append(chainHeads, head) + if uint32(n) >= chains[0].Length && pkt.Header.NumBuffers == 1 { + pkt.Payload = pkt.Payload[:chains[0].Length-virtio.NetHdrSize] + return 1, nil + } - n, err := dev.ReceiveQueue.GetDescriptorChainContents(head, out[lenRead:]) + i := 1 + // we used chain 0 already + for i = 1; i < len(chains); i++ { + n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:]) if err != nil { // When this fails we may miss to free some descriptor chains. We // could try to mitigate this by deferring the freeing somehow, but // it's not worth the hassle. When this method fails, the queue will // be in a broken state anyway. - return 0, virtio.NetHdr{}, fmt.Errorf("get descriptor chain: %w", err) + return i, fmt.Errorf("get descriptor chain: %w", err) } - lenRead += n - packetLength += int(usedElement.Length) + cursor += n + } + //todo this has to be wrong + pkt.Payload = pkt.Payload[:cursor] + return i, nil +} - // Is this the first descriptor chain we process? - if len(buffers) == 0 { - // The specification requires that the first descriptor chain starts - // with a virtio-net header. It is not clear, whether it is also - // required to be fully contained in the first buffer of that - // descriptor chain, but it is reasonable to assume that this is - // always the case. - // The decode method already does the buffer length check. - if err = vnethdr.Decode(out[0:]); err != nil { - // The device misbehaved. There is no way we can gracefully - // recover from this, because we don't know how many of the - // following descriptor chains belong to this packet. - return 0, virtio.NetHdr{}, fmt.Errorf("decode vnethdr: %w", err) - } - lenRead = 0 - out = out[virtio.NetHdrSize:] +func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) { + //todo optimize? + var chains []virtqueue.UsedElement + var err error + //if len(dev.extraRx) == 0 { + chains, err = dev.ReceiveQueue.BlockAndGetHeadsCapped(context.TODO(), 64) //todo config batch + if err != nil { + return 0, err + } + if len(chains) == 0 { + return 0, nil + } + //} else { + // chains = dev.extraRx + //} - // The virtio-net header tells us how many descriptor chains this - // packet is long. - remainingChains = int(vnethdr.NumBuffers) + numPackets := 0 + chainsIdx := 0 + for numPackets = 0; chainsIdx < len(chains); numPackets++ { + if numPackets >= len(out) { + //dev.extraRx = chains[chainsIdx:] + //return numPackets, nil + return numPackets, fmt.Errorf("dropping %d packets, no room", len(chains)-numPackets) } - - //buffers = append(buffers, inBuffers...) + numChains, err := dev.processChains(out[numPackets], chains[chainsIdx:]) + if err != nil { + return 0, err + } + chainsIdx += numChains } - // Copy all the buffers together to produce the complete packet slice. - //out = out[:packetLength] - //copied := 0 - //for _, buffer := range buffers { - // copied += copy(out[copied:], buffer) - //} - //if copied != packetLength { - // panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", packetLength, copied)) - //} - - // Now that we have copied all buffers, we can free the used descriptor - // chains again. - // TODO: Recycling the descriptor chains would be more efficient than - // freeing them just to offer them again right after. - for _, head := range chainHeads { - if err := dev.ReceiveQueue.FreeAndOfferDescriptorChains(head); err != nil { - return 0, virtio.NetHdr{}, fmt.Errorf("free descriptor chain with head index %d: %w", head, err) - } + // Now that we have copied all buffers, we can recycle the used descriptor chains + if err := dev.ReceiveQueue.RecycleDescriptorChains(chains); err != nil { + return 0, err } //if we don't churn chains, maybe we don't need this? - // It's advised to always keep the receive queue fully populated with - // available buffers which the device can write new packets into. + // It's advised to always keep the rx queue fully populated with available buffers which the device can write new packets into. //if err := dev.refillReceiveQueue(); err != nil { // return 0, virtio.NetHdr{}, fmt.Errorf("refill receive queue: %w", err) //} - return packetLength, vnethdr, nil + return numPackets, nil } - -// TODO: Make above methods cancelable by taking a context.Context argument? -// TODO: Implement zero-copy variants to transmit and receive packets? diff --git a/overlay/virtqueue/available_ring.go b/overlay/virtqueue/available_ring.go index b57c167..a73afa2 100644 --- a/overlay/virtqueue/available_ring.go +++ b/overlay/virtqueue/available_ring.go @@ -82,22 +82,61 @@ func (r *AvailableRing) Address() uintptr { // offer adds the given descriptor chain heads to the available ring and // advances the ring index accordingly to make the device process the new // descriptor chains. -func (r *AvailableRing) offer(chainHeads []uint16) { +func (r *AvailableRing) offerElements(chains []UsedElement) { //always called under lock //r.mu.Lock() //defer r.mu.Unlock() // Add descriptor chain heads to the ring. - for offset, head := range chainHeads { + for offset, x := range chains { // The 16-bit ring index may overflow. This is expected and is not an // issue because the size of the ring array (which equals the queue // size) is always a power of 2 and smaller than the highest possible // 16-bit value. insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring) - r.ring[insertIndex] = head + r.ring[insertIndex] = x.GetHead() } // Increase the ring index by the number of descriptor chains added to the // ring. - *r.ringIndex += uint16(len(chainHeads)) + *r.ringIndex += uint16(len(chains)) +} + +func (r *AvailableRing) offer(chains []uint16) { + //always called under lock + //r.mu.Lock() + //defer r.mu.Unlock() + + // Add descriptor chain heads to the ring. + for offset, x := range chains { + // The 16-bit ring index may overflow. This is expected and is not an + // issue because the size of the ring array (which equals the queue + // size) is always a power of 2 and smaller than the highest possible + // 16-bit value. + insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring) + r.ring[insertIndex] = x + } + + // Increase the ring index by the number of descriptor chains added to the + // ring. + *r.ringIndex += uint16(len(chains)) +} + +func (r *AvailableRing) offerSingle(x uint16) { + //always called under lock + //r.mu.Lock() + //defer r.mu.Unlock() + + offset := 0 + // Add descriptor chain heads to the ring. + + // The 16-bit ring index may overflow. This is expected and is not an + // issue because the size of the ring array (which equals the queue + // size) is always a power of 2 and smaller than the highest possible + // 16-bit value. + insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring) + r.ring[insertIndex] = x + + // Increase the ring index by the number of descriptor chains added to the ring. + *r.ringIndex += 1 } diff --git a/overlay/virtqueue/split_virtqueue.go b/overlay/virtqueue/split_virtqueue.go index 8f17a50..5d2620d 100644 --- a/overlay/virtqueue/split_virtqueue.go +++ b/overlay/virtqueue/split_virtqueue.go @@ -30,9 +30,9 @@ type SplitQueue struct { // chains and put them in the used ring. callEventFD eventfd.EventFD - // usedChains is a chanel that receives [UsedElement]s for descriptor chains + // UsedChains is a chanel that receives [UsedElement]s for descriptor chains // that were used by the device. - usedChains chan UsedElement + UsedChains chan UsedElement // moreFreeDescriptors is a channel that signals when any descriptors were // put back into the free chain of the descriptor table. This is used to @@ -51,6 +51,7 @@ type SplitQueue struct { itemSize int epoll eventfd.Epoll + more int } // NewSplitQueue allocates a new [SplitQueue] in memory. The given queue size @@ -131,7 +132,7 @@ func NewSplitQueue(queueSize int) (_ *SplitQueue, err error) { } // Initialize channels. - sq.usedChains = make(chan UsedElement, queueSize) + sq.UsedChains = make(chan UsedElement, queueSize) sq.moreFreeDescriptors = make(chan struct{}) sq.epoll, err = eventfd.NewEpoll() @@ -190,20 +191,6 @@ func (sq *SplitQueue) CallEventFD() int { return sq.callEventFD.FD() } -// UsedDescriptorChains returns the channel that receives [UsedElement]s for all -// descriptor chains that were used by the device. -// -// Users of the [SplitQueue] should read from this channel, handle the used -// descriptor chains and free them using [SplitQueue.FreeDescriptorChain] when -// they're done with them. When this does not happen, the queue will run full -// and any further calls to [SplitQueue.OfferDescriptorChain] will stall. -// -// When [SplitQueue.Close] is called, this channel will be closed as well. -func (sq *SplitQueue) UsedDescriptorChains() chan UsedElement { - sq.ensureInitialized() - return sq.usedChains -} - // startConsumeUsedRing starts a goroutine that runs [consumeUsedRing]. // A function is returned that can be used to gracefully cancel it. todo rename func (sq *SplitQueue) startConsumeUsedRing() func() error { @@ -225,14 +212,49 @@ func (sq *SplitQueue) BlockAndGetHeads(ctx context.Context) ([]UsedElement, erro var n int var err error for ctx.Err() == nil { + // Wait for a signal from the device. if n, err = sq.epoll.Block(); err != nil { return nil, fmt.Errorf("wait: %w", err) } - if n > 0 { - out := sq.usedRing.take() - _ = sq.epoll.Clear() //??? + stillNeedToTake, out := sq.usedRing.take(-1) + sq.more = stillNeedToTake + if stillNeedToTake == 0 { + _ = sq.epoll.Clear() //??? + } + return out, nil + } + } + return nil, ctx.Err() +} + +func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) ([]UsedElement, error) { + var n int + var err error + for ctx.Err() == nil { + + //we have leftovers in the fridge + if sq.more > 0 { + stillNeedToTake, out := sq.usedRing.take(maxToTake) + sq.more = stillNeedToTake + if stillNeedToTake == 0 { + _ = sq.epoll.Clear() //??? + } + + return out, nil + } + + // Wait for a signal from the device. + if n, err = sq.epoll.Block(); err != nil { + return nil, fmt.Errorf("wait: %w", err) + } + if n > 0 { + stillNeedToTake, out := sq.usedRing.take(maxToTake) + sq.more = stillNeedToTake + if stillNeedToTake == 0 { + _ = sq.epoll.Clear() //??? + } return out, nil } } @@ -240,12 +262,6 @@ func (sq *SplitQueue) BlockAndGetHeads(ctx context.Context) ([]UsedElement, erro return nil, ctx.Err() } -// blockForMoreDescriptors blocks on a channel waiting for more descriptors to free up. -// it is its own function so maybe it might show up in pprof -func (sq *SplitQueue) blockForMoreDescriptors() { - <-sq.moreFreeDescriptors -} - // OfferDescriptorChain offers a descriptor chain to the device which contains a // number of device-readable buffers (out buffers) and device-writable buffers // (in buffers). @@ -271,63 +287,9 @@ func (sq *SplitQueue) blockForMoreDescriptors() { // used descriptor chains again using [SplitQueue.FreeDescriptorChain] when // they're done with them. When this does not happen, the queue will run full // and any further calls to [SplitQueue.OfferDescriptorChain] will stall. -func (sq *SplitQueue) OfferDescriptorChain(outBuffers [][]byte, numInBuffers int, waitFree bool) (uint16, error) { + +func (sq *SplitQueue) OfferInDescriptorChains(numInBuffers int) (uint16, error) { sq.ensureInitialized() - - // TODO change this - // Each descriptor can only hold a whole memory page, so split large out - // buffers into multiple smaller ones. - outBuffers = splitBuffers(outBuffers, sq.pageSize) - - // Synchronize the offering of descriptor chains. While the descriptor table - // and available ring are synchronized on their own as well, this does not - // protect us from interleaved calls which could cause reordering. - // By locking here, we can ensure that all descriptor chains are made - // available to the device in the same order as this method was called. - sq.offerMutex.Lock() - defer sq.offerMutex.Unlock() - - // Create a descriptor chain for the given buffers. - var ( - head uint16 - err error - ) - for { - head, err = sq.descriptorTable.createDescriptorChain(outBuffers, numInBuffers) - if err == nil { - break - } - - // I don't wanna use errors.Is, it's slow - //goland:noinspection GoDirectComparisonOfErrors - if err == ErrNotEnoughFreeDescriptors { - if waitFree { - // Wait for more free descriptors to be put back into the queue. - // If the number of free descriptors is still not sufficient, we'll - // land here again. - sq.blockForMoreDescriptors() - continue - } else { - return 0, err - } - } - return 0, fmt.Errorf("create descriptor chain: %w", err) - } - - // Make the descriptor chain available to the device. - sq.availableRing.offer([]uint16{head}) - - // Notify the device to make it process the updated available ring. - if err := sq.kickEventFD.Kick(); err != nil { - return head, fmt.Errorf("notify device: %w", err) - } - - return head, nil -} - -func (sq *SplitQueue) OfferInDescriptorChains(numInBuffers int, waitFree bool) (uint16, error) { - sq.ensureInitialized() - // Synchronize the offering of descriptor chains. While the descriptor table // and available ring are synchronized on their own as well, this does not // protect us from interleaved calls which could cause reordering. @@ -350,21 +312,14 @@ func (sq *SplitQueue) OfferInDescriptorChains(numInBuffers int, waitFree bool) ( // I don't wanna use errors.Is, it's slow //goland:noinspection GoDirectComparisonOfErrors if err == ErrNotEnoughFreeDescriptors { - if waitFree { - // Wait for more free descriptors to be put back into the queue. - // If the number of free descriptors is still not sufficient, we'll - // land here again. - sq.blockForMoreDescriptors() - continue - } else { - return 0, err - } + return 0, err + } else { + return 0, fmt.Errorf("create descriptor chain: %w", err) } - return 0, fmt.Errorf("create descriptor chain: %w", err) } // Make the descriptor chain available to the device. - sq.availableRing.offer([]uint16{head}) + sq.availableRing.offerSingle(head) // Notify the device to make it process the updated available ring. if err := sq.kickEventFD.Kick(); err != nil { @@ -374,7 +329,7 @@ func (sq *SplitQueue) OfferInDescriptorChains(numInBuffers int, waitFree bool) ( return head, nil } -func (sq *SplitQueue) OfferOutDescriptorChains(prepend []byte, outBuffers [][]byte, waitFree bool) ([]uint16, error) { +func (sq *SplitQueue) OfferOutDescriptorChains(prepend []byte, outBuffers [][]byte) ([]uint16, error) { sq.ensureInitialized() // TODO change this @@ -408,15 +363,11 @@ func (sq *SplitQueue) OfferOutDescriptorChains(prepend []byte, outBuffers [][]by // I don't wanna use errors.Is, it's slow //goland:noinspection GoDirectComparisonOfErrors if err == ErrNotEnoughFreeDescriptors { - if waitFree { - // Wait for more free descriptors to be put back into the queue. - // If the number of free descriptors is still not sufficient, we'll - // land here again. - sq.blockForMoreDescriptors() - continue - } else { - return nil, err - } + // Wait for more free descriptors to be put back into the queue. + // If the number of free descriptors is still not sufficient, we'll + // land here again. + <-sq.moreFreeDescriptors + continue } return nil, fmt.Errorf("create descriptor chain: %w", err) } @@ -473,7 +424,7 @@ func (sq *SplitQueue) FreeDescriptorChain(head uint16) error { // There is more free room in the descriptor table now. // This is a fire-and-forget signal, so do not block when nobody listens. - select { + select { //todo eliminate case sq.moreFreeDescriptors <- struct{}{}: default: } @@ -481,7 +432,7 @@ func (sq *SplitQueue) FreeDescriptorChain(head uint16) error { return nil } -func (sq *SplitQueue) FreeAndOfferDescriptorChains(head uint16) error { +func (sq *SplitQueue) RecycleDescriptorChains(chains []UsedElement) error { sq.ensureInitialized() //todo I don't think we need this here? @@ -500,7 +451,7 @@ func (sq *SplitQueue) FreeAndOfferDescriptorChains(head uint16) error { //} // Make the descriptor chain available to the device. - sq.availableRing.offer([]uint16{head}) + sq.availableRing.offerElements(chains) // Notify the device to make it process the updated available ring. if err := sq.kickEventFD.Kick(); err != nil { @@ -524,7 +475,7 @@ func (sq *SplitQueue) Close() error { // The stop function blocked until the goroutine ended, so the channel // can now safely be closed. - close(sq.usedChains) + close(sq.UsedChains) // Make sure that this code block is executed only once. sq.stop = nil diff --git a/overlay/virtqueue/used_element.go b/overlay/virtqueue/used_element.go index 7348d87..a4d5d26 100644 --- a/overlay/virtqueue/used_element.go +++ b/overlay/virtqueue/used_element.go @@ -15,3 +15,7 @@ type UsedElement struct { // the buffer described by the descriptor chain. Length uint32 } + +func (u *UsedElement) GetHead() uint16 { + return uint16(u.DescriptorIndex) +} diff --git a/overlay/virtqueue/used_ring.go b/overlay/virtqueue/used_ring.go index 1d7cd2e..c08b48b 100644 --- a/overlay/virtqueue/used_ring.go +++ b/overlay/virtqueue/used_ring.go @@ -87,14 +87,14 @@ func (r *UsedRing) Address() uintptr { // take returns all new [UsedElement]s that the device put into the ring and // that weren't already returned by a previous call to this method. // had a lock, I removed it -func (r *UsedRing) take() []UsedElement { +func (r *UsedRing) take(maxToTake int) (int, []UsedElement) { //r.mu.Lock() //defer r.mu.Unlock() ringIndex := *r.ringIndex if ringIndex == r.lastIndex { // Nothing new. - return nil + return 0, nil } // Calculate the number new used elements that we can read from the ring. @@ -104,6 +104,16 @@ func (r *UsedRing) take() []UsedElement { count += 0xffff } + stillNeedToTake := 0 + + if maxToTake > 0 { + stillNeedToTake = count - maxToTake + if stillNeedToTake < 0 { + stillNeedToTake = 0 + } + count = min(count, maxToTake) + } + // The number of new elements can never exceed the queue size. if count > len(r.ring) { panic("used ring contains more new elements than the ring is long") @@ -115,5 +125,5 @@ func (r *UsedRing) take() []UsedElement { r.lastIndex++ } - return elems + return stillNeedToTake, elems } diff --git a/packet/virtio.go b/packet/virtio.go new file mode 100644 index 0000000..18c5bce --- /dev/null +++ b/packet/virtio.go @@ -0,0 +1,16 @@ +package packet + +import ( + "github.com/slackhq/nebula/util/virtio" +) + +type VirtIOPacket struct { + Payload []byte + Header virtio.NetHdr +} + +func NewVIO() *VirtIOPacket { + out := new(VirtIOPacket) + out.Payload = make([]byte, Size) + return out +} diff --git a/overlay/virtio/doc.go b/util/virtio/doc.go similarity index 100% rename from overlay/virtio/doc.go rename to util/virtio/doc.go diff --git a/overlay/virtio/features.go b/util/virtio/features.go similarity index 100% rename from overlay/virtio/features.go rename to util/virtio/features.go diff --git a/overlay/virtio/net_hdr.go b/util/virtio/net_hdr.go similarity index 100% rename from overlay/virtio/net_hdr.go rename to util/virtio/net_hdr.go diff --git a/overlay/virtio/net_hdr_test.go b/util/virtio/net_hdr_test.go similarity index 100% rename from overlay/virtio/net_hdr_test.go rename to util/virtio/net_hdr_test.go