This commit is contained in:
JackDoan
2025-12-18 15:33:09 -06:00
parent 111efc0779
commit 3338a2a2a1

View File

@@ -27,10 +27,8 @@ const (
// Device represents a vhost networking device within the kernel-level virtio // Device represents a vhost networking device within the kernel-level virtio
// implementation and provides methods to interact with it. // implementation and provides methods to interact with it.
type Device struct { type Device struct {
initialized bool controlFD int
controlFD int
fullTable bool
ReceiveQueue *virtqueue.SplitQueue ReceiveQueue *virtqueue.SplitQueue
TransmitQueue *virtqueue.SplitQueue TransmitQueue *virtqueue.SplitQueue
} }
@@ -127,8 +125,6 @@ func NewDevice(options ...Option) (*Device, error) {
return nil, fmt.Errorf("refill tx queue: %w", err) return nil, fmt.Errorf("refill tx queue: %w", err)
} }
dev.initialized = true
// Make sure to clean up even when the device gets garbage collected without // Make sure to clean up even when the device gets garbage collected without
// Close being called first. // Close being called first.
devPtr := &dev devPtr := &dev
@@ -179,8 +175,6 @@ func (dev *Device) prefillTxQueue() error {
// The implementation will try to release as many resources as possible and // The implementation will try to release as many resources as possible and
// collect potential errors before returning them. // collect potential errors before returning them.
func (dev *Device) Close() error { func (dev *Device) Close() error {
dev.initialized = false
// Closing the control file descriptor will unregister all queues from the // Closing the control file descriptor will unregister all queues from the
// kernel. // kernel.
if dev.controlFD >= 0 { if dev.controlFD >= 0 {
@@ -233,18 +227,7 @@ func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*v
} }
func (dev *Device) GetPacketForTx() (uint16, []byte, error) { func (dev *Device) GetPacketForTx() (uint16, []byte, error) {
var err error idx, err := dev.TransmitQueue.TakeSingleIndex(context.TODO())
var idx uint16
//if !dev.fullTable {
// idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
// if err == virtqueue.ErrNotEnoughFreeDescriptors {
// dev.fullTable = true
// idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO())
// }
//} else {
// idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO())
//}
idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO())
if err != nil { if err != nil {
return 0, nil, fmt.Errorf("transmit queue: %w", err) return 0, nil, fmt.Errorf("transmit queue: %w", err)
} }
@@ -266,9 +249,7 @@ func (dev *Device) TransmitPacket(pkt *packet.OutPacket, kick bool) error {
} }
pkt.Reset() pkt.Reset()
if kick { if kick {
if err := dev.TransmitQueue.Kick(); err != nil { return dev.TransmitQueue.Kick()
return err
}
} }
return nil return nil
@@ -293,9 +274,8 @@ func (dev *Device) TransmitPackets(pkts []*packet.OutPacket) error {
// ProcessRxChain processes a single chain to create one packet. The number of processed chains is returned. // ProcessRxChain processes a single chain to create one packet. The number of processed chains is returned.
func (dev *Device) ProcessRxChain(pkt *VirtIOPacket, chain virtqueue.UsedElement) (int, error) { func (dev *Device) ProcessRxChain(pkt *VirtIOPacket, chain virtqueue.UsedElement) (int, error) {
//read first element to see how many descriptors we need: //read first element to see how many descriptors we need:
pkt.Reset() pkt.Chain = uint16(chain.DescriptorIndex)
idx := uint16(chain.DescriptorIndex) buf := dev.ReceiveQueue.GetDescriptorItem(pkt.Chain)
buf := dev.ReceiveQueue.GetDescriptorItem(idx)
// The specification requires that the first descriptor chain starts // The specification requires that the first descriptor chain starts
// with a virtio-net header. It is not clear, whether it is also // with a virtio-net header. It is not clear, whether it is also
@@ -324,12 +304,12 @@ func (dev *Device) ProcessRxChain(pkt *VirtIOPacket, chain virtqueue.UsedElement
if chain.Length > 16000 { if chain.Length > 16000 {
//todo! //todo!
pkt.payload = nil
return 1, fmt.Errorf("too big packet length: %d", chain.Length) return 1, fmt.Errorf("too big packet length: %d", chain.Length)
} }
//shift the buffer out of out: //shift the buffer out of out:
pkt.payload = buf[virtio.NetHdrSize:chain.Length] pkt.payload = buf[virtio.NetHdrSize:chain.Length]
pkt.Chain = idx
return 1, nil return 1, nil
} }