diff --git a/outside.go b/outside.go index 5ef80fa..0c44cec 100644 --- a/outside.go +++ b/outside.go @@ -455,11 +455,11 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { return ErrPacketTooShort } - version := int((data[0] >> 4) & 0x0f) - switch version { - case ipv4.Version: + //version := int((data[0] >> 4) & 0x0f) + switch data[0] & 0xf0 { + case ipv4.Version << 4: return parseV4(data, incoming, fp) - case ipv6.Version: + case ipv6.Version << 4: return parseV6(data, incoming, fp) } return ErrUnknownIPVersion diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index cb1aa58..1add4ae 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -806,8 +806,11 @@ func (t *tun) WriteMany(x []*packet.OutPacket, q int) (int, error) { } func (t *tun) RecycleRxSeg(pkt TunPacket, kick bool, q int) error { + if pkt.GetPayload() == nil { + return nil + } vpkt := pkt.(*vhostnet.VirtIOPacket) - err := t.vdev[q].ReceiveQueue.OfferDescriptorChains(vpkt.Chains, kick) + err := t.vdev[q].ReceiveQueue.OfferDescriptorChains([]uint16{vpkt.Chain}, kick) vpkt.Reset() //intentionally ignoring err! return err } diff --git a/overlay/vhostnet/device.go b/overlay/vhostnet/device.go index 289f8aa..0a3b70e 100644 --- a/overlay/vhostnet/device.go +++ b/overlay/vhostnet/device.go @@ -123,6 +123,9 @@ func NewDevice(options ...Option) (*Device, error) { if err = dev.refillReceiveQueue(); err != nil { return nil, fmt.Errorf("refill receive queue: %w", err) } + if err = dev.prefillTxQueue(); err != nil { + return nil, fmt.Errorf("refill tx queue: %w", err) + } dev.initialized = true @@ -150,6 +153,27 @@ func (dev *Device) refillReceiveQueue() error { } } +func (dev *Device) prefillTxQueue() error { + for { + dt := dev.TransmitQueue.DescriptorTable() + for { + x, _, err := dt.CreateDescriptorForOutputs() + if err != nil { + if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) { + // Queue is full, job is done. + return nil + } + return err + } + err = dev.TransmitQueue.OfferDescriptorChains([]uint16{x}, false) + if err != nil { + return err + } + } + + } +} + // Close cleans up the vhost networking device within the kernel and releases // all resources used for it. // The implementation will try to release as many resources as possible and @@ -211,15 +235,16 @@ func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*v func (dev *Device) GetPacketForTx() (uint16, []byte, error) { var err error var idx uint16 - if !dev.fullTable { - idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs() - if err == virtqueue.ErrNotEnoughFreeDescriptors { - dev.fullTable = true - idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO()) - } - } else { - idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO()) - } + //if !dev.fullTable { + // idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs() + // if err == virtqueue.ErrNotEnoughFreeDescriptors { + // dev.fullTable = true + // idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO()) + // } + //} else { + // idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO()) + //} + idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO()) if err != nil { return 0, nil, fmt.Errorf("transmit queue: %w", err) } @@ -304,26 +329,26 @@ func (dev *Device) ProcessRxChain(pkt *VirtIOPacket, chain virtqueue.UsedElement //shift the buffer out of out: pkt.payload = buf[virtio.NetHdrSize:chain.Length] - pkt.Chains = append(pkt.Chains, idx) + pkt.Chain = idx return 1, nil } type VirtIOPacket struct { payload []byte //header virtio.NetHdr - Chains []uint16 + Chain uint16 } func NewVIO() *VirtIOPacket { out := new(VirtIOPacket) out.payload = nil - out.Chains = make([]uint16, 0, 8) + out.Chain = 0 return out } func (v *VirtIOPacket) Reset() { v.payload = nil - v.Chains = v.Chains[:0] + v.Chain = 0 } func (v *VirtIOPacket) GetPayload() []byte { diff --git a/overlay/virtqueue/descriptor_table.go b/overlay/virtqueue/descriptor_table.go index 5aa4cd5..762567f 100644 --- a/overlay/virtqueue/descriptor_table.go +++ b/overlay/virtqueue/descriptor_table.go @@ -168,12 +168,12 @@ func (dt *DescriptorTable) releaseBuffers() error { return nil } -func (dt *DescriptorTable) CreateDescriptorForOutputs() (uint16, error) { +func (dt *DescriptorTable) CreateDescriptorForOutputs() (uint16, uint32, error) { //todo just fill the damn table // Do we still have enough free descriptors? if 1 > dt.freeNum { - return 0, ErrNotEnoughFreeDescriptors + return 0, 0, ErrNotEnoughFreeDescriptors } // Above validation ensured that there is at least one free descriptor, so @@ -216,7 +216,7 @@ func (dt *DescriptorTable) CreateDescriptorForOutputs() (uint16, error) { dt.descriptors[dt.freeHeadIndex].next = next } - return head, nil + return head, desc.length, nil } func (dt *DescriptorTable) createDescriptorForInputs() (uint16, error) { diff --git a/overlay/virtqueue/split_virtqueue.go b/overlay/virtqueue/split_virtqueue.go index cd06bf6..27daa0a 100644 --- a/overlay/virtqueue/split_virtqueue.go +++ b/overlay/virtqueue/split_virtqueue.go @@ -301,7 +301,6 @@ func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) // and any further calls to [SplitQueue.OfferDescriptorChain] will stall. func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) { - // Create a descriptor chain for the given buffers. var ( head uint16 err error @@ -350,18 +349,6 @@ func (sq *SplitQueue) SetDescSize(head uint16, sz int) { sq.descriptorTable.descriptors[int(head)].length = uint32(sz) } -func (sq *SplitQueue) OfferDescriptor(chain uint16, kick bool) error { - // Make the descriptor chain available to the device. - sq.availableRing.offerSingle(chain) - - // Notify the device to make it process the updated available ring. - if kick { - return sq.Kick() - } - - return nil -} - func (sq *SplitQueue) OfferDescriptorChains(chains []uint16, kick bool) error { // Make the descriptor chain available to the device. sq.availableRing.offer(chains) diff --git a/overlay/virtqueue/used_ring.go b/overlay/virtqueue/used_ring.go index acf65fe..896af1a 100644 --- a/overlay/virtqueue/used_ring.go +++ b/overlay/virtqueue/used_ring.go @@ -153,11 +153,7 @@ func (r *UsedRing) takeOne() (UsedElement, bool) { } // InitOfferSingle is only used to pre-fill the used queue at startup, and should not be used if the device is running! -func (r *UsedRing) InitOfferSingle(x uint16, size int) { - //always called under lock - //r.mu.Lock() - //defer r.mu.Unlock() - +func (r *UsedRing) InitOfferSingle(x uint16, size uint32) { offset := 0 // Add descriptor chain heads to the ring. @@ -166,10 +162,8 @@ func (r *UsedRing) InitOfferSingle(x uint16, size int) { // size) is always a power of 2 and smaller than the highest possible // 16-bit value. insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring) - r.ring[insertIndex] = UsedElement{ - DescriptorIndex: uint32(x), - Length: uint32(size), - } + r.ring[insertIndex].DescriptorIndex = uint32(x) + r.ring[insertIndex].Length = size // Increase the ring index by the number of descriptor chains added to the ring. *r.ringIndex += 1 diff --git a/packet/packet.go b/packet/packet.go index 5ae6db6..50b8669 100644 --- a/packet/packet.go +++ b/packet/packet.go @@ -87,17 +87,15 @@ func (p *UDPPacket) updateCtrl(ctrlLen int) { if len(p.Control) == 0 { return } - cmsgs, err := unix.ParseSocketControlMessage(p.Control) + header, data, _ /*remain*/, err := unix.ParseOneSocketControlMessage(p.Control) if err != nil { return // oh well } - for _, c := range cmsgs { - if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 { - p.wasSegmented = true - p.SegSize = int(binary.LittleEndian.Uint16(c.Data[:2])) - return - } + if header.Level == unix.SOL_UDP && header.Type == unix.UDP_GRO && len(data) >= 2 { + p.wasSegmented = true + p.SegSize = int(binary.LittleEndian.Uint16(data[:2])) + return } }