From e0f93c9d4b7bdbb4930e610e920153448658f13a Mon Sep 17 00:00:00 2001 From: JackDoan Date: Sat, 8 Nov 2025 12:03:47 -0600 Subject: [PATCH] yeah --- overlay/tun_linux.go | 2 +- overlay/vhost/memory.go | 6 +- overlay/vhostnet/device.go | 3 +- overlay/virtqueue/descriptor_table.go | 86 ++++++++++++++------------- 4 files changed, 49 insertions(+), 48 deletions(-) diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index e14339a..a2250fb 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -154,7 +154,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu vdev, err := vhostnet.NewDevice( vhostnet.WithBackendFD(fd), - vhostnet.WithQueueSize(16), //todo config + vhostnet.WithQueueSize(8192), //todo config ) if err != nil { return nil, err diff --git a/overlay/vhost/memory.go b/overlay/vhost/memory.go index 64ae34e..d9a94c3 100644 --- a/overlay/vhost/memory.go +++ b/overlay/vhost/memory.go @@ -3,7 +3,6 @@ package vhost import ( "encoding/binary" "fmt" - "os" "unsafe" "github.com/slackhq/nebula/overlay/virtqueue" @@ -33,15 +32,14 @@ type MemoryLayout []MemoryRegion // NewMemoryLayoutForQueues returns a new [MemoryLayout] that describes the // memory pages used by the descriptor tables of the given queues. func NewMemoryLayoutForQueues(queues []*virtqueue.SplitQueue) MemoryLayout { - pageSize := os.Getpagesize() regions := make([]MemoryRegion, 0) for _, queue := range queues { - for _, address := range queue.DescriptorTable().BufferAddresses() { + for address, size := range queue.DescriptorTable().BufferAddresses() { regions = append(regions, MemoryRegion{ // There is no virtualization in play here, so the guest address // is the same as in the host's userspace. GuestPhysicalAddress: address, - Size: uint64(pageSize), + Size: uint64(size), UserspaceAddress: address, }) } diff --git a/overlay/vhostnet/device.go b/overlay/vhostnet/device.go index 053eaa2..b7531ee 100644 --- a/overlay/vhostnet/device.go +++ b/overlay/vhostnet/device.go @@ -50,7 +50,8 @@ type Device struct { // - [WithBackendDevice] // // Remember to call [Device.Close] after use to free up resources. -func NewDevice(options ...Option) (_ *Device, err error) { +func NewDevice(options ...Option) (*Device, error) { + var err error opts := optionDefaults opts.apply(options) if err = opts.validate(); err != nil { diff --git a/overlay/virtqueue/descriptor_table.go b/overlay/virtqueue/descriptor_table.go index 8f99872..6602e32 100644 --- a/overlay/virtqueue/descriptor_table.go +++ b/overlay/virtqueue/descriptor_table.go @@ -52,6 +52,9 @@ type DescriptorTable struct { // freeNum tracks the number of descriptors which are currently not in use. freeNum uint16 + bufferBase uintptr + bufferSize int + mu sync.Mutex } @@ -82,29 +85,24 @@ func (dt *DescriptorTable) Address() uintptr { if dt.descriptors == nil { panic("descriptor table is not initialized") } + //should be same as dt.bufferBase return uintptr(unsafe.Pointer(&dt.descriptors[0])) } -// BufferAddresses returns pointers to all memory pages used by the descriptor -// table to store buffers, independent of whether descriptors are currently in -// use or not. These pointers can be helpful to set up memory mappings. Do not -// use them to access or modify the memory in any way. -// Each pointer points to a whole memory page. Use [os.Getpagesize] to get the -// page size. -func (dt *DescriptorTable) BufferAddresses() []uintptr { +func (dt *DescriptorTable) Size() uintptr { + if dt.descriptors == nil { + panic("descriptor table is not initialized") + } + return uintptr(dt.bufferSize) +} + +// BufferAddresses returns a map of pointer->size for all allocations used by the table +func (dt *DescriptorTable) BufferAddresses() map[uintptr]int { if dt.descriptors == nil { panic("descriptor table is not initialized") } - dt.mu.Lock() - defer dt.mu.Unlock() - - ptrs := make([]uintptr, len(dt.descriptors)) - for i, desc := range dt.descriptors { - ptrs[i] = desc.address - } - - return ptrs + return map[uintptr]int{dt.bufferBase: dt.bufferSize} } // initializeDescriptors allocates buffers with the size of a full memory page @@ -116,22 +114,29 @@ func (dt *DescriptorTable) BufferAddresses() []uintptr { // addresses of all descriptors will be populated while their length remains // zero. func (dt *DescriptorTable) initializeDescriptors() error { - pageSize := os.Getpagesize() + numDescriptors := len(dt.descriptors) + + itemSize := os.Getpagesize() //todo configurable? needs to be page-aligned + + // Allocate ONE large region for all buffers + totalSize := itemSize * numDescriptors + basePtr, err := unix.MmapPtr(-1, 0, nil, uintptr(totalSize), + unix.PROT_READ|unix.PROT_WRITE, + unix.MAP_PRIVATE|unix.MAP_ANONYMOUS) + if err != nil { + return fmt.Errorf("allocate buffer memory for descriptors: %w", err) + } dt.mu.Lock() defer dt.mu.Unlock() - for i := range dt.descriptors { - // Allocate a full memory page for this descriptor. - pagePtr, err := unix.MmapPtr(-1, 0, nil, uintptr(pageSize), - unix.PROT_READ|unix.PROT_WRITE, - unix.MAP_PRIVATE|unix.MAP_ANONYMOUS) - if err != nil { - return fmt.Errorf("allocate page for descriptor %d: %w", i, err) - } + // Store the base for cleanup later + dt.bufferBase = uintptr(basePtr) + dt.bufferSize = totalSize + for i := range dt.descriptors { dt.descriptors[i] = Descriptor{ - address: uintptr(pagePtr), + address: dt.bufferBase + uintptr(i*itemSize), length: 0, // All descriptors should form a free chain that loops around. flags: descriptorFlagHasNext, @@ -154,30 +159,27 @@ func (dt *DescriptorTable) releaseBuffers() error { dt.mu.Lock() defer dt.mu.Unlock() - var errs []error - pageSize := os.Getpagesize() for i := range dt.descriptors { descriptor := &dt.descriptors[i] - if descriptor.address == 0 { - continue - } - - // The pointer points to memory not managed by Go, so this conversion - // is safe. See https://github.com/golang/go/issues/58625 - //goland:noinspection GoVetUnsafePointer - err := unix.MunmapPtr(unsafe.Pointer(descriptor.address), uintptr(pageSize)) - if err == nil { - descriptor.address = 0 - } else { - errs = append(errs, fmt.Errorf("release page for descriptor %d: %w", i, err)) - } + descriptor.address = 0 } // As a safety measure, make sure no descriptors can be used anymore. dt.freeHeadIndex = noFreeHead dt.freeNum = 0 - return errors.Join(errs...) + if dt.bufferBase != 0 { + // The pointer points to memory not managed by Go, so this conversion + // is safe. See https://github.com/golang/go/issues/58625 + dt.bufferBase = 0 + //goland:noinspection GoVetUnsafePointer + err := unix.MunmapPtr(unsafe.Pointer(dt.bufferBase), uintptr(dt.bufferSize)) + if err != nil { + return fmt.Errorf("release buffer memory: %w", err) + } + } + + return nil } // createDescriptorChain creates a new descriptor chain within the descriptor