mirror of
https://github.com/slackhq/nebula.git
synced 2025-12-29 01:58:28 +01:00
refactoring a bit
This commit is contained in:
@@ -223,10 +223,7 @@ func (dev *Device) GetPacketForTx() (uint16, []byte, error) {
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("transmit queue: %w", err)
|
||||
}
|
||||
buf, err := dev.TransmitQueue.GetDescriptorItem(idx)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("get descriptor chain: %w", err)
|
||||
}
|
||||
buf := dev.TransmitQueue.GetDescriptorItem(idx)
|
||||
return idx, buf, nil
|
||||
}
|
||||
|
||||
@@ -273,10 +270,7 @@ func (dev *Device) ProcessRxChain(pkt *VirtIOPacket, chain virtqueue.UsedElement
|
||||
//read first element to see how many descriptors we need:
|
||||
pkt.Reset()
|
||||
idx := uint16(chain.DescriptorIndex)
|
||||
buf, err := dev.ReceiveQueue.GetDescriptorItem(idx)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get descriptor chain: %w", err)
|
||||
}
|
||||
buf := dev.ReceiveQueue.GetDescriptorItem(idx)
|
||||
|
||||
// The specification requires that the first descriptor chain starts
|
||||
// with a virtio-net header. It is not clear, whether it is also
|
||||
@@ -284,20 +278,25 @@ func (dev *Device) ProcessRxChain(pkt *VirtIOPacket, chain virtqueue.UsedElement
|
||||
// descriptor chain, but it is reasonable to assume that this is
|
||||
// always the case.
|
||||
// The decode method already does the buffer length check.
|
||||
if err = pkt.header.Decode(buf); err != nil {
|
||||
// The device misbehaved. There is no way we can gracefully
|
||||
// recover from this, because we don't know how many of the
|
||||
// following descriptor chains belong to this packet.
|
||||
return 0, fmt.Errorf("decode vnethdr: %w", err)
|
||||
}
|
||||
|
||||
//HACK: we only want the last bit of the header, the NumBuffers field. So, let's grab just that:
|
||||
//numBuffers := binary.BigEndian.Uint16(buf[virtio.NetHdrSize-3:])
|
||||
//even bigger hack: apparently this is hitting some kind of memory access pitfall? Let's only grab the last byte:
|
||||
//numBuffers := buf[virtio.NetHdrSize-2]
|
||||
|
||||
//if err = pkt.header.Decode(buf); err != nil {
|
||||
// // The device misbehaved. There is no way we can gracefully
|
||||
// // recover from this, because we don't know how many of the
|
||||
// // following descriptor chains belong to this packet.
|
||||
// return 0, fmt.Errorf("decode vnethdr: %w", err)
|
||||
//}
|
||||
|
||||
//we have the header now: what do we need to do?
|
||||
if int(pkt.header.NumBuffers) > 1 {
|
||||
return 0, fmt.Errorf("number of buffers is greater than number of chains %d", 1)
|
||||
}
|
||||
if int(pkt.header.NumBuffers) != 1 {
|
||||
return 0, fmt.Errorf("too smol-brain to handle more than one buffer per chain item right now: %d chains, %d bufs", 1, int(pkt.header.NumBuffers))
|
||||
}
|
||||
//todo we're ignoring the header lol
|
||||
//if int(numBuffers) != 1 {
|
||||
// return 0, fmt.Errorf("too smol-brain to handle more than one buffer per Chain item right now: %d chains, %d bufs", 1, int(numBuffers))
|
||||
//}
|
||||
|
||||
if chain.Length > 16000 {
|
||||
//todo!
|
||||
return 1, fmt.Errorf("too big packet length: %d", chain.Length)
|
||||
@@ -311,8 +310,8 @@ func (dev *Device) ProcessRxChain(pkt *VirtIOPacket, chain virtqueue.UsedElement
|
||||
|
||||
type VirtIOPacket struct {
|
||||
payload []byte
|
||||
header virtio.NetHdr
|
||||
Chains []uint16
|
||||
//header virtio.NetHdr
|
||||
Chains []uint16
|
||||
}
|
||||
|
||||
func NewVIO() *VirtIOPacket {
|
||||
@@ -328,8 +327,8 @@ func (v *VirtIOPacket) Reset() {
|
||||
}
|
||||
|
||||
func (v *VirtIOPacket) GetPayload() []byte {
|
||||
return v.payload
|
||||
return v.payload //todo this could be dev.ReceiveQueue.GetDescriptorItem(idx)
|
||||
}
|
||||
func (v *VirtIOPacket) SetPayload(x []byte) {
|
||||
v.payload = x //todo?
|
||||
v.payload = x
|
||||
}
|
||||
|
||||
@@ -268,18 +268,13 @@ func (dt *DescriptorTable) createDescriptorForInputs() (uint16, error) {
|
||||
return head, nil
|
||||
}
|
||||
|
||||
func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) {
|
||||
if int(head) > len(dt.descriptors) {
|
||||
return nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
func (dt *DescriptorTable) getDescriptorItem(head uint16) []byte {
|
||||
desc := &dt.descriptors[head] //todo this is a pretty nasty hack with no checks
|
||||
|
||||
// The descriptor address points to memory not managed by Go, so this
|
||||
// conversion is safe. See https://github.com/golang/go/issues/58625
|
||||
//goland:noinspection GoVetUnsafePointer
|
||||
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
|
||||
return bs, nil
|
||||
return unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
|
||||
}
|
||||
|
||||
// checkUnusedDescriptorLength asserts that the length of an unused descriptor
|
||||
|
||||
@@ -340,7 +340,7 @@ func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) {
|
||||
// Be careful to only access the returned buffer slices when the device is no
|
||||
// longer using them. They must not be accessed after
|
||||
// [SplitQueue.FreeDescriptorChain] has been called.
|
||||
func (sq *SplitQueue) GetDescriptorItem(head uint16) ([]byte, error) {
|
||||
func (sq *SplitQueue) GetDescriptorItem(head uint16) []byte {
|
||||
sq.descriptorTable.descriptors[head].length = uint32(sq.descriptorTable.itemSize)
|
||||
return sq.descriptorTable.getDescriptorItem(head)
|
||||
}
|
||||
@@ -350,13 +350,19 @@ func (sq *SplitQueue) SetDescSize(head uint16, sz int) {
|
||||
sq.descriptorTable.descriptors[int(head)].length = uint32(sz)
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) OfferDescriptorChains(chains []uint16, kick bool) error {
|
||||
//todo not doing this may break eventually?
|
||||
//not called under lock
|
||||
//if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
|
||||
// return fmt.Errorf("free: %w", err)
|
||||
//}
|
||||
func (sq *SplitQueue) OfferDescriptor(chain uint16, kick bool) error {
|
||||
// Make the descriptor chain available to the device.
|
||||
sq.availableRing.offerSingle(chain)
|
||||
|
||||
// Notify the device to make it process the updated available ring.
|
||||
if kick {
|
||||
return sq.Kick()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) OfferDescriptorChains(chains []uint16, kick bool) error {
|
||||
// Make the descriptor chain available to the device.
|
||||
sq.availableRing.offer(chains)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user