mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-24 01:14:25 +01:00
BAM! a hit from my spice-weasel. Now TX is zero-copy, at the expense of my sanity!
This commit is contained in:
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"slices"
|
||||
|
||||
"github.com/slackhq/nebula/overlay/vhost"
|
||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
||||
@@ -31,6 +30,7 @@ type Device struct {
|
||||
initialized bool
|
||||
controlFD int
|
||||
|
||||
fullTable bool
|
||||
ReceiveQueue *virtqueue.SplitQueue
|
||||
TransmitQueue *virtqueue.SplitQueue
|
||||
}
|
||||
@@ -123,6 +123,9 @@ func NewDevice(options ...Option) (*Device, error) {
|
||||
if err = dev.refillReceiveQueue(); err != nil {
|
||||
return nil, fmt.Errorf("refill receive queue: %w", err)
|
||||
}
|
||||
if err = dev.refillTransmitQueue(); err != nil {
|
||||
return nil, fmt.Errorf("refill receive queue: %w", err)
|
||||
}
|
||||
|
||||
dev.initialized = true
|
||||
|
||||
@@ -139,7 +142,7 @@ func NewDevice(options ...Option) (*Device, error) {
|
||||
// packets.
|
||||
func (dev *Device) refillReceiveQueue() error {
|
||||
for {
|
||||
_, err := dev.ReceiveQueue.OfferInDescriptorChains(1)
|
||||
_, err := dev.ReceiveQueue.OfferInDescriptorChains()
|
||||
if err != nil {
|
||||
if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
|
||||
// Queue is full, job is done.
|
||||
@@ -150,6 +153,22 @@ func (dev *Device) refillReceiveQueue() error {
|
||||
}
|
||||
}
|
||||
|
||||
func (dev *Device) refillTransmitQueue() error {
|
||||
//for {
|
||||
// desc, err := dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
|
||||
// if err != nil {
|
||||
// if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
|
||||
// // Queue is full, job is done.
|
||||
// return nil
|
||||
// }
|
||||
// return fmt.Errorf("offer descriptor chain: %w", err)
|
||||
// } else {
|
||||
// dev.TransmitQueue.UsedRing().InitOfferSingle(desc, 0)
|
||||
// }
|
||||
//}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close cleans up the vhost networking device within the kernel and releases
|
||||
// all resources used for it.
|
||||
// The implementation will try to release as many resources as possible and
|
||||
@@ -238,49 +257,67 @@ func truncateBuffers(buffers [][]byte, length int) (out [][]byte) {
|
||||
return
|
||||
}
|
||||
|
||||
func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) error {
|
||||
// Prepend the packet with its virtio-net header.
|
||||
vnethdrBuf := make([]byte, virtio.NetHdrSize+14) //todo WHY
|
||||
if err := vnethdr.Encode(vnethdrBuf); err != nil {
|
||||
return fmt.Errorf("encode vnethdr: %w", err)
|
||||
}
|
||||
vnethdrBuf[virtio.NetHdrSize+14-2] = 0x86
|
||||
vnethdrBuf[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype
|
||||
func (dev *Device) GetPacketForTx() (uint16, []byte, error) {
|
||||
var err error
|
||||
var idx uint16
|
||||
if !dev.fullTable {
|
||||
|
||||
chainIndexes, err := dev.TransmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets)
|
||||
idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
|
||||
if err == virtqueue.ErrNotEnoughFreeDescriptors {
|
||||
dev.fullTable = true
|
||||
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
|
||||
}
|
||||
} else {
|
||||
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("offer descriptor chain: %w", err)
|
||||
return 0, nil, fmt.Errorf("transmit queue: %w", err)
|
||||
}
|
||||
//todo surely there's something better to do here
|
||||
buf, err := dev.TransmitQueue.GetDescriptorItem(idx)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("get descriptor chain: %w", err)
|
||||
}
|
||||
return idx, buf, nil
|
||||
}
|
||||
|
||||
for {
|
||||
txedChains, err := dev.TransmitQueue.BlockAndGetHeadsCapped(context.TODO(), len(chainIndexes))
|
||||
if err != nil {
|
||||
func (dev *Device) TransmitPacket(pkt *packet.OutPacket, kick bool) error {
|
||||
//if pkt.Valid {
|
||||
if len(pkt.SegmentIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
for idx := range pkt.SegmentIDs {
|
||||
segmentID := pkt.SegmentIDs[idx]
|
||||
dev.TransmitQueue.SetDescSize(segmentID, len(pkt.Segments[idx]))
|
||||
}
|
||||
err := dev.TransmitQueue.OfferDescriptorChains(pkt.SegmentIDs, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("offer descriptor chains: %w", err)
|
||||
}
|
||||
pkt.Reset()
|
||||
//}
|
||||
//if kick {
|
||||
if err := dev.TransmitQueue.Kick(); err != nil {
|
||||
return err
|
||||
}
|
||||
//}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dev *Device) TransmitPackets(pkts []*packet.OutPacket) error {
|
||||
if len(pkts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for i := range pkts {
|
||||
if err := dev.TransmitPacket(pkts[i], false); err != nil {
|
||||
return err
|
||||
} else if len(txedChains) == 0 {
|
||||
continue //todo will this ever exit?
|
||||
}
|
||||
for _, c := range txedChains {
|
||||
idx := slices.Index(chainIndexes, c.GetHead())
|
||||
if idx < 0 {
|
||||
continue
|
||||
} else {
|
||||
_ = dev.TransmitQueue.FreeDescriptorChain(chainIndexes[idx])
|
||||
chainIndexes[idx] = 0 //todo I hope this works
|
||||
}
|
||||
}
|
||||
done := true //optimism!
|
||||
for _, x := range chainIndexes {
|
||||
if x != 0 {
|
||||
done = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if done {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if err := dev.TransmitQueue.Kick(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: Make above methods cancelable by taking a context.Context argument?
|
||||
@@ -327,7 +364,7 @@ func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.Us
|
||||
//shift the buffer out of out:
|
||||
pkt.Payload = pkt.ChainRefs[0][virtio.NetHdrSize:chains[0].Length]
|
||||
pkt.Chains = append(pkt.Chains, uint16(chains[0].DescriptorIndex))
|
||||
pkt.Recycler = dev.ReceiveQueue.RecycleDescriptorChains
|
||||
pkt.Recycler = dev.ReceiveQueue.OfferDescriptorChains
|
||||
return 1, nil
|
||||
|
||||
//cursor := n - virtio.NetHdrSize
|
||||
@@ -385,7 +422,7 @@ func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) {
|
||||
}
|
||||
|
||||
// Now that we have copied all buffers, we can recycle the used descriptor chains
|
||||
//if err = dev.ReceiveQueue.RecycleDescriptorChains(chains); err != nil {
|
||||
//if err = dev.ReceiveQueue.OfferDescriptorChains(chains); err != nil {
|
||||
// return 0, err
|
||||
//}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user