This commit is contained in:
JackDoan
2025-11-11 17:00:40 -06:00
parent cd30e5aa01
commit 400fdace9d
4 changed files with 19 additions and 35 deletions

View File

@@ -298,7 +298,7 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.UsedElement) (int, error) { func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.UsedElement) (int, error) {
//read first element to see how many descriptors we need: //read first element to see how many descriptors we need:
pkt.Payload = pkt.Payload[:cap(pkt.Payload)] pkt.Payload = pkt.Payload[:cap(pkt.Payload)]
n, err := dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[0].DescriptorIndex), pkt.Payload) n, err := dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[0].DescriptorIndex), pkt.Payload, int(chains[0].Length)) //todo
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -333,7 +333,7 @@ func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.Us
i := 1 i := 1
// we used chain 0 already // we used chain 0 already
for i = 1; i < len(chains); i++ { for i = 1; i < len(chains); i++ {
n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:]) n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:], int(chains[i].Length))
if err != nil { if err != nil {
// When this fails we may miss to free some descriptor chains. We // When this fails we may miss to free some descriptor chains. We
// could try to mitigate this by deferring the freeing somehow, but // could try to mitigate this by deferring the freeing somehow, but

View File

@@ -349,7 +349,7 @@ func (dt *DescriptorTable) getDescriptorChain(head uint16) (outBuffers, inBuffer
return return
} }
func (dt *DescriptorTable) getDescriptorChainContents(head uint16, out []byte) (int, error) { func (dt *DescriptorTable) getDescriptorChainContents(head uint16, out []byte, maxLen int) (int, error) {
if int(head) > len(dt.descriptors) { if int(head) > len(dt.descriptors) {
return 0, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain) return 0, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
} }
@@ -387,7 +387,9 @@ func (dt *DescriptorTable) getDescriptorChainContents(head uint16, out []byte) (
next = desc.next next = desc.next
} }
if maxLen > 0 {
//todo length = min(maxLen, length)
}
//set out to length: //set out to length:
out = out[:length] out = out[:length]
@@ -399,7 +401,7 @@ func (dt *DescriptorTable) getDescriptorChainContents(head uint16, out []byte) (
// The descriptor address points to memory not managed by Go, so this // The descriptor address points to memory not managed by Go, so this
// conversion is safe. See https://github.com/golang/go/issues/58625 // conversion is safe. See https://github.com/golang/go/issues/58625
//goland:noinspection GoVetUnsafePointer //goland:noinspection GoVetUnsafePointer
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length) bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), min(uint32(length-copied), desc.length))
copied += copy(out[copied:], bs) copied += copy(out[copied:], bs)
// Is this the tail of the chain? // Is this the tail of the chain?

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"os" "os"
"sync" "sync"
"syscall"
"github.com/slackhq/nebula/overlay/eventfd" "github.com/slackhq/nebula/overlay/eventfd"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
@@ -30,16 +31,6 @@ type SplitQueue struct {
// chains and put them in the used ring. // chains and put them in the used ring.
callEventFD eventfd.EventFD callEventFD eventfd.EventFD
// UsedChains is a chanel that receives [UsedElement]s for descriptor chains
// that were used by the device.
UsedChains chan UsedElement
// moreFreeDescriptors is a channel that signals when any descriptors were
// put back into the free chain of the descriptor table. This is used to
// unblock methods waiting for available room in the queue to create new
// descriptor chains again.
moreFreeDescriptors chan struct{}
// stop is used by [SplitQueue.Close] to cancel the goroutine that handles // stop is used by [SplitQueue.Close] to cancel the goroutine that handles
// used buffer notifications. It blocks until the goroutine ended. // used buffer notifications. It blocks until the goroutine ended.
stop func() error stop func() error
@@ -131,10 +122,6 @@ func NewSplitQueue(queueSize int) (_ *SplitQueue, err error) {
return nil, fmt.Errorf("initialize descriptors: %w", err) return nil, fmt.Errorf("initialize descriptors: %w", err)
} }
// Initialize channels.
sq.UsedChains = make(chan UsedElement, queueSize)
sq.moreFreeDescriptors = make(chan struct{})
sq.epoll, err = eventfd.NewEpoll() sq.epoll, err = eventfd.NewEpoll()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -366,7 +353,8 @@ func (sq *SplitQueue) OfferOutDescriptorChains(prepend []byte, outBuffers [][]by
// Wait for more free descriptors to be put back into the queue. // Wait for more free descriptors to be put back into the queue.
// If the number of free descriptors is still not sufficient, we'll // If the number of free descriptors is still not sufficient, we'll
// land here again. // land here again.
<-sq.moreFreeDescriptors //todo should never happen
syscall.Syscall(syscall.SYS_SCHED_YIELD, 0, 0, 0) // Cheap barrier
continue continue
} }
return nil, fmt.Errorf("create descriptor chain: %w", err) return nil, fmt.Errorf("create descriptor chain: %w", err)
@@ -400,9 +388,9 @@ func (sq *SplitQueue) GetDescriptorChain(head uint16) (outBuffers, inBuffers [][
return sq.descriptorTable.getDescriptorChain(head) return sq.descriptorTable.getDescriptorChain(head)
} }
func (sq *SplitQueue) GetDescriptorChainContents(head uint16, out []byte) (int, error) { func (sq *SplitQueue) GetDescriptorChainContents(head uint16, out []byte, maxLen int) (int, error) {
sq.ensureInitialized() sq.ensureInitialized()
return sq.descriptorTable.getDescriptorChainContents(head, out) return sq.descriptorTable.getDescriptorChainContents(head, out, maxLen)
} }
// FreeDescriptorChain frees the descriptor chain with the given head index. // FreeDescriptorChain frees the descriptor chain with the given head index.
@@ -415,20 +403,11 @@ func (sq *SplitQueue) GetDescriptorChainContents(head uint16, out []byte) (int,
// When there are outstanding calls for [SplitQueue.OfferDescriptorChain] that // When there are outstanding calls for [SplitQueue.OfferDescriptorChain] that
// are waiting for free room in the queue, they may become unblocked by this. // are waiting for free room in the queue, they may become unblocked by this.
func (sq *SplitQueue) FreeDescriptorChain(head uint16) error { func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
sq.ensureInitialized()
//not called under lock //not called under lock
if err := sq.descriptorTable.freeDescriptorChain(head); err != nil { if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
return fmt.Errorf("free: %w", err) return fmt.Errorf("free: %w", err)
} }
// There is more free room in the descriptor table now.
// This is a fire-and-forget signal, so do not block when nobody listens.
select { //todo eliminate
case sq.moreFreeDescriptors <- struct{}{}:
default:
}
return nil return nil
} }
@@ -473,10 +452,6 @@ func (sq *SplitQueue) Close() error {
errs = append(errs, fmt.Errorf("stop consume used ring: %w", err)) errs = append(errs, fmt.Errorf("stop consume used ring: %w", err))
} }
// The stop function blocked until the goroutine ended, so the channel
// can now safely be closed.
close(sq.UsedChains)
// Make sure that this code block is executed only once. // Make sure that this code block is executed only once.
sq.stop = nil sq.stop = nil
} }

View File

@@ -184,6 +184,13 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
return u.writeTo6(b, ip) return u.writeTo6(b, ip)
} }
func (u *StdConn) WriteToBatch(b []byte, ip netip.AddrPort) error {
if u.isV4 {
return u.writeTo4(b, ip)
}
return u.writeTo6(b, ip)
}
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
var rsa unix.RawSockaddrInet6 var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6 rsa.Family = unix.AF_INET6