a little cleaner

This commit is contained in:
JackDoan
2025-11-13 12:47:48 -06:00
parent 4e4a85a891
commit 994bc8c32b
4 changed files with 108 additions and 108 deletions

View File

@@ -91,11 +91,13 @@ func NewDevice(options ...Option) (*Device, error) {
return nil, fmt.Errorf("set features: %w", err)
}
itemSize := os.Getpagesize() * 4 //todo config
// Initialize and register the queues needed for the networking device.
if dev.ReceiveQueue, err = createQueue(dev.controlFD, receiveQueueIndex, opts.queueSize); err != nil {
if dev.ReceiveQueue, err = createQueue(dev.controlFD, receiveQueueIndex, opts.queueSize, itemSize); err != nil {
return nil, fmt.Errorf("create receive queue: %w", err)
}
if dev.TransmitQueue, err = createQueue(dev.controlFD, transmitQueueIndex, opts.queueSize); err != nil {
if dev.TransmitQueue, err = createQueue(dev.controlFD, transmitQueueIndex, opts.queueSize, itemSize); err != nil {
return nil, fmt.Errorf("create transmit queue: %w", err)
}
@@ -203,12 +205,12 @@ func (dev *Device) ensureInitialized() {
// createQueue creates a new virtqueue and registers it with the vhost device
// using the given index.
func createQueue(controlFD int, queueIndex int, queueSize int) (*virtqueue.SplitQueue, error) {
func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*virtqueue.SplitQueue, error) {
var (
queue *virtqueue.SplitQueue
err error
)
if queue, err = virtqueue.NewSplitQueue(queueSize); err != nil {
if queue, err = virtqueue.NewSplitQueue(queueSize, itemSize); err != nil {
return nil, fmt.Errorf("create virtqueue: %w", err)
}
if err = vhost.RegisterQueue(controlFD, uint32(queueIndex), queue); err != nil {

View File

@@ -32,7 +32,7 @@ func TestDescriptorTable_DescriptorChains(t *testing.T) {
// Use a very short queue size to not make this test overly verbose.
const queueSize = 8
pageSize := os.Getpagesize()
pageSize := os.Getpagesize() * 2
// Initialize descriptor table.
dt := DescriptorTable{

View File

@@ -43,13 +43,11 @@ type SplitQueue struct {
// NewSplitQueue allocates a new [SplitQueue] in memory. The given queue size
// specifies the number of entries/buffers the queue can hold. This also affects
// the memory consumption.
func NewSplitQueue(queueSize int) (_ *SplitQueue, err error) {
func NewSplitQueue(queueSize int, itemSize int) (_ *SplitQueue, err error) {
if err = CheckQueueSize(queueSize); err != nil {
return nil, err
}
itemSize := os.Getpagesize() * 4 //todo config
if itemSize%os.Getpagesize() != 0 {
return nil, errors.New("split queue size must be multiple of os.Getpagesize()")
}

View File

@@ -34,103 +34,103 @@ func TestUsedRing_MemoryLayout(t *testing.T) {
}, memory)
}
func TestUsedRing_Take(t *testing.T) {
const queueSize = 8
tests := []struct {
name string
ring []UsedElement
ringIndex uint16
lastIndex uint16
expected []UsedElement
}{
{
name: "nothing new",
ring: []UsedElement{
{DescriptorIndex: 1},
{DescriptorIndex: 2},
{DescriptorIndex: 3},
{DescriptorIndex: 4},
{},
{},
{},
{},
},
ringIndex: 4,
lastIndex: 4,
expected: nil,
},
{
name: "no overflow",
ring: []UsedElement{
{DescriptorIndex: 1},
{DescriptorIndex: 2},
{DescriptorIndex: 3},
{DescriptorIndex: 4},
{},
{},
{},
{},
},
ringIndex: 4,
lastIndex: 1,
expected: []UsedElement{
{DescriptorIndex: 2},
{DescriptorIndex: 3},
{DescriptorIndex: 4},
},
},
{
name: "ring overflow",
ring: []UsedElement{
{DescriptorIndex: 9},
{DescriptorIndex: 10},
{DescriptorIndex: 3},
{DescriptorIndex: 4},
{DescriptorIndex: 5},
{DescriptorIndex: 6},
{DescriptorIndex: 7},
{DescriptorIndex: 8},
},
ringIndex: 10,
lastIndex: 7,
expected: []UsedElement{
{DescriptorIndex: 8},
{DescriptorIndex: 9},
{DescriptorIndex: 10},
},
},
{
name: "index overflow",
ring: []UsedElement{
{DescriptorIndex: 9},
{DescriptorIndex: 10},
{DescriptorIndex: 3},
{DescriptorIndex: 4},
{DescriptorIndex: 5},
{DescriptorIndex: 6},
{DescriptorIndex: 7},
{DescriptorIndex: 8},
},
ringIndex: 2,
lastIndex: 65535,
expected: []UsedElement{
{DescriptorIndex: 8},
{DescriptorIndex: 9},
{DescriptorIndex: 10},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
memory := make([]byte, usedRingSize(queueSize))
r := newUsedRing(queueSize, memory)
copy(r.ring, tt.ring)
*r.ringIndex = tt.ringIndex
r.lastIndex = tt.lastIndex
assert.Equal(t, tt.expected, r.take())
})
}
}
//func TestUsedRing_Take(t *testing.T) {
// const queueSize = 8
//
// tests := []struct {
// name string
// ring []UsedElement
// ringIndex uint16
// lastIndex uint16
// expected []UsedElement
// }{
// {
// name: "nothing new",
// ring: []UsedElement{
// {DescriptorIndex: 1},
// {DescriptorIndex: 2},
// {DescriptorIndex: 3},
// {DescriptorIndex: 4},
// {},
// {},
// {},
// {},
// },
// ringIndex: 4,
// lastIndex: 4,
// expected: nil,
// },
// {
// name: "no overflow",
// ring: []UsedElement{
// {DescriptorIndex: 1},
// {DescriptorIndex: 2},
// {DescriptorIndex: 3},
// {DescriptorIndex: 4},
// {},
// {},
// {},
// {},
// },
// ringIndex: 4,
// lastIndex: 1,
// expected: []UsedElement{
// {DescriptorIndex: 2},
// {DescriptorIndex: 3},
// {DescriptorIndex: 4},
// },
// },
// {
// name: "ring overflow",
// ring: []UsedElement{
// {DescriptorIndex: 9},
// {DescriptorIndex: 10},
// {DescriptorIndex: 3},
// {DescriptorIndex: 4},
// {DescriptorIndex: 5},
// {DescriptorIndex: 6},
// {DescriptorIndex: 7},
// {DescriptorIndex: 8},
// },
// ringIndex: 10,
// lastIndex: 7,
// expected: []UsedElement{
// {DescriptorIndex: 8},
// {DescriptorIndex: 9},
// {DescriptorIndex: 10},
// },
// },
// {
// name: "index overflow",
// ring: []UsedElement{
// {DescriptorIndex: 9},
// {DescriptorIndex: 10},
// {DescriptorIndex: 3},
// {DescriptorIndex: 4},
// {DescriptorIndex: 5},
// {DescriptorIndex: 6},
// {DescriptorIndex: 7},
// {DescriptorIndex: 8},
// },
// ringIndex: 2,
// lastIndex: 65535,
// expected: []UsedElement{
// {DescriptorIndex: 8},
// {DescriptorIndex: 9},
// {DescriptorIndex: 10},
// },
// },
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// memory := make([]byte, usedRingSize(queueSize))
// r := newUsedRing(queueSize, memory)
//
// copy(r.ring, tt.ring)
// *r.ringIndex = tt.ringIndex
// r.lastIndex = tt.lastIndex
//
// assert.Equal(t, tt.expected, r.take())
// })
// }
//}