refactoring a bit

This commit is contained in:
JackDoan
2025-12-18 14:07:28 -06:00
parent 41c9a3b2eb
commit 726e282d0a
3 changed files with 38 additions and 38 deletions

View File

@@ -223,10 +223,7 @@ func (dev *Device) GetPacketForTx() (uint16, []byte, error) {
if err != nil { if err != nil {
return 0, nil, fmt.Errorf("transmit queue: %w", err) return 0, nil, fmt.Errorf("transmit queue: %w", err)
} }
buf, err := dev.TransmitQueue.GetDescriptorItem(idx) buf := dev.TransmitQueue.GetDescriptorItem(idx)
if err != nil {
return 0, nil, fmt.Errorf("get descriptor chain: %w", err)
}
return idx, buf, nil return idx, buf, nil
} }
@@ -273,10 +270,7 @@ func (dev *Device) ProcessRxChain(pkt *VirtIOPacket, chain virtqueue.UsedElement
//read first element to see how many descriptors we need: //read first element to see how many descriptors we need:
pkt.Reset() pkt.Reset()
idx := uint16(chain.DescriptorIndex) idx := uint16(chain.DescriptorIndex)
buf, err := dev.ReceiveQueue.GetDescriptorItem(idx) buf := dev.ReceiveQueue.GetDescriptorItem(idx)
if err != nil {
return 0, fmt.Errorf("get descriptor chain: %w", err)
}
// The specification requires that the first descriptor chain starts // The specification requires that the first descriptor chain starts
// with a virtio-net header. It is not clear, whether it is also // with a virtio-net header. It is not clear, whether it is also
@@ -284,20 +278,25 @@ func (dev *Device) ProcessRxChain(pkt *VirtIOPacket, chain virtqueue.UsedElement
// descriptor chain, but it is reasonable to assume that this is // descriptor chain, but it is reasonable to assume that this is
// always the case. // always the case.
// The decode method already does the buffer length check. // The decode method already does the buffer length check.
if err = pkt.header.Decode(buf); err != nil {
// The device misbehaved. There is no way we can gracefully //HACK: we only want the last bit of the header, the NumBuffers field. So, let's grab just that:
// recover from this, because we don't know how many of the //numBuffers := binary.BigEndian.Uint16(buf[virtio.NetHdrSize-3:])
// following descriptor chains belong to this packet. //even bigger hack: apparently this is hitting some kind of memory access pitfall? Let's only grab the last byte:
return 0, fmt.Errorf("decode vnethdr: %w", err) //numBuffers := buf[virtio.NetHdrSize-2]
}
//if err = pkt.header.Decode(buf); err != nil {
// // The device misbehaved. There is no way we can gracefully
// // recover from this, because we don't know how many of the
// // following descriptor chains belong to this packet.
// return 0, fmt.Errorf("decode vnethdr: %w", err)
//}
//we have the header now: what do we need to do? //we have the header now: what do we need to do?
if int(pkt.header.NumBuffers) > 1 { //todo we're ignoring the header lol
return 0, fmt.Errorf("number of buffers is greater than number of chains %d", 1) //if int(numBuffers) != 1 {
} // return 0, fmt.Errorf("too smol-brain to handle more than one buffer per Chain item right now: %d chains, %d bufs", 1, int(numBuffers))
if int(pkt.header.NumBuffers) != 1 { //}
return 0, fmt.Errorf("too smol-brain to handle more than one buffer per chain item right now: %d chains, %d bufs", 1, int(pkt.header.NumBuffers))
}
if chain.Length > 16000 { if chain.Length > 16000 {
//todo! //todo!
return 1, fmt.Errorf("too big packet length: %d", chain.Length) return 1, fmt.Errorf("too big packet length: %d", chain.Length)
@@ -311,8 +310,8 @@ func (dev *Device) ProcessRxChain(pkt *VirtIOPacket, chain virtqueue.UsedElement
type VirtIOPacket struct { type VirtIOPacket struct {
payload []byte payload []byte
header virtio.NetHdr //header virtio.NetHdr
Chains []uint16 Chains []uint16
} }
func NewVIO() *VirtIOPacket { func NewVIO() *VirtIOPacket {
@@ -328,8 +327,8 @@ func (v *VirtIOPacket) Reset() {
} }
func (v *VirtIOPacket) GetPayload() []byte { func (v *VirtIOPacket) GetPayload() []byte {
return v.payload return v.payload //todo this could be dev.ReceiveQueue.GetDescriptorItem(idx)
} }
func (v *VirtIOPacket) SetPayload(x []byte) { func (v *VirtIOPacket) SetPayload(x []byte) {
v.payload = x //todo? v.payload = x
} }

View File

@@ -268,18 +268,13 @@ func (dt *DescriptorTable) createDescriptorForInputs() (uint16, error) {
return head, nil return head, nil
} }
func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) { func (dt *DescriptorTable) getDescriptorItem(head uint16) []byte {
if int(head) > len(dt.descriptors) {
return nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
}
desc := &dt.descriptors[head] //todo this is a pretty nasty hack with no checks desc := &dt.descriptors[head] //todo this is a pretty nasty hack with no checks
// The descriptor address points to memory not managed by Go, so this // The descriptor address points to memory not managed by Go, so this
// conversion is safe. See https://github.com/golang/go/issues/58625 // conversion is safe. See https://github.com/golang/go/issues/58625
//goland:noinspection GoVetUnsafePointer //goland:noinspection GoVetUnsafePointer
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length) return unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
return bs, nil
} }
// checkUnusedDescriptorLength asserts that the length of an unused descriptor // checkUnusedDescriptorLength asserts that the length of an unused descriptor

View File

@@ -340,7 +340,7 @@ func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) {
// Be careful to only access the returned buffer slices when the device is no // Be careful to only access the returned buffer slices when the device is no
// longer using them. They must not be accessed after // longer using them. They must not be accessed after
// [SplitQueue.FreeDescriptorChain] has been called. // [SplitQueue.FreeDescriptorChain] has been called.
func (sq *SplitQueue) GetDescriptorItem(head uint16) ([]byte, error) { func (sq *SplitQueue) GetDescriptorItem(head uint16) []byte {
sq.descriptorTable.descriptors[head].length = uint32(sq.descriptorTable.itemSize) sq.descriptorTable.descriptors[head].length = uint32(sq.descriptorTable.itemSize)
return sq.descriptorTable.getDescriptorItem(head) return sq.descriptorTable.getDescriptorItem(head)
} }
@@ -350,13 +350,19 @@ func (sq *SplitQueue) SetDescSize(head uint16, sz int) {
sq.descriptorTable.descriptors[int(head)].length = uint32(sz) sq.descriptorTable.descriptors[int(head)].length = uint32(sz)
} }
func (sq *SplitQueue) OfferDescriptorChains(chains []uint16, kick bool) error { func (sq *SplitQueue) OfferDescriptor(chain uint16, kick bool) error {
//todo not doing this may break eventually? // Make the descriptor chain available to the device.
//not called under lock sq.availableRing.offerSingle(chain)
//if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
// return fmt.Errorf("free: %w", err)
//}
// Notify the device to make it process the updated available ring.
if kick {
return sq.Kick()
}
return nil
}
func (sq *SplitQueue) OfferDescriptorChains(chains []uint16, kick bool) error {
// Make the descriptor chain available to the device. // Make the descriptor chain available to the device.
sq.availableRing.offer(chains) sq.availableRing.offer(chains)