mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 08:24:25 +01:00
broken chkpt
This commit is contained in:
@@ -28,8 +28,8 @@ type Device struct {
|
||||
initialized bool
|
||||
controlFD int
|
||||
|
||||
receiveQueue *virtqueue.SplitQueue
|
||||
transmitQueue *virtqueue.SplitQueue
|
||||
ReceiveQueue *virtqueue.SplitQueue
|
||||
TransmitQueue *virtqueue.SplitQueue
|
||||
|
||||
// transmitted contains channels for each possible descriptor chain head
|
||||
// index. This is used for packet transmit notifications.
|
||||
@@ -96,17 +96,17 @@ func NewDevice(options ...Option) (*Device, error) {
|
||||
}
|
||||
|
||||
// Initialize and register the queues needed for the networking device.
|
||||
if dev.receiveQueue, err = createQueue(dev.controlFD, receiveQueueIndex, opts.queueSize); err != nil {
|
||||
if dev.ReceiveQueue, err = createQueue(dev.controlFD, receiveQueueIndex, opts.queueSize); err != nil {
|
||||
return nil, fmt.Errorf("create receive queue: %w", err)
|
||||
}
|
||||
if dev.transmitQueue, err = createQueue(dev.controlFD, transmitQueueIndex, opts.queueSize); err != nil {
|
||||
if dev.TransmitQueue, err = createQueue(dev.controlFD, transmitQueueIndex, opts.queueSize); err != nil {
|
||||
return nil, fmt.Errorf("create transmit queue: %w", err)
|
||||
}
|
||||
|
||||
// Set up memory mappings for all buffers used by the queues. This has to
|
||||
// happen before a backend for the queues can be registered.
|
||||
memoryLayout := vhost.NewMemoryLayoutForQueues(
|
||||
[]*virtqueue.SplitQueue{dev.receiveQueue, dev.transmitQueue},
|
||||
[]*virtqueue.SplitQueue{dev.ReceiveQueue, dev.TransmitQueue},
|
||||
)
|
||||
if err = vhost.SetMemoryLayout(dev.controlFD, memoryLayout); err != nil {
|
||||
return nil, fmt.Errorf("setup memory layout: %w", err)
|
||||
@@ -127,7 +127,7 @@ func NewDevice(options ...Option) (*Device, error) {
|
||||
}
|
||||
|
||||
// Initialize channels for transmit notifications.
|
||||
dev.transmitted = make([]chan virtqueue.UsedElement, dev.transmitQueue.Size())
|
||||
dev.transmitted = make([]chan virtqueue.UsedElement, dev.TransmitQueue.Size())
|
||||
for i := range len(dev.transmitted) {
|
||||
// It is important to use a single-element buffered channel here.
|
||||
// When the channel was unbuffered and the monitorTransmitQueue
|
||||
@@ -159,7 +159,7 @@ func NewDevice(options ...Option) (*Device, error) {
|
||||
// in the transmit queue and produces a transmit notification via the
|
||||
// corresponding channel.
|
||||
func (dev *Device) monitorTransmitQueue() {
|
||||
usedChan := dev.transmitQueue.UsedDescriptorChains()
|
||||
usedChan := dev.TransmitQueue.UsedDescriptorChains()
|
||||
for {
|
||||
used, ok := <-usedChan
|
||||
if !ok {
|
||||
@@ -180,7 +180,7 @@ func (dev *Device) monitorTransmitQueue() {
|
||||
// packets.
|
||||
func (dev *Device) refillReceiveQueue() error {
|
||||
for {
|
||||
_, err := dev.receiveQueue.OfferDescriptorChain(nil, 1, false)
|
||||
_, err := dev.ReceiveQueue.OfferDescriptorChain(nil, 1, false)
|
||||
if err != nil {
|
||||
if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
|
||||
// Queue is full, job is done.
|
||||
@@ -212,17 +212,17 @@ func (dev *Device) Close() error {
|
||||
|
||||
var errs []error
|
||||
|
||||
if dev.receiveQueue != nil {
|
||||
if err := dev.receiveQueue.Close(); err == nil {
|
||||
dev.receiveQueue = nil
|
||||
if dev.ReceiveQueue != nil {
|
||||
if err := dev.ReceiveQueue.Close(); err == nil {
|
||||
dev.ReceiveQueue = nil
|
||||
} else {
|
||||
errs = append(errs, fmt.Errorf("close receive queue: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if dev.transmitQueue != nil {
|
||||
if err := dev.transmitQueue.Close(); err == nil {
|
||||
dev.transmitQueue = nil
|
||||
if dev.TransmitQueue != nil {
|
||||
if err := dev.TransmitQueue.Close(); err == nil {
|
||||
dev.TransmitQueue = nil
|
||||
} else {
|
||||
errs = append(errs, fmt.Errorf("close transmit queue: %w", err))
|
||||
}
|
||||
@@ -296,7 +296,7 @@ func (dev *Device) TransmitPacket(vnethdr virtio.NetHdr, packet []byte) error {
|
||||
outBuffers := [][]byte{vnethdrBuf, packet}
|
||||
//outBuffers := [][]byte{packet}
|
||||
|
||||
chainIndex, err := dev.transmitQueue.OfferDescriptorChain(outBuffers, 0, true)
|
||||
chainIndex, err := dev.TransmitQueue.OfferDescriptorChain(outBuffers, 0, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("offer descriptor chain: %w", err)
|
||||
}
|
||||
@@ -304,7 +304,7 @@ func (dev *Device) TransmitPacket(vnethdr virtio.NetHdr, packet []byte) error {
|
||||
// Wait for the packet to have been transmitted.
|
||||
<-dev.transmitted[chainIndex]
|
||||
|
||||
if err = dev.transmitQueue.FreeDescriptorChain(chainIndex); err != nil {
|
||||
if err = dev.TransmitQueue.FreeDescriptorChain(chainIndex); err != nil {
|
||||
return fmt.Errorf("free descriptor chain: %w", err)
|
||||
}
|
||||
|
||||
@@ -320,7 +320,7 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
|
||||
vnethdrBuf[virtio.NetHdrSize+14-2] = 0x86
|
||||
vnethdrBuf[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype
|
||||
|
||||
chainIndexes, err := dev.transmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets, true)
|
||||
chainIndexes, err := dev.TransmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("offer descriptor chain: %w", err)
|
||||
}
|
||||
@@ -330,7 +330,7 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
|
||||
for i := range chainIndexes {
|
||||
<-dev.transmitted[chainIndexes[i]]
|
||||
|
||||
if err = dev.transmitQueue.FreeDescriptorChain(chainIndexes[i]); err != nil {
|
||||
if err = dev.TransmitQueue.FreeDescriptorChain(chainIndexes[i]); err != nil {
|
||||
return fmt.Errorf("free descriptor chain: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -346,7 +346,7 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
|
||||
// When this method returns an error, the receive queue will likely be in a
|
||||
// broken state which this implementation cannot recover from. The caller should
|
||||
// close the device and not attempt any additional receives.
|
||||
func (dev *Device) ReceivePacket() (virtio.NetHdr, []byte, error) {
|
||||
func (dev *Device) ReceivePacket(out []byte) (int, virtio.NetHdr, error) {
|
||||
var (
|
||||
chainHeads []uint16
|
||||
|
||||
@@ -358,41 +358,30 @@ func (dev *Device) ReceivePacket() (virtio.NetHdr, []byte, error) {
|
||||
packetLength = -virtio.NetHdrSize
|
||||
)
|
||||
|
||||
lenRead := 0
|
||||
|
||||
// We presented FeatureNetMergeRXBuffers to the device, so one packet may be
|
||||
// made of multiple descriptor chains which are to be merged.
|
||||
for remainingChains := 1; remainingChains > 0; remainingChains-- {
|
||||
// Get the next descriptor chain.
|
||||
usedElement, ok := <-dev.receiveQueue.UsedDescriptorChains()
|
||||
usedElement, ok := <-dev.ReceiveQueue.UsedDescriptorChains()
|
||||
if !ok {
|
||||
return virtio.NetHdr{}, nil, ErrDeviceClosed
|
||||
return 0, virtio.NetHdr{}, ErrDeviceClosed
|
||||
}
|
||||
|
||||
// Track this chain to be freed later.
|
||||
head := uint16(usedElement.DescriptorIndex)
|
||||
chainHeads = append(chainHeads, head)
|
||||
|
||||
outBuffers, inBuffers, err := dev.receiveQueue.GetDescriptorChain(head)
|
||||
n, err := dev.ReceiveQueue.GetDescriptorChainContents(head, out[lenRead:])
|
||||
if err != nil {
|
||||
// When this fails we may miss to free some descriptor chains. We
|
||||
// could try to mitigate this by deferring the freeing somehow, but
|
||||
// it's not worth the hassle. When this method fails, the queue will
|
||||
// be in a broken state anyway.
|
||||
return virtio.NetHdr{}, nil, fmt.Errorf("get descriptor chain: %w", err)
|
||||
return 0, virtio.NetHdr{}, fmt.Errorf("get descriptor chain: %w", err)
|
||||
}
|
||||
if len(outBuffers) > 0 {
|
||||
// How did this happen!?
|
||||
panic("receive queue contains device-readable buffers")
|
||||
}
|
||||
if len(inBuffers) == 0 {
|
||||
// Empty descriptor chains should not be possible.
|
||||
panic("descriptor chain contains no buffers")
|
||||
}
|
||||
|
||||
// The device tells us how many bytes of the descriptor chain it has
|
||||
// actually written to. The specification forces the device to fully
|
||||
// fill up all but the last descriptor chain when multiple descriptor
|
||||
// chains are being merged, but being more compatible here doesn't hurt.
|
||||
inBuffers = truncateBuffers(inBuffers, int(usedElement.Length))
|
||||
lenRead += n
|
||||
packetLength += int(usedElement.Length)
|
||||
|
||||
// Is this the first descriptor chain we process?
|
||||
@@ -403,49 +392,51 @@ func (dev *Device) ReceivePacket() (virtio.NetHdr, []byte, error) {
|
||||
// descriptor chain, but it is reasonable to assume that this is
|
||||
// always the case.
|
||||
// The decode method already does the buffer length check.
|
||||
if err = vnethdr.Decode(inBuffers[0]); err != nil {
|
||||
if err = vnethdr.Decode(out[0:]); err != nil {
|
||||
// The device misbehaved. There is no way we can gracefully
|
||||
// recover from this, because we don't know how many of the
|
||||
// following descriptor chains belong to this packet.
|
||||
return virtio.NetHdr{}, nil, fmt.Errorf("decode vnethdr: %w", err)
|
||||
return 0, virtio.NetHdr{}, fmt.Errorf("decode vnethdr: %w", err)
|
||||
}
|
||||
inBuffers[0] = inBuffers[0][virtio.NetHdrSize:]
|
||||
lenRead = 0
|
||||
out = out[virtio.NetHdrSize:]
|
||||
|
||||
// The virtio-net header tells us how many descriptor chains this
|
||||
// packet is long.
|
||||
remainingChains = int(vnethdr.NumBuffers)
|
||||
}
|
||||
|
||||
buffers = append(buffers, inBuffers...)
|
||||
//buffers = append(buffers, inBuffers...)
|
||||
}
|
||||
|
||||
// Copy all the buffers together to produce the complete packet slice.
|
||||
packet := make([]byte, packetLength)
|
||||
copied := 0
|
||||
for _, buffer := range buffers {
|
||||
copied += copy(packet[copied:], buffer)
|
||||
}
|
||||
if copied != packetLength {
|
||||
panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", packetLength, copied))
|
||||
}
|
||||
//out = out[:packetLength]
|
||||
//copied := 0
|
||||
//for _, buffer := range buffers {
|
||||
// copied += copy(out[copied:], buffer)
|
||||
//}
|
||||
//if copied != packetLength {
|
||||
// panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", packetLength, copied))
|
||||
//}
|
||||
|
||||
// Now that we have copied all buffers, we can free the used descriptor
|
||||
// chains again.
|
||||
// TODO: Recycling the descriptor chains would be more efficient than
|
||||
// freeing them just to offer them again right after.
|
||||
for _, head := range chainHeads {
|
||||
if err := dev.receiveQueue.FreeDescriptorChain(head); err != nil {
|
||||
return virtio.NetHdr{}, nil, fmt.Errorf("free descriptor chain with head index %d: %w", head, err)
|
||||
if err := dev.ReceiveQueue.FreeAndOfferDescriptorChains(head); err != nil {
|
||||
return 0, virtio.NetHdr{}, fmt.Errorf("free descriptor chain with head index %d: %w", head, err)
|
||||
}
|
||||
}
|
||||
|
||||
//if we don't churn chains, maybe we don't need this?
|
||||
// It's advised to always keep the receive queue fully populated with
|
||||
// available buffers which the device can write new packets into.
|
||||
if err := dev.refillReceiveQueue(); err != nil {
|
||||
return virtio.NetHdr{}, nil, fmt.Errorf("refill receive queue: %w", err)
|
||||
}
|
||||
//if err := dev.refillReceiveQueue(); err != nil {
|
||||
// return 0, virtio.NetHdr{}, fmt.Errorf("refill receive queue: %w", err)
|
||||
//}
|
||||
|
||||
return vnethdr, packet, nil
|
||||
return packetLength, vnethdr, nil
|
||||
}
|
||||
|
||||
// TODO: Make above methods cancelable by taking a context.Context argument?
|
||||
|
||||
Reference in New Issue
Block a user