mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 00:15:37 +01:00
yeah
This commit is contained in:
@@ -154,7 +154,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
||||
|
||||
vdev, err := vhostnet.NewDevice(
|
||||
vhostnet.WithBackendFD(fd),
|
||||
vhostnet.WithQueueSize(16), //todo config
|
||||
vhostnet.WithQueueSize(8192), //todo config
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -3,7 +3,6 @@ package vhost
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"os"
|
||||
"unsafe"
|
||||
|
||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
||||
@@ -33,15 +32,14 @@ type MemoryLayout []MemoryRegion
|
||||
// NewMemoryLayoutForQueues returns a new [MemoryLayout] that describes the
|
||||
// memory pages used by the descriptor tables of the given queues.
|
||||
func NewMemoryLayoutForQueues(queues []*virtqueue.SplitQueue) MemoryLayout {
|
||||
pageSize := os.Getpagesize()
|
||||
regions := make([]MemoryRegion, 0)
|
||||
for _, queue := range queues {
|
||||
for _, address := range queue.DescriptorTable().BufferAddresses() {
|
||||
for address, size := range queue.DescriptorTable().BufferAddresses() {
|
||||
regions = append(regions, MemoryRegion{
|
||||
// There is no virtualization in play here, so the guest address
|
||||
// is the same as in the host's userspace.
|
||||
GuestPhysicalAddress: address,
|
||||
Size: uint64(pageSize),
|
||||
Size: uint64(size),
|
||||
UserspaceAddress: address,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -50,7 +50,8 @@ type Device struct {
|
||||
// - [WithBackendDevice]
|
||||
//
|
||||
// Remember to call [Device.Close] after use to free up resources.
|
||||
func NewDevice(options ...Option) (_ *Device, err error) {
|
||||
func NewDevice(options ...Option) (*Device, error) {
|
||||
var err error
|
||||
opts := optionDefaults
|
||||
opts.apply(options)
|
||||
if err = opts.validate(); err != nil {
|
||||
|
||||
@@ -52,6 +52,9 @@ type DescriptorTable struct {
|
||||
// freeNum tracks the number of descriptors which are currently not in use.
|
||||
freeNum uint16
|
||||
|
||||
bufferBase uintptr
|
||||
bufferSize int
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
@@ -82,29 +85,24 @@ func (dt *DescriptorTable) Address() uintptr {
|
||||
if dt.descriptors == nil {
|
||||
panic("descriptor table is not initialized")
|
||||
}
|
||||
//should be same as dt.bufferBase
|
||||
return uintptr(unsafe.Pointer(&dt.descriptors[0]))
|
||||
}
|
||||
|
||||
// BufferAddresses returns pointers to all memory pages used by the descriptor
|
||||
// table to store buffers, independent of whether descriptors are currently in
|
||||
// use or not. These pointers can be helpful to set up memory mappings. Do not
|
||||
// use them to access or modify the memory in any way.
|
||||
// Each pointer points to a whole memory page. Use [os.Getpagesize] to get the
|
||||
// page size.
|
||||
func (dt *DescriptorTable) BufferAddresses() []uintptr {
|
||||
func (dt *DescriptorTable) Size() uintptr {
|
||||
if dt.descriptors == nil {
|
||||
panic("descriptor table is not initialized")
|
||||
}
|
||||
return uintptr(dt.bufferSize)
|
||||
}
|
||||
|
||||
// BufferAddresses returns a map of pointer->size for all allocations used by the table
|
||||
func (dt *DescriptorTable) BufferAddresses() map[uintptr]int {
|
||||
if dt.descriptors == nil {
|
||||
panic("descriptor table is not initialized")
|
||||
}
|
||||
|
||||
dt.mu.Lock()
|
||||
defer dt.mu.Unlock()
|
||||
|
||||
ptrs := make([]uintptr, len(dt.descriptors))
|
||||
for i, desc := range dt.descriptors {
|
||||
ptrs[i] = desc.address
|
||||
}
|
||||
|
||||
return ptrs
|
||||
return map[uintptr]int{dt.bufferBase: dt.bufferSize}
|
||||
}
|
||||
|
||||
// initializeDescriptors allocates buffers with the size of a full memory page
|
||||
@@ -116,22 +114,29 @@ func (dt *DescriptorTable) BufferAddresses() []uintptr {
|
||||
// addresses of all descriptors will be populated while their length remains
|
||||
// zero.
|
||||
func (dt *DescriptorTable) initializeDescriptors() error {
|
||||
pageSize := os.Getpagesize()
|
||||
numDescriptors := len(dt.descriptors)
|
||||
|
||||
itemSize := os.Getpagesize() //todo configurable? needs to be page-aligned
|
||||
|
||||
// Allocate ONE large region for all buffers
|
||||
totalSize := itemSize * numDescriptors
|
||||
basePtr, err := unix.MmapPtr(-1, 0, nil, uintptr(totalSize),
|
||||
unix.PROT_READ|unix.PROT_WRITE,
|
||||
unix.MAP_PRIVATE|unix.MAP_ANONYMOUS)
|
||||
if err != nil {
|
||||
return fmt.Errorf("allocate buffer memory for descriptors: %w", err)
|
||||
}
|
||||
|
||||
dt.mu.Lock()
|
||||
defer dt.mu.Unlock()
|
||||
|
||||
for i := range dt.descriptors {
|
||||
// Allocate a full memory page for this descriptor.
|
||||
pagePtr, err := unix.MmapPtr(-1, 0, nil, uintptr(pageSize),
|
||||
unix.PROT_READ|unix.PROT_WRITE,
|
||||
unix.MAP_PRIVATE|unix.MAP_ANONYMOUS)
|
||||
if err != nil {
|
||||
return fmt.Errorf("allocate page for descriptor %d: %w", i, err)
|
||||
}
|
||||
// Store the base for cleanup later
|
||||
dt.bufferBase = uintptr(basePtr)
|
||||
dt.bufferSize = totalSize
|
||||
|
||||
for i := range dt.descriptors {
|
||||
dt.descriptors[i] = Descriptor{
|
||||
address: uintptr(pagePtr),
|
||||
address: dt.bufferBase + uintptr(i*itemSize),
|
||||
length: 0,
|
||||
// All descriptors should form a free chain that loops around.
|
||||
flags: descriptorFlagHasNext,
|
||||
@@ -154,30 +159,27 @@ func (dt *DescriptorTable) releaseBuffers() error {
|
||||
dt.mu.Lock()
|
||||
defer dt.mu.Unlock()
|
||||
|
||||
var errs []error
|
||||
pageSize := os.Getpagesize()
|
||||
for i := range dt.descriptors {
|
||||
descriptor := &dt.descriptors[i]
|
||||
if descriptor.address == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// The pointer points to memory not managed by Go, so this conversion
|
||||
// is safe. See https://github.com/golang/go/issues/58625
|
||||
//goland:noinspection GoVetUnsafePointer
|
||||
err := unix.MunmapPtr(unsafe.Pointer(descriptor.address), uintptr(pageSize))
|
||||
if err == nil {
|
||||
descriptor.address = 0
|
||||
} else {
|
||||
errs = append(errs, fmt.Errorf("release page for descriptor %d: %w", i, err))
|
||||
}
|
||||
descriptor.address = 0
|
||||
}
|
||||
|
||||
// As a safety measure, make sure no descriptors can be used anymore.
|
||||
dt.freeHeadIndex = noFreeHead
|
||||
dt.freeNum = 0
|
||||
|
||||
return errors.Join(errs...)
|
||||
if dt.bufferBase != 0 {
|
||||
// The pointer points to memory not managed by Go, so this conversion
|
||||
// is safe. See https://github.com/golang/go/issues/58625
|
||||
dt.bufferBase = 0
|
||||
//goland:noinspection GoVetUnsafePointer
|
||||
err := unix.MunmapPtr(unsafe.Pointer(dt.bufferBase), uintptr(dt.bufferSize))
|
||||
if err != nil {
|
||||
return fmt.Errorf("release buffer memory: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createDescriptorChain creates a new descriptor chain within the descriptor
|
||||
|
||||
Reference in New Issue
Block a user