mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 08:24:25 +01:00
yeah
This commit is contained in:
@@ -736,6 +736,7 @@ func (t *tun) Write(b []byte) (int, error) {
|
|||||||
|
|
||||||
err := t.vdev.TransmitPacket(hdr, b)
|
err := t.vdev.TransmitPacket(hdr, b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
t.l.WithError(err).Error("Transmitting packet")
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return maximum, nil
|
return maximum, nil
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package virtqueue
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -49,7 +48,7 @@ type AvailableRing struct {
|
|||||||
// virtio specification.
|
// virtio specification.
|
||||||
usedEvent *uint16
|
usedEvent *uint16
|
||||||
|
|
||||||
mu sync.Mutex
|
//mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// newAvailableRing creates an available ring that uses the given underlying
|
// newAvailableRing creates an available ring that uses the given underlying
|
||||||
@@ -84,8 +83,9 @@ func (r *AvailableRing) Address() uintptr {
|
|||||||
// advances the ring index accordingly to make the device process the new
|
// advances the ring index accordingly to make the device process the new
|
||||||
// descriptor chains.
|
// descriptor chains.
|
||||||
func (r *AvailableRing) offer(chainHeads []uint16) {
|
func (r *AvailableRing) offer(chainHeads []uint16) {
|
||||||
r.mu.Lock()
|
//always called under lock
|
||||||
defer r.mu.Unlock()
|
//r.mu.Lock()
|
||||||
|
//defer r.mu.Unlock()
|
||||||
|
|
||||||
// Add descriptor chain heads to the ring.
|
// Add descriptor chain heads to the ring.
|
||||||
for offset, head := range chainHeads {
|
for offset, head := range chainHeads {
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"os"
|
|
||||||
"sync"
|
"sync"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
@@ -54,6 +53,7 @@ type DescriptorTable struct {
|
|||||||
|
|
||||||
bufferBase uintptr
|
bufferBase uintptr
|
||||||
bufferSize int
|
bufferSize int
|
||||||
|
itemSize int
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
@@ -63,7 +63,7 @@ type DescriptorTable struct {
|
|||||||
// descriptor table (see [descriptorTableSize]) for the given queue size.
|
// descriptor table (see [descriptorTableSize]) for the given queue size.
|
||||||
//
|
//
|
||||||
// Before this descriptor table can be used, [initialize] must be called.
|
// Before this descriptor table can be used, [initialize] must be called.
|
||||||
func newDescriptorTable(queueSize int, mem []byte) *DescriptorTable {
|
func newDescriptorTable(queueSize int, mem []byte, itemSize int) *DescriptorTable {
|
||||||
dtSize := descriptorTableSize(queueSize)
|
dtSize := descriptorTableSize(queueSize)
|
||||||
if len(mem) != dtSize {
|
if len(mem) != dtSize {
|
||||||
panic(fmt.Sprintf("memory size (%v) does not match required size "+
|
panic(fmt.Sprintf("memory size (%v) does not match required size "+
|
||||||
@@ -75,6 +75,7 @@ func newDescriptorTable(queueSize int, mem []byte) *DescriptorTable {
|
|||||||
// We have no free descriptors until they were initialized.
|
// We have no free descriptors until they were initialized.
|
||||||
freeHeadIndex: noFreeHead,
|
freeHeadIndex: noFreeHead,
|
||||||
freeNum: 0,
|
freeNum: 0,
|
||||||
|
itemSize: itemSize, //todo configurable? needs to be page-aligned
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -116,10 +117,8 @@ func (dt *DescriptorTable) BufferAddresses() map[uintptr]int {
|
|||||||
func (dt *DescriptorTable) initializeDescriptors() error {
|
func (dt *DescriptorTable) initializeDescriptors() error {
|
||||||
numDescriptors := len(dt.descriptors)
|
numDescriptors := len(dt.descriptors)
|
||||||
|
|
||||||
itemSize := os.Getpagesize() //todo configurable? needs to be page-aligned
|
|
||||||
|
|
||||||
// Allocate ONE large region for all buffers
|
// Allocate ONE large region for all buffers
|
||||||
totalSize := itemSize * numDescriptors
|
totalSize := dt.itemSize * numDescriptors
|
||||||
basePtr, err := unix.MmapPtr(-1, 0, nil, uintptr(totalSize),
|
basePtr, err := unix.MmapPtr(-1, 0, nil, uintptr(totalSize),
|
||||||
unix.PROT_READ|unix.PROT_WRITE,
|
unix.PROT_READ|unix.PROT_WRITE,
|
||||||
unix.MAP_PRIVATE|unix.MAP_ANONYMOUS)
|
unix.MAP_PRIVATE|unix.MAP_ANONYMOUS)
|
||||||
@@ -136,7 +135,7 @@ func (dt *DescriptorTable) initializeDescriptors() error {
|
|||||||
|
|
||||||
for i := range dt.descriptors {
|
for i := range dt.descriptors {
|
||||||
dt.descriptors[i] = Descriptor{
|
dt.descriptors[i] = Descriptor{
|
||||||
address: dt.bufferBase + uintptr(i*itemSize),
|
address: dt.bufferBase + uintptr(i*dt.itemSize),
|
||||||
length: 0,
|
length: 0,
|
||||||
// All descriptors should form a free chain that loops around.
|
// All descriptors should form a free chain that loops around.
|
||||||
flags: descriptorFlagHasNext,
|
flags: descriptorFlagHasNext,
|
||||||
@@ -202,8 +201,6 @@ func (dt *DescriptorTable) releaseBuffers() error {
|
|||||||
// caller should try again after some descriptor chains were used by the device
|
// caller should try again after some descriptor chains were used by the device
|
||||||
// and returned back into the free chain.
|
// and returned back into the free chain.
|
||||||
func (dt *DescriptorTable) createDescriptorChain(outBuffers [][]byte, numInBuffers int) (uint16, error) {
|
func (dt *DescriptorTable) createDescriptorChain(outBuffers [][]byte, numInBuffers int) (uint16, error) {
|
||||||
pageSize := os.Getpagesize()
|
|
||||||
|
|
||||||
// Calculate the number of descriptors needed to build the chain.
|
// Calculate the number of descriptors needed to build the chain.
|
||||||
numDesc := uint16(len(outBuffers) + numInBuffers)
|
numDesc := uint16(len(outBuffers) + numInBuffers)
|
||||||
|
|
||||||
@@ -217,7 +214,7 @@ func (dt *DescriptorTable) createDescriptorChain(outBuffers [][]byte, numInBuffe
|
|||||||
|
|
||||||
// Do we still have enough free descriptors?
|
// Do we still have enough free descriptors?
|
||||||
if numDesc > dt.freeNum {
|
if numDesc > dt.freeNum {
|
||||||
return 0, fmt.Errorf("%w: %d free but needed %d", ErrNotEnoughFreeDescriptors, dt.freeNum, numDesc)
|
return 0, ErrNotEnoughFreeDescriptors
|
||||||
}
|
}
|
||||||
|
|
||||||
// Above validation ensured that there is at least one free descriptor, so
|
// Above validation ensured that there is at least one free descriptor, so
|
||||||
@@ -238,16 +235,16 @@ func (dt *DescriptorTable) createDescriptorChain(outBuffers [][]byte, numInBuffe
|
|||||||
desc := &dt.descriptors[next]
|
desc := &dt.descriptors[next]
|
||||||
checkUnusedDescriptorLength(next, desc)
|
checkUnusedDescriptorLength(next, desc)
|
||||||
|
|
||||||
if len(buffer) > pageSize {
|
if len(buffer) > dt.itemSize {
|
||||||
// The caller should already prevent that from happening.
|
// The caller should already prevent that from happening.
|
||||||
panic(fmt.Sprintf("out buffer %d has size %d which exceeds page size %d", i, len(buffer), pageSize))
|
panic(fmt.Sprintf("out buffer %d has size %d which exceeds desc length %d", i, len(buffer), dt.itemSize))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy the buffer to the memory referenced by the descriptor.
|
// Copy the buffer to the memory referenced by the descriptor.
|
||||||
// The descriptor address points to memory not managed by Go, so this
|
// The descriptor address points to memory not managed by Go, so this
|
||||||
// conversion is safe. See https://github.com/golang/go/issues/58625
|
// conversion is safe. See https://github.com/golang/go/issues/58625
|
||||||
//goland:noinspection GoVetUnsafePointer
|
//goland:noinspection GoVetUnsafePointer
|
||||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), pageSize), buffer)
|
copy(unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), dt.itemSize), buffer)
|
||||||
desc.length = uint32(len(buffer))
|
desc.length = uint32(len(buffer))
|
||||||
|
|
||||||
// Clear the flags in case there were any others set.
|
// Clear the flags in case there were any others set.
|
||||||
@@ -261,7 +258,7 @@ func (dt *DescriptorTable) createDescriptorChain(outBuffers [][]byte, numInBuffe
|
|||||||
checkUnusedDescriptorLength(next, desc)
|
checkUnusedDescriptorLength(next, desc)
|
||||||
|
|
||||||
// Give the device the maximum available number of bytes to write into.
|
// Give the device the maximum available number of bytes to write into.
|
||||||
desc.length = uint32(pageSize)
|
desc.length = uint32(dt.itemSize)
|
||||||
|
|
||||||
// Mark the descriptor as device-writable.
|
// Mark the descriptor as device-writable.
|
||||||
desc.flags = descriptorFlagHasNext | descriptorFlagWritable
|
desc.flags = descriptorFlagHasNext | descriptorFlagWritable
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ type SplitQueue struct {
|
|||||||
// [SplitQueue.OfferDescriptorChain].
|
// [SplitQueue.OfferDescriptorChain].
|
||||||
offerMutex sync.Mutex
|
offerMutex sync.Mutex
|
||||||
pageSize int
|
pageSize int
|
||||||
|
itemSize int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSplitQueue allocates a new [SplitQueue] in memory. The given queue size
|
// NewSplitQueue allocates a new [SplitQueue] in memory. The given queue size
|
||||||
@@ -61,6 +62,7 @@ func NewSplitQueue(queueSize int) (_ *SplitQueue, err error) {
|
|||||||
sq := SplitQueue{
|
sq := SplitQueue{
|
||||||
size: queueSize,
|
size: queueSize,
|
||||||
pageSize: os.Getpagesize(),
|
pageSize: os.Getpagesize(),
|
||||||
|
itemSize: os.Getpagesize(), //todo config
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean up a partially initialized queue when something fails.
|
// Clean up a partially initialized queue when something fails.
|
||||||
@@ -109,7 +111,7 @@ func NewSplitQueue(queueSize int) (_ *SplitQueue, err error) {
|
|||||||
return nil, fmt.Errorf("allocate virtqueue buffer: %w", err)
|
return nil, fmt.Errorf("allocate virtqueue buffer: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sq.descriptorTable = newDescriptorTable(queueSize, sq.buf[descriptorTableStart:descriptorTableEnd])
|
sq.descriptorTable = newDescriptorTable(queueSize, sq.buf[descriptorTableStart:descriptorTableEnd], sq.itemSize)
|
||||||
sq.availableRing = newAvailableRing(queueSize, sq.buf[availableRingStart:availableRingEnd])
|
sq.availableRing = newAvailableRing(queueSize, sq.buf[availableRingStart:availableRingEnd])
|
||||||
sq.usedRing = newUsedRing(queueSize, sq.buf[usedRingStart:usedRingEnd])
|
sq.usedRing = newUsedRing(queueSize, sq.buf[usedRingStart:usedRingEnd])
|
||||||
|
|
||||||
@@ -241,6 +243,12 @@ func (sq *SplitQueue) consumeUsedRing(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// blockForMoreDescriptors blocks on a channel waiting for more descriptors to free up.
|
||||||
|
// it is its own function so maybe it might show up in pprof
|
||||||
|
func (sq *SplitQueue) blockForMoreDescriptors() {
|
||||||
|
<-sq.moreFreeDescriptors
|
||||||
|
}
|
||||||
|
|
||||||
// OfferDescriptorChain offers a descriptor chain to the device which contains a
|
// OfferDescriptorChain offers a descriptor chain to the device which contains a
|
||||||
// number of device-readable buffers (out buffers) and device-writable buffers
|
// number of device-readable buffers (out buffers) and device-writable buffers
|
||||||
// (in buffers).
|
// (in buffers).
|
||||||
@@ -292,12 +300,19 @@ func (sq *SplitQueue) OfferDescriptorChain(outBuffers [][]byte, numInBuffers int
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if waitFree && errors.Is(err, ErrNotEnoughFreeDescriptors) {
|
|
||||||
// Wait for more free descriptors to be put back into the queue.
|
// I don't wanna use errors.Is, it's slow
|
||||||
// If the number of free descriptors is still not sufficient, we'll
|
//goland:noinspection GoDirectComparisonOfErrors
|
||||||
// land here again.
|
if err == ErrNotEnoughFreeDescriptors {
|
||||||
<-sq.moreFreeDescriptors
|
if waitFree {
|
||||||
continue
|
// Wait for more free descriptors to be put back into the queue.
|
||||||
|
// If the number of free descriptors is still not sufficient, we'll
|
||||||
|
// land here again.
|
||||||
|
sq.blockForMoreDescriptors()
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return 0, fmt.Errorf("create descriptor chain: %w", err)
|
return 0, fmt.Errorf("create descriptor chain: %w", err)
|
||||||
}
|
}
|
||||||
@@ -340,6 +355,7 @@ func (sq *SplitQueue) GetDescriptorChain(head uint16) (outBuffers, inBuffers [][
|
|||||||
func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
|
func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
|
||||||
sq.ensureInitialized()
|
sq.ensureInitialized()
|
||||||
|
|
||||||
|
//not called under lock
|
||||||
if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
|
if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
|
||||||
return fmt.Errorf("free: %w", err)
|
return fmt.Errorf("free: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package virtqueue
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -52,7 +51,7 @@ type UsedRing struct {
|
|||||||
// processed.
|
// processed.
|
||||||
lastIndex uint16
|
lastIndex uint16
|
||||||
|
|
||||||
mu sync.Mutex
|
//mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// newUsedRing creates a used ring that uses the given underlying memory. The
|
// newUsedRing creates a used ring that uses the given underlying memory. The
|
||||||
@@ -87,9 +86,10 @@ func (r *UsedRing) Address() uintptr {
|
|||||||
|
|
||||||
// take returns all new [UsedElement]s that the device put into the ring and
|
// take returns all new [UsedElement]s that the device put into the ring and
|
||||||
// that weren't already returned by a previous call to this method.
|
// that weren't already returned by a previous call to this method.
|
||||||
|
// had a lock, I removed it
|
||||||
func (r *UsedRing) take() []UsedElement {
|
func (r *UsedRing) take() []UsedElement {
|
||||||
r.mu.Lock()
|
//r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
//defer r.mu.Unlock()
|
||||||
|
|
||||||
ringIndex := *r.ringIndex
|
ringIndex := *r.ringIndex
|
||||||
if ringIndex == r.lastIndex {
|
if ringIndex == r.lastIndex {
|
||||||
|
|||||||
Reference in New Issue
Block a user