mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-23 00:44:25 +01:00
more better
This commit is contained in:
@@ -288,17 +288,22 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
|
||||
func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.UsedElement) (int, error) {
|
||||
//read first element to see how many descriptors we need:
|
||||
pkt.Reset()
|
||||
n, err := dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[0].DescriptorIndex), pkt.Payload, int(chains[0].Length)) //todo
|
||||
|
||||
err := dev.ReceiveQueue.GetDescriptorInbuffers(uint16(chains[0].DescriptorIndex), &pkt.ChainRefs)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return 0, fmt.Errorf("get descriptor chain: %w", err)
|
||||
}
|
||||
if len(pkt.ChainRefs) == 0 {
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
// The specification requires that the first descriptor chain starts
|
||||
// with a virtio-net header. It is not clear, whether it is also
|
||||
// required to be fully contained in the first buffer of that
|
||||
// descriptor chain, but it is reasonable to assume that this is
|
||||
// always the case.
|
||||
// The decode method already does the buffer length check.
|
||||
if err = pkt.Header.Decode(pkt.Payload[0:]); err != nil {
|
||||
if err = pkt.Header.Decode(pkt.ChainRefs[0][0:]); err != nil {
|
||||
// The device misbehaved. There is no way we can gracefully
|
||||
// recover from this, because we don't know how many of the
|
||||
// following descriptor chains belong to this packet.
|
||||
@@ -309,33 +314,43 @@ func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.Us
|
||||
if int(pkt.Header.NumBuffers) > len(chains) {
|
||||
return 0, fmt.Errorf("number of buffers is greater than number of chains %d", len(chains))
|
||||
}
|
||||
if int(pkt.Header.NumBuffers) != 1 {
|
||||
return 0, fmt.Errorf("too smol-brain to handle more than one chain right now: %d chains", len(chains))
|
||||
}
|
||||
if chains[0].Length > 4000 {
|
||||
//todo!
|
||||
return 1, fmt.Errorf("too big packet length: %d", chains[0].Length)
|
||||
}
|
||||
|
||||
//shift the buffer out of out:
|
||||
pkt.Payload = pkt.Payload[virtio.NetHdrSize:]
|
||||
pkt.Payload = pkt.ChainRefs[0][virtio.NetHdrSize:chains[0].Length]
|
||||
pkt.Chains = append(pkt.Chains, uint16(chains[0].DescriptorIndex))
|
||||
pkt.Recycler = dev.ReceiveQueue.RecycleDescriptorChains
|
||||
return 1, nil
|
||||
|
||||
cursor := n - virtio.NetHdrSize
|
||||
|
||||
if uint32(n) >= chains[0].Length && pkt.Header.NumBuffers == 1 {
|
||||
pkt.Payload = pkt.Payload[:chains[0].Length-virtio.NetHdrSize]
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
i := 1
|
||||
// we used chain 0 already
|
||||
for i = 1; i < len(chains); i++ {
|
||||
n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:], int(chains[i].Length))
|
||||
if err != nil {
|
||||
// When this fails we may miss to free some descriptor chains. We
|
||||
// could try to mitigate this by deferring the freeing somehow, but
|
||||
// it's not worth the hassle. When this method fails, the queue will
|
||||
// be in a broken state anyway.
|
||||
return i, fmt.Errorf("get descriptor chain: %w", err)
|
||||
}
|
||||
cursor += n
|
||||
}
|
||||
//todo this has to be wrong
|
||||
pkt.Payload = pkt.Payload[:cursor]
|
||||
return i, nil
|
||||
//cursor := n - virtio.NetHdrSize
|
||||
//
|
||||
//if uint32(n) >= chains[0].Length && pkt.Header.NumBuffers == 1 {
|
||||
// pkt.Payload = pkt.Payload[:chains[0].Length-virtio.NetHdrSize]
|
||||
// return 1, nil
|
||||
//}
|
||||
//
|
||||
//i := 1
|
||||
//// we used chain 0 already
|
||||
//for i = 1; i < len(chains); i++ {
|
||||
// n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:], int(chains[i].Length))
|
||||
// if err != nil {
|
||||
// // When this fails we may miss to free some descriptor chains. We
|
||||
// // could try to mitigate this by deferring the freeing somehow, but
|
||||
// // it's not worth the hassle. When this method fails, the queue will
|
||||
// // be in a broken state anyway.
|
||||
// return i, fmt.Errorf("get descriptor chain: %w", err)
|
||||
// }
|
||||
// cursor += n
|
||||
//}
|
||||
////todo this has to be wrong
|
||||
//pkt.Payload = pkt.Payload[:cursor]
|
||||
//return i, nil
|
||||
}
|
||||
|
||||
func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) {
|
||||
@@ -368,9 +383,9 @@ func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) {
|
||||
}
|
||||
|
||||
// Now that we have copied all buffers, we can recycle the used descriptor chains
|
||||
if err = dev.ReceiveQueue.RecycleDescriptorChains(chains); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
//if err = dev.ReceiveQueue.RecycleDescriptorChains(chains); err != nil {
|
||||
// return 0, err
|
||||
//}
|
||||
|
||||
return numPackets, nil
|
||||
}
|
||||
|
||||
@@ -334,6 +334,48 @@ func (dt *DescriptorTable) getDescriptorChain(head uint16) (outBuffers, inBuffer
|
||||
return
|
||||
}
|
||||
|
||||
func (dt *DescriptorTable) getDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
|
||||
if int(head) > len(dt.descriptors) {
|
||||
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
// Iterate over the chain. The iteration is limited to the queue size to
|
||||
// avoid ending up in an endless loop when things go very wrong.
|
||||
next := head
|
||||
for range len(dt.descriptors) {
|
||||
if next == dt.freeHeadIndex {
|
||||
return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
desc := &dt.descriptors[next]
|
||||
|
||||
// The descriptor address points to memory not managed by Go, so this
|
||||
// conversion is safe. See https://github.com/golang/go/issues/58625
|
||||
//goland:noinspection GoVetUnsafePointer
|
||||
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
|
||||
|
||||
if desc.flags&descriptorFlagWritable == 0 {
|
||||
return fmt.Errorf("there should not be an outbuffer in %d", head)
|
||||
} else {
|
||||
*inBuffers = append(*inBuffers, bs)
|
||||
}
|
||||
|
||||
// Is this the tail of the chain?
|
||||
if desc.flags&descriptorFlagHasNext == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Detect loops.
|
||||
if desc.next == head {
|
||||
return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
next = desc.next
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dt *DescriptorTable) getDescriptorChainContents(head uint16, out []byte, maxLen int) (int, error) {
|
||||
if int(head) > len(dt.descriptors) {
|
||||
return 0, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
||||
|
||||
@@ -363,6 +363,10 @@ func (sq *SplitQueue) GetDescriptorChainContents(head uint16, out []byte, maxLen
|
||||
return sq.descriptorTable.getDescriptorChainContents(head, out, maxLen)
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) GetDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
|
||||
return sq.descriptorTable.getDescriptorInbuffers(head, inBuffers)
|
||||
}
|
||||
|
||||
// FreeDescriptorChain frees the descriptor chain with the given head index.
|
||||
// The head index must be one that was returned by a previous call to
|
||||
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
|
||||
@@ -381,7 +385,7 @@ func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) RecycleDescriptorChains(chains []UsedElement) error {
|
||||
func (sq *SplitQueue) RecycleDescriptorChains(chains []uint16, kick bool) error {
|
||||
//todo not doing this may break eventually?
|
||||
//not called under lock
|
||||
//if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
|
||||
@@ -389,11 +393,13 @@ func (sq *SplitQueue) RecycleDescriptorChains(chains []UsedElement) error {
|
||||
//}
|
||||
|
||||
// Make the descriptor chain available to the device.
|
||||
sq.availableRing.offerElements(chains)
|
||||
sq.availableRing.offer(chains)
|
||||
|
||||
// Notify the device to make it process the updated available ring.
|
||||
if err := sq.kickEventFD.Kick(); err != nil {
|
||||
return fmt.Errorf("notify device: %w", err)
|
||||
if kick {
|
||||
if err := sq.kickEventFD.Kick(); err != nil {
|
||||
return fmt.Errorf("notify device: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user