diff --git a/overlay/vhostnet/device.go b/overlay/vhostnet/device.go index 0a3b70e..2ab728b 100644 --- a/overlay/vhostnet/device.go +++ b/overlay/vhostnet/device.go @@ -27,10 +27,8 @@ const ( // Device represents a vhost networking device within the kernel-level virtio // implementation and provides methods to interact with it. type Device struct { - initialized bool - controlFD int + controlFD int - fullTable bool ReceiveQueue *virtqueue.SplitQueue TransmitQueue *virtqueue.SplitQueue } @@ -127,8 +125,6 @@ func NewDevice(options ...Option) (*Device, error) { return nil, fmt.Errorf("refill tx queue: %w", err) } - dev.initialized = true - // Make sure to clean up even when the device gets garbage collected without // Close being called first. devPtr := &dev @@ -179,8 +175,6 @@ func (dev *Device) prefillTxQueue() error { // The implementation will try to release as many resources as possible and // collect potential errors before returning them. func (dev *Device) Close() error { - dev.initialized = false - // Closing the control file descriptor will unregister all queues from the // kernel. if dev.controlFD >= 0 { @@ -233,18 +227,7 @@ func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*v } func (dev *Device) GetPacketForTx() (uint16, []byte, error) { - var err error - var idx uint16 - //if !dev.fullTable { - // idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs() - // if err == virtqueue.ErrNotEnoughFreeDescriptors { - // dev.fullTable = true - // idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO()) - // } - //} else { - // idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO()) - //} - idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO()) + idx, err := dev.TransmitQueue.TakeSingleIndex(context.TODO()) if err != nil { return 0, nil, fmt.Errorf("transmit queue: %w", err) } @@ -266,9 +249,7 @@ func (dev *Device) TransmitPacket(pkt *packet.OutPacket, kick bool) error { } pkt.Reset() if kick { - if err := dev.TransmitQueue.Kick(); err != nil { - return err - } + return dev.TransmitQueue.Kick() } return nil @@ -293,9 +274,8 @@ func (dev *Device) TransmitPackets(pkts []*packet.OutPacket) error { // ProcessRxChain processes a single chain to create one packet. The number of processed chains is returned. func (dev *Device) ProcessRxChain(pkt *VirtIOPacket, chain virtqueue.UsedElement) (int, error) { //read first element to see how many descriptors we need: - pkt.Reset() - idx := uint16(chain.DescriptorIndex) - buf := dev.ReceiveQueue.GetDescriptorItem(idx) + pkt.Chain = uint16(chain.DescriptorIndex) + buf := dev.ReceiveQueue.GetDescriptorItem(pkt.Chain) // The specification requires that the first descriptor chain starts // with a virtio-net header. It is not clear, whether it is also @@ -324,12 +304,12 @@ func (dev *Device) ProcessRxChain(pkt *VirtIOPacket, chain virtqueue.UsedElement if chain.Length > 16000 { //todo! + pkt.payload = nil return 1, fmt.Errorf("too big packet length: %d", chain.Length) } //shift the buffer out of out: pkt.payload = buf[virtio.NetHdrSize:chain.Length] - pkt.Chain = idx return 1, nil }