diff --git a/overlay/vhostnet/device.go b/overlay/vhostnet/device.go index e4e65cd..17aa11d 100644 --- a/overlay/vhostnet/device.go +++ b/overlay/vhostnet/device.go @@ -298,7 +298,7 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.UsedElement) (int, error) { //read first element to see how many descriptors we need: pkt.Payload = pkt.Payload[:cap(pkt.Payload)] - n, err := dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[0].DescriptorIndex), pkt.Payload) + n, err := dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[0].DescriptorIndex), pkt.Payload, int(chains[0].Length)) //todo if err != nil { return 0, err } @@ -333,7 +333,7 @@ func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.Us i := 1 // we used chain 0 already for i = 1; i < len(chains); i++ { - n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:]) + n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:], int(chains[i].Length)) if err != nil { // When this fails we may miss to free some descriptor chains. We // could try to mitigate this by deferring the freeing somehow, but diff --git a/overlay/virtqueue/descriptor_table.go b/overlay/virtqueue/descriptor_table.go index 515facf..7a29c9a 100644 --- a/overlay/virtqueue/descriptor_table.go +++ b/overlay/virtqueue/descriptor_table.go @@ -349,7 +349,7 @@ func (dt *DescriptorTable) getDescriptorChain(head uint16) (outBuffers, inBuffer return } -func (dt *DescriptorTable) getDescriptorChainContents(head uint16, out []byte) (int, error) { +func (dt *DescriptorTable) getDescriptorChainContents(head uint16, out []byte, maxLen int) (int, error) { if int(head) > len(dt.descriptors) { return 0, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain) } @@ -387,7 +387,9 @@ func (dt *DescriptorTable) getDescriptorChainContents(head uint16, out []byte) ( next = desc.next } - + if maxLen > 0 { + //todo length = min(maxLen, length) + } //set out to length: out = out[:length] @@ -399,7 +401,7 @@ func (dt *DescriptorTable) getDescriptorChainContents(head uint16, out []byte) ( // The descriptor address points to memory not managed by Go, so this // conversion is safe. See https://github.com/golang/go/issues/58625 //goland:noinspection GoVetUnsafePointer - bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length) + bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), min(uint32(length-copied), desc.length)) copied += copy(out[copied:], bs) // Is this the tail of the chain? diff --git a/overlay/virtqueue/split_virtqueue.go b/overlay/virtqueue/split_virtqueue.go index 5d2620d..dec9a2b 100644 --- a/overlay/virtqueue/split_virtqueue.go +++ b/overlay/virtqueue/split_virtqueue.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "sync" + "syscall" "github.com/slackhq/nebula/overlay/eventfd" "golang.org/x/sys/unix" @@ -30,16 +31,6 @@ type SplitQueue struct { // chains and put them in the used ring. callEventFD eventfd.EventFD - // UsedChains is a chanel that receives [UsedElement]s for descriptor chains - // that were used by the device. - UsedChains chan UsedElement - - // moreFreeDescriptors is a channel that signals when any descriptors were - // put back into the free chain of the descriptor table. This is used to - // unblock methods waiting for available room in the queue to create new - // descriptor chains again. - moreFreeDescriptors chan struct{} - // stop is used by [SplitQueue.Close] to cancel the goroutine that handles // used buffer notifications. It blocks until the goroutine ended. stop func() error @@ -131,10 +122,6 @@ func NewSplitQueue(queueSize int) (_ *SplitQueue, err error) { return nil, fmt.Errorf("initialize descriptors: %w", err) } - // Initialize channels. - sq.UsedChains = make(chan UsedElement, queueSize) - sq.moreFreeDescriptors = make(chan struct{}) - sq.epoll, err = eventfd.NewEpoll() if err != nil { return nil, err @@ -366,7 +353,8 @@ func (sq *SplitQueue) OfferOutDescriptorChains(prepend []byte, outBuffers [][]by // Wait for more free descriptors to be put back into the queue. // If the number of free descriptors is still not sufficient, we'll // land here again. - <-sq.moreFreeDescriptors + //todo should never happen + syscall.Syscall(syscall.SYS_SCHED_YIELD, 0, 0, 0) // Cheap barrier continue } return nil, fmt.Errorf("create descriptor chain: %w", err) @@ -400,9 +388,9 @@ func (sq *SplitQueue) GetDescriptorChain(head uint16) (outBuffers, inBuffers [][ return sq.descriptorTable.getDescriptorChain(head) } -func (sq *SplitQueue) GetDescriptorChainContents(head uint16, out []byte) (int, error) { +func (sq *SplitQueue) GetDescriptorChainContents(head uint16, out []byte, maxLen int) (int, error) { sq.ensureInitialized() - return sq.descriptorTable.getDescriptorChainContents(head, out) + return sq.descriptorTable.getDescriptorChainContents(head, out, maxLen) } // FreeDescriptorChain frees the descriptor chain with the given head index. @@ -415,20 +403,11 @@ func (sq *SplitQueue) GetDescriptorChainContents(head uint16, out []byte) (int, // When there are outstanding calls for [SplitQueue.OfferDescriptorChain] that // are waiting for free room in the queue, they may become unblocked by this. func (sq *SplitQueue) FreeDescriptorChain(head uint16) error { - sq.ensureInitialized() - //not called under lock if err := sq.descriptorTable.freeDescriptorChain(head); err != nil { return fmt.Errorf("free: %w", err) } - // There is more free room in the descriptor table now. - // This is a fire-and-forget signal, so do not block when nobody listens. - select { //todo eliminate - case sq.moreFreeDescriptors <- struct{}{}: - default: - } - return nil } @@ -473,10 +452,6 @@ func (sq *SplitQueue) Close() error { errs = append(errs, fmt.Errorf("stop consume used ring: %w", err)) } - // The stop function blocked until the goroutine ended, so the channel - // can now safely be closed. - close(sq.UsedChains) - // Make sure that this code block is executed only once. sq.stop = nil } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 8861e4d..dac89d6 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -184,6 +184,13 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { return u.writeTo6(b, ip) } +func (u *StdConn) WriteToBatch(b []byte, ip netip.AddrPort) error { + if u.isV4 { + return u.writeTo4(b, ip) + } + return u.writeTo6(b, ip) +} + func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { var rsa unix.RawSockaddrInet6 rsa.Family = unix.AF_INET6