diff --git a/overlay/vhostnet/device.go b/overlay/vhostnet/device.go index cb862c5..435424b 100644 --- a/overlay/vhostnet/device.go +++ b/overlay/vhostnet/device.go @@ -123,9 +123,6 @@ func NewDevice(options ...Option) (*Device, error) { if err = dev.refillReceiveQueue(); err != nil { return nil, fmt.Errorf("refill receive queue: %w", err) } - if err = dev.refillTransmitQueue(); err != nil { - return nil, fmt.Errorf("refill receive queue: %w", err) - } dev.initialized = true @@ -153,22 +150,6 @@ func (dev *Device) refillReceiveQueue() error { } } -func (dev *Device) refillTransmitQueue() error { - //for { - // desc, err := dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs() - // if err != nil { - // if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) { - // // Queue is full, job is done. - // return nil - // } - // return fmt.Errorf("offer descriptor chain: %w", err) - // } else { - // dev.TransmitQueue.UsedRing().InitOfferSingle(desc, 0) - // } - //} - return nil -} - // Close cleans up the vhost networking device within the kernel and releases // all resources used for it. // The implementation will try to release as many resources as possible and @@ -214,14 +195,6 @@ func (dev *Device) Close() error { return errors.Join(errs...) } -// ensureInitialized is used as a guard to prevent methods to be called on an -// uninitialized instance. -func (dev *Device) ensureInitialized() { - if !dev.initialized { - panic("device is not initialized") - } -} - // createQueue creates a new virtqueue and registers it with the vhost device // using the given index. func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*virtqueue.SplitQueue, error) { @@ -238,30 +211,10 @@ func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*v return queue, nil } -// truncateBuffers returns a new list of buffers whose combined length matches -// exactly the specified length. When the specified length exceeds the length of -// the buffers, this is an error. When it is smaller, the buffer list will be -// truncated accordingly. -func truncateBuffers(buffers [][]byte, length int) (out [][]byte) { - for _, buffer := range buffers { - if length < len(buffer) { - out = append(out, buffer[:length]) - return - } - out = append(out, buffer) - length -= len(buffer) - } - if length > 0 { - panic("length exceeds the combined length of all buffers") - } - return -} - func (dev *Device) GetPacketForTx() (uint16, []byte, error) { var err error var idx uint16 if !dev.fullTable { - idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs() if err == virtqueue.ErrNotEnoughFreeDescriptors { dev.fullTable = true @@ -393,7 +346,7 @@ func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) { //todo optimize? var chains []virtqueue.UsedElement var err error - //if len(dev.extraRx) == 0 { + chains, err = dev.ReceiveQueue.BlockAndGetHeadsCapped(context.TODO(), len(out)) if err != nil { return 0, err @@ -401,9 +354,6 @@ func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) { if len(chains) == 0 { return 0, nil } - //} else { - // chains = dev.extraRx - //} numPackets := 0 chainsIdx := 0 @@ -418,10 +368,5 @@ func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) { chainsIdx += numChains } - // Now that we have copied all buffers, we can recycle the used descriptor chains - //if err = dev.ReceiveQueue.OfferDescriptorChains(chains); err != nil { - // return 0, err - //} - return numPackets, nil } diff --git a/overlay/virtqueue/descriptor_table.go b/overlay/virtqueue/descriptor_table.go index 44b8494..298036f 100644 --- a/overlay/virtqueue/descriptor_table.go +++ b/overlay/virtqueue/descriptor_table.go @@ -172,115 +172,6 @@ func (dt *DescriptorTable) releaseBuffers() error { return nil } -// createDescriptorChain creates a new descriptor chain within the descriptor -// table which contains a number of device-readable buffers (out buffers) and -// device-writable buffers (in buffers). -// -// All buffers in the outBuffers slice will be concatenated by chaining -// descriptors, one for each buffer in the slice. The size of the single buffers -// must not exceed the size of a memory page (see [os.Getpagesize]). -// When numInBuffers is greater than zero, the given number of device-writable -// descriptors will be appended to the end of the chain, each referencing a -// whole memory page. -// -// The index of the head of the new descriptor chain will be returned. Callers -// should make sure to free the descriptor chain using [freeDescriptorChain] -// after it was used by the device. -// -// When there are not enough free descriptors to hold the given number of -// buffers, an [ErrNotEnoughFreeDescriptors] will be returned. In this case, the -// caller should try again after some descriptor chains were used by the device -// and returned back into the free chain. -func (dt *DescriptorTable) createDescriptorChain(outBuffers [][]byte, numInBuffers int) (uint16, error) { - // Calculate the number of descriptors needed to build the chain. - numDesc := uint16(len(outBuffers) + numInBuffers) - - // Descriptor chains must always contain at least one descriptor. - if numDesc < 1 { - return 0, ErrDescriptorChainEmpty - } - - // Do we still have enough free descriptors? - if numDesc > dt.freeNum { - return 0, ErrNotEnoughFreeDescriptors - } - - // Above validation ensured that there is at least one free descriptor, so - // the free descriptor chain head should be valid. - if dt.freeHeadIndex == noFreeHead { - panic("free descriptor chain head is unset but there should be free descriptors") - } - - // To avoid having to iterate over the whole table to find the descriptor - // pointing to the head just to replace the free head, we instead always - // create descriptor chains from the descriptors coming after the head. - // This way we only have to touch the head as a last resort, when all other - // descriptors are already used. - head := dt.descriptors[dt.freeHeadIndex].next - next := head - tail := head - for i, buffer := range outBuffers { - desc := &dt.descriptors[next] - checkUnusedDescriptorLength(next, desc) - - if len(buffer) > dt.itemSize { - // The caller should already prevent that from happening. - panic(fmt.Sprintf("out buffer %d has size %d which exceeds desc length %d", i, len(buffer), dt.itemSize)) - } - - // Copy the buffer to the memory referenced by the descriptor. - // The descriptor address points to memory not managed by Go, so this - // conversion is safe. See https://github.com/golang/go/issues/58625 - //goland:noinspection GoVetUnsafePointer - copy(unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), dt.itemSize), buffer) - desc.length = uint32(len(buffer)) - - // Clear the flags in case there were any others set. - desc.flags = descriptorFlagHasNext - - tail = next - next = desc.next - } - for range numInBuffers { - desc := &dt.descriptors[next] - checkUnusedDescriptorLength(next, desc) - - // Give the device the maximum available number of bytes to write into. - desc.length = uint32(dt.itemSize) - - // Mark the descriptor as device-writable. - desc.flags = descriptorFlagHasNext | descriptorFlagWritable - - tail = next - next = desc.next - } - - // The last descriptor should end the chain. - tailDesc := &dt.descriptors[tail] - tailDesc.flags &= ^descriptorFlagHasNext - tailDesc.next = 0 // Not necessary to clear this, it's just for looks. - - dt.freeNum -= numDesc - - if dt.freeNum == 0 { - // The last descriptor in the chain should be the free chain head - // itself. - if tail != dt.freeHeadIndex { - panic("descriptor chain takes up all free descriptors but does not end with the free chain head") - } - - // When this new chain takes up all remaining descriptors, we no longer - // have a free chain. - dt.freeHeadIndex = noFreeHead - } else { - // We took some descriptors out of the free chain, so make sure to close - // the circle again. - dt.descriptors[dt.freeHeadIndex].next = next - } - - return head, nil -} - func (dt *DescriptorTable) CreateDescriptorForOutputs() (uint16, error) { //todo just fill the damn table // Do we still have enough free descriptors? @@ -490,73 +381,6 @@ func (dt *DescriptorTable) getDescriptorInbuffers(head uint16, inBuffers *[][]by return nil } -func (dt *DescriptorTable) getDescriptorChainContents(head uint16, out []byte, maxLen int) (int, error) { - if int(head) > len(dt.descriptors) { - return 0, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain) - } - - // Iterate over the chain. The iteration is limited to the queue size to - // avoid ending up in an endless loop when things go very wrong. - - length := 0 - //find length - next := head - for range len(dt.descriptors) { - if next == dt.freeHeadIndex { - return 0, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain) - } - - desc := &dt.descriptors[next] - - if desc.flags&descriptorFlagWritable == 0 { - return 0, fmt.Errorf("receive queue contains device-readable buffer") - } - length += int(desc.length) - - // Is this the tail of the chain? - if desc.flags&descriptorFlagHasNext == 0 { - break - } - - // Detect loops. - if desc.next == head { - return 0, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain) - } - - next = desc.next - } - if maxLen > 0 { - //todo length = min(maxLen, length) - } - //set out to length: - out = out[:length] - - //now do the copying - copied := 0 - for range len(dt.descriptors) { - desc := &dt.descriptors[next] - - // The descriptor address points to memory not managed by Go, so this - // conversion is safe. See https://github.com/golang/go/issues/58625 - //goland:noinspection GoVetUnsafePointer - bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), min(uint32(length-copied), desc.length)) - copied += copy(out[copied:], bs) - - // Is this the tail of the chain? - if desc.flags&descriptorFlagHasNext == 0 { - break - } - - // we did this already, no need to detect loops. - next = desc.next - } - if copied != length { - panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", length, copied)) - } - - return length, nil -} - // freeDescriptorChain can be used to free a descriptor chain when it is no // longer in use. The descriptor chain that starts with the given index will be // put back into the free chain, so the descriptors can be used for later calls diff --git a/overlay/virtqueue/descriptor_table_internal_test.go b/overlay/virtqueue/descriptor_table_internal_test.go deleted file mode 100644 index 50803e0..0000000 --- a/overlay/virtqueue/descriptor_table_internal_test.go +++ /dev/null @@ -1,407 +0,0 @@ -package virtqueue - -import ( - "os" - "testing" - "unsafe" - - "github.com/stretchr/testify/assert" -) - -func TestDescriptorTable_InitializeDescriptors(t *testing.T) { - const queueSize = 32 - - dt := DescriptorTable{ - descriptors: make([]Descriptor, queueSize), - } - - assert.NoError(t, dt.initializeDescriptors()) - t.Cleanup(func() { - assert.NoError(t, dt.releaseBuffers()) - }) - - for i, descriptor := range dt.descriptors { - assert.NotZero(t, descriptor.address) - assert.Zero(t, descriptor.length) - assert.EqualValues(t, descriptorFlagHasNext, descriptor.flags) - assert.EqualValues(t, (i+1)%queueSize, descriptor.next) - } -} - -func TestDescriptorTable_DescriptorChains(t *testing.T) { - // Use a very short queue size to not make this test overly verbose. - const queueSize = 8 - - pageSize := os.Getpagesize() * 2 - - // Initialize descriptor table. - dt := DescriptorTable{ - descriptors: make([]Descriptor, queueSize), - } - assert.NoError(t, dt.initializeDescriptors()) - t.Cleanup(func() { - assert.NoError(t, dt.releaseBuffers()) - }) - - // Some utilities for easier checking if the descriptor table looks as - // expected. - type desc struct { - buffer []byte - flags descriptorFlag - next uint16 - } - assertDescriptorTable := func(expected [queueSize]desc) { - for i := 0; i < queueSize; i++ { - actualDesc := &dt.descriptors[i] - expectedDesc := &expected[i] - assert.Equal(t, uint32(len(expectedDesc.buffer)), actualDesc.length) - if len(expectedDesc.buffer) > 0 { - //goland:noinspection GoVetUnsafePointer - assert.EqualValues(t, - unsafe.Slice((*byte)(unsafe.Pointer(actualDesc.address)), actualDesc.length), - expectedDesc.buffer) - } - assert.Equal(t, expectedDesc.flags, actualDesc.flags) - if expectedDesc.flags&descriptorFlagHasNext != 0 { - assert.Equal(t, expectedDesc.next, actualDesc.next) - } - } - } - - // Initial state: All descriptors are in the free chain. - assert.Equal(t, uint16(0), dt.freeHeadIndex) - assert.Equal(t, uint16(8), dt.freeNum) - assertDescriptorTable([queueSize]desc{ - { - // Free head. - flags: descriptorFlagHasNext, - next: 1, - }, - { - flags: descriptorFlagHasNext, - next: 2, - }, - { - flags: descriptorFlagHasNext, - next: 3, - }, - { - flags: descriptorFlagHasNext, - next: 4, - }, - { - flags: descriptorFlagHasNext, - next: 5, - }, - { - flags: descriptorFlagHasNext, - next: 6, - }, - { - flags: descriptorFlagHasNext, - next: 7, - }, - { - flags: descriptorFlagHasNext, - next: 0, - }, - }) - - // Create the first chain. - firstChain, err := dt.createDescriptorChain([][]byte{ - makeTestBuffer(t, 26), - makeTestBuffer(t, 256), - }, 1) - assert.NoError(t, err) - assert.Equal(t, uint16(1), firstChain) - - // Now there should be a new chain next to the free chain. - assert.Equal(t, uint16(0), dt.freeHeadIndex) - assert.Equal(t, uint16(5), dt.freeNum) - assertDescriptorTable([queueSize]desc{ - { - // Free head. - flags: descriptorFlagHasNext, - next: 4, - }, - { - // Head of first chain. - buffer: makeTestBuffer(t, 26), - flags: descriptorFlagHasNext, - next: 2, - }, - { - buffer: makeTestBuffer(t, 256), - flags: descriptorFlagHasNext, - next: 3, - }, - { - // Tail of first chain. - buffer: make([]byte, pageSize), - flags: descriptorFlagWritable, - }, - { - flags: descriptorFlagHasNext, - next: 5, - }, - { - flags: descriptorFlagHasNext, - next: 6, - }, - { - flags: descriptorFlagHasNext, - next: 7, - }, - { - flags: descriptorFlagHasNext, - next: 0, - }, - }) - - // Create a second chain with only a single in buffer. - secondChain, err := dt.createDescriptorChain(nil, 1) - assert.NoError(t, err) - assert.Equal(t, uint16(4), secondChain) - - // Now there should be two chains next to the free chain. - assert.Equal(t, uint16(0), dt.freeHeadIndex) - assert.Equal(t, uint16(4), dt.freeNum) - assertDescriptorTable([queueSize]desc{ - { - // Free head. - flags: descriptorFlagHasNext, - next: 5, - }, - { - // Head of the first chain. - buffer: makeTestBuffer(t, 26), - flags: descriptorFlagHasNext, - next: 2, - }, - { - buffer: makeTestBuffer(t, 256), - flags: descriptorFlagHasNext, - next: 3, - }, - { - // Tail of the first chain. - buffer: make([]byte, pageSize), - flags: descriptorFlagWritable, - }, - { - // Head and tail of the second chain. - buffer: make([]byte, pageSize), - flags: descriptorFlagWritable, - }, - { - flags: descriptorFlagHasNext, - next: 6, - }, - { - flags: descriptorFlagHasNext, - next: 7, - }, - { - flags: descriptorFlagHasNext, - next: 0, - }, - }) - - // Create a third chain taking up all remaining descriptors. - thirdChain, err := dt.createDescriptorChain([][]byte{ - makeTestBuffer(t, 42), - makeTestBuffer(t, 96), - makeTestBuffer(t, 33), - makeTestBuffer(t, 222), - }, 0) - assert.NoError(t, err) - assert.Equal(t, uint16(5), thirdChain) - - // Now there should be three chains and no free chain. - assert.Equal(t, noFreeHead, dt.freeHeadIndex) - assert.Equal(t, uint16(0), dt.freeNum) - assertDescriptorTable([queueSize]desc{ - { - // Tail of the third chain. - buffer: makeTestBuffer(t, 222), - }, - { - // Head of the first chain. - buffer: makeTestBuffer(t, 26), - flags: descriptorFlagHasNext, - next: 2, - }, - { - buffer: makeTestBuffer(t, 256), - flags: descriptorFlagHasNext, - next: 3, - }, - { - // Tail of the first chain. - buffer: make([]byte, pageSize), - flags: descriptorFlagWritable, - }, - { - // Head and tail of the second chain. - buffer: make([]byte, pageSize), - flags: descriptorFlagWritable, - }, - { - // Head of the third chain. - buffer: makeTestBuffer(t, 42), - flags: descriptorFlagHasNext, - next: 6, - }, - { - buffer: makeTestBuffer(t, 96), - flags: descriptorFlagHasNext, - next: 7, - }, - { - buffer: makeTestBuffer(t, 33), - flags: descriptorFlagHasNext, - next: 0, - }, - }) - - // Free the third chain. - assert.NoError(t, dt.freeDescriptorChain(thirdChain)) - - // Now there should be two chains and a free chain again. - assert.Equal(t, uint16(5), dt.freeHeadIndex) - assert.Equal(t, uint16(4), dt.freeNum) - assertDescriptorTable([queueSize]desc{ - { - flags: descriptorFlagHasNext, - next: 5, - }, - { - // Head of the first chain. - buffer: makeTestBuffer(t, 26), - flags: descriptorFlagHasNext, - next: 2, - }, - { - buffer: makeTestBuffer(t, 256), - flags: descriptorFlagHasNext, - next: 3, - }, - { - // Tail of the first chain. - buffer: make([]byte, pageSize), - flags: descriptorFlagWritable, - }, - { - // Head and tail of the second chain. - buffer: make([]byte, pageSize), - flags: descriptorFlagWritable, - }, - { - // Free head. - flags: descriptorFlagHasNext, - next: 6, - }, - { - flags: descriptorFlagHasNext, - next: 7, - }, - { - flags: descriptorFlagHasNext, - next: 0, - }, - }) - - // Free the first chain. - assert.NoError(t, dt.freeDescriptorChain(firstChain)) - - // Now there should be only a single chain next to the free chain. - assert.Equal(t, uint16(5), dt.freeHeadIndex) - assert.Equal(t, uint16(7), dt.freeNum) - assertDescriptorTable([queueSize]desc{ - { - flags: descriptorFlagHasNext, - next: 5, - }, - { - flags: descriptorFlagHasNext, - next: 2, - }, - { - flags: descriptorFlagHasNext, - next: 3, - }, - { - flags: descriptorFlagHasNext, - next: 6, - }, - { - // Head and tail of the second chain. - buffer: make([]byte, pageSize), - flags: descriptorFlagWritable, - }, - { - // Free head. - flags: descriptorFlagHasNext, - next: 1, - }, - { - flags: descriptorFlagHasNext, - next: 7, - }, - { - flags: descriptorFlagHasNext, - next: 0, - }, - }) - - // Free the second chain. - assert.NoError(t, dt.freeDescriptorChain(secondChain)) - - // Now all descriptors should be in the free chain again. - assert.Equal(t, uint16(5), dt.freeHeadIndex) - assert.Equal(t, uint16(8), dt.freeNum) - assertDescriptorTable([queueSize]desc{ - { - flags: descriptorFlagHasNext, - next: 5, - }, - { - flags: descriptorFlagHasNext, - next: 2, - }, - { - flags: descriptorFlagHasNext, - next: 3, - }, - { - flags: descriptorFlagHasNext, - next: 6, - }, - { - flags: descriptorFlagHasNext, - next: 1, - }, - { - // Free head. - flags: descriptorFlagHasNext, - next: 4, - }, - { - flags: descriptorFlagHasNext, - next: 7, - }, - { - flags: descriptorFlagHasNext, - next: 0, - }, - }) -} - -func makeTestBuffer(t *testing.T, length int) []byte { - t.Helper() - buf := make([]byte, length) - for i := 0; i < length; i++ { - buf[i] = byte(length - i) - } - return buf -} diff --git a/overlay/virtqueue/split_virtqueue.go b/overlay/virtqueue/split_virtqueue.go index 0da7734..59421ea 100644 --- a/overlay/virtqueue/split_virtqueue.go +++ b/overlay/virtqueue/split_virtqueue.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "os" - "syscall" "github.com/slackhq/nebula/overlay/eventfd" "golang.org/x/sys/unix" @@ -186,28 +185,6 @@ func (sq *SplitQueue) startConsumeUsedRing() func() error { } } -// BlockAndGetHeads waits for the device to signal that it has used descriptor chains and returns all [UsedElement]s -func (sq *SplitQueue) BlockAndGetHeads(ctx context.Context) ([]UsedElement, error) { - var n int - var err error - for ctx.Err() == nil { - - // Wait for a signal from the device. - if n, err = sq.epoll.Block(); err != nil { - return nil, fmt.Errorf("wait: %w", err) - } - if n > 0 { - stillNeedToTake, out := sq.usedRing.take(-1) - sq.more = stillNeedToTake - if stillNeedToTake == 0 { - _ = sq.epoll.Clear() //??? - } - return out, nil - } - } - return nil, ctx.Err() -} - func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) { var n int var err error @@ -326,53 +303,6 @@ func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) { return head, nil } -func (sq *SplitQueue) OfferOutDescriptorChains(prepend []byte, outBuffers [][]byte) ([]uint16, error) { - // TODO change this - // Each descriptor can only hold a whole memory page, so split large out - // buffers into multiple smaller ones. - outBuffers = splitBuffers(outBuffers, sq.itemSize) - - chains := make([]uint16, len(outBuffers)) - - // Create a descriptor chain for the given buffers. - var ( - head uint16 - err error - ) - for i := range outBuffers { - for { - bufs := [][]byte{prepend, outBuffers[i]} - head, err = sq.descriptorTable.createDescriptorChain(bufs, 0) - if err == nil { - break - } - - // I don't wanna use errors.Is, it's slow - //goland:noinspection GoDirectComparisonOfErrors - if err == ErrNotEnoughFreeDescriptors { - // Wait for more free descriptors to be put back into the queue. - // If the number of free descriptors is still not sufficient, we'll - // land here again. - //todo should never happen - syscall.Syscall(syscall.SYS_SCHED_YIELD, 0, 0, 0) // Cheap barrier - continue - } - return nil, fmt.Errorf("create descriptor chain: %w", err) - } - chains[i] = head - } - - // Make the descriptor chain available to the device. - sq.availableRing.offer(chains) - - // Notify the device to make it process the updated available ring. - if err := sq.kickEventFD.Kick(); err != nil { - return chains, fmt.Errorf("notify device: %w", err) - } - - return chains, nil -} - // GetDescriptorChain returns the device-readable buffers (out buffers) and // device-writable buffers (in buffers) of the descriptor chain with the given // head index. @@ -392,10 +322,6 @@ func (sq *SplitQueue) GetDescriptorItem(head uint16) ([]byte, error) { return sq.descriptorTable.getDescriptorItem(head) } -func (sq *SplitQueue) GetDescriptorChainContents(head uint16, out []byte, maxLen int) (int, error) { - return sq.descriptorTable.getDescriptorChainContents(head, out, maxLen) -} - func (sq *SplitQueue) GetDescriptorInbuffers(head uint16, inBuffers *[][]byte) error { return sq.descriptorTable.getDescriptorInbuffers(head, inBuffers) } @@ -486,14 +412,6 @@ func (sq *SplitQueue) Close() error { return errors.Join(errs...) } -// ensureInitialized is used as a guard to prevent methods to be called on an -// uninitialized instance. -func (sq *SplitQueue) ensureInitialized() { - if sq.buf == nil { - panic("used ring is not initialized") - } -} - func align(index, alignment int) int { remainder := index % alignment if remainder == 0 { @@ -501,30 +419,3 @@ func align(index, alignment int) int { } return index + alignment - remainder } - -// splitBuffers processes a list of buffers and splits each buffer that is -// larger than the size limit into multiple smaller buffers. -// If none of the buffers are too big though, do nothing, to avoid allocation for now -func splitBuffers(buffers [][]byte, sizeLimit int) [][]byte { - for i := range buffers { - if len(buffers[i]) > sizeLimit { - return reallySplitBuffers(buffers, sizeLimit) - } - } - return buffers -} - -func reallySplitBuffers(buffers [][]byte, sizeLimit int) [][]byte { - result := make([][]byte, 0, len(buffers)) - for _, buffer := range buffers { - for added := 0; added < len(buffer); added += sizeLimit { - if len(buffer)-added <= sizeLimit { - result = append(result, buffer[added:]) - break - } - result = append(result, buffer[added:added+sizeLimit]) - } - } - - return result -} diff --git a/overlay/virtqueue/split_virtqueue_internal_test.go b/overlay/virtqueue/split_virtqueue_internal_test.go deleted file mode 100644 index d353df7..0000000 --- a/overlay/virtqueue/split_virtqueue_internal_test.go +++ /dev/null @@ -1,105 +0,0 @@ -package virtqueue - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestSplitQueue_MemoryAlignment(t *testing.T) { - tests := []struct { - name string - queueSize int - }{ - { - name: "minimal queue size", - queueSize: 1, - }, - { - name: "small queue size", - queueSize: 8, - }, - { - name: "large queue size", - queueSize: 256, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - sq, err := NewSplitQueue(tt.queueSize) - require.NoError(t, err) - - assert.Zero(t, sq.descriptorTable.Address()%descriptorTableAlignment) - assert.Zero(t, sq.availableRing.Address()%availableRingAlignment) - assert.Zero(t, sq.usedRing.Address()%usedRingAlignment) - }) - } -} - -func TestSplitBuffers(t *testing.T) { - const sizeLimit = 16 - tests := []struct { - name string - buffers [][]byte - expected [][]byte - }{ - { - name: "no buffers", - buffers: make([][]byte, 0), - expected: make([][]byte, 0), - }, - { - name: "small", - buffers: [][]byte{ - make([]byte, 11), - }, - expected: [][]byte{ - make([]byte, 11), - }, - }, - { - name: "exact size", - buffers: [][]byte{ - make([]byte, sizeLimit), - }, - expected: [][]byte{ - make([]byte, sizeLimit), - }, - }, - { - name: "large", - buffers: [][]byte{ - make([]byte, 42), - }, - expected: [][]byte{ - make([]byte, 16), - make([]byte, 16), - make([]byte, 10), - }, - }, - { - name: "mixed", - buffers: [][]byte{ - make([]byte, 7), - make([]byte, 30), - make([]byte, 15), - make([]byte, 32), - }, - expected: [][]byte{ - make([]byte, 7), - make([]byte, 16), - make([]byte, 14), - make([]byte, 15), - make([]byte, 16), - make([]byte, 16), - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - actual := splitBuffers(tt.buffers, sizeLimit) - assert.Equal(t, tt.expected, actual) - }) - } -}