refactoring a bit

This commit is contained in:
JackDoan
2025-12-18 13:27:28 -06:00
parent f5c46c43ce
commit 41c9a3b2eb
19 changed files with 229 additions and 387 deletions

View File

@@ -13,7 +13,7 @@ import (
"github.com/slackhq/nebula/routing"
)
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.Packet, q int, localCache firewall.ConntrackCache, now time.Time) {
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.UDPPacket, q int, localCache firewall.ConntrackCache, now time.Time) {
err := newPacket(packet, false, fwPacket)
if err != nil {
if f.l.Level >= logrus.DebugLevel {
@@ -412,7 +412,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
}
}
func (f *Interface) sendNoMetricsDelayed(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb []byte, out *packet.Packet, q int) {
func (f *Interface) sendNoMetricsDelayed(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb []byte, out *packet.UDPPacket, q int) {
if ci.eKey == nil {
return
}

View File

@@ -294,7 +294,7 @@ func (f *Interface) listenOut(q int) {
toSend := make([][]byte, batch)
li.ListenOut(func(pkts []*packet.Packet) {
li.ListenOut(func(pkts []*packet.UDPPacket) {
toSend = toSend[:0]
for i := range outPackets {
outPackets[i].SegCounter = 0
@@ -323,11 +323,11 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
packets := make([]*packet.VirtIOPacket, batch)
outPackets := make([]*packet.Packet, batch)
packets := reader.NewPacketArrays(batch)
outPackets := make([]*packet.UDPPacket, batch)
for i := 0; i < batch; i++ {
packets[i] = packet.NewVIO()
outPackets[i] = packet.New(false) //todo?
outPackets[i] = packet.New(false) //todo isv4?
}
for {
@@ -352,9 +352,8 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
now := time.Now()
for i, pkt := range packets[:n] {
outPackets[i].ReadyToSend = false
f.consumeInsidePacket(pkt.Payload, fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l), now)
f.consumeInsidePacket(pkt.GetPayload(), fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l), now)
reader.RecycleRxSeg(pkt, i == (n-1), queueNum) //todo handle err?
pkt.Reset()
}
_, err = f.writers[queueNum].WriteBatch(outPackets[:n])
if err != nil {

View File

@@ -359,7 +359,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe
f.connectionManager.In(hostinfo)
}
func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
func (f *Interface) readOutsidePacketsMany(packets []*packet.UDPPacket, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
for i, pkt := range packets {
out[i].Scratch = out[i].Scratch[:0]
via := ViaSender{UdpAddr: pkt.AddrPort()}

36
overlay/packets.go Normal file
View File

@@ -0,0 +1,36 @@
package overlay
//import (
// "github.com/slackhq/nebula/util/virtio"
//)
//type VirtIOPacket struct {
// Payload []byte
// Header virtio.NetHdr
// Chains []uint16
// ChainRefs [][]byte
//}
//
//func NewVIO() *VirtIOPacket {
// out := new(VirtIOPacket)
// out.Payload = nil
// out.ChainRefs = make([][]byte, 0, 4)
// out.Chains = make([]uint16, 0, 8)
// return out
//}
//
//func (v *VirtIOPacket) Reset() {
// v.Payload = nil
// v.ChainRefs = v.ChainRefs[:0]
// v.Chains = v.Chains[:0]
//}
// TunPacket is formerly VirtIOPacket
type TunPacket interface {
SetPayload([]byte)
GetPayload() []byte
}
type OutPacket interface {
SetPayload([]byte)
GetPayload() []byte
}

View File

@@ -16,13 +16,15 @@ const DefaultMTU = 1300
type TunDev interface {
io.WriteCloser
ReadMany(x []*packet.VirtIOPacket, q int) (int, error)
NewPacketArrays(batchSize int) []TunPacket
ReadMany(x []TunPacket, q int) (int, error)
RecycleRxSeg(pkt TunPacket, kick bool, q int) error
//todo this interface sux
AllocSeg(pkt *packet.OutPacket, q int) (int, error)
WriteOne(x *packet.OutPacket, kick bool, q int) (int, error)
WriteMany(x []*packet.OutPacket, q int) (int, error)
RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error
}
// TODO: We may be able to remove routines
@@ -31,8 +33,8 @@ type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefi
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
switch {
case c.GetBool("tun.disabled", false):
tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
return tun, nil
t := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
return t, nil
default:
return newTun(c, l, vpnNetworks, routines > 1)

View File

@@ -24,7 +24,11 @@ type disabledTun struct {
l *logrus.Logger
}
func (*disabledTun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
func (t *disabledTun) NewPacketArrays(batchSize int) []TunPacket {
panic("implement me") //TODO
}
func (*disabledTun) RecycleRxSeg(pkt TunPacket, kick bool, q int) error {
return nil
}
@@ -131,8 +135,8 @@ func (t *disabledTun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
return 0, fmt.Errorf("tun_disabled: WriteMany not implemented")
}
func (t *disabledTun) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
return t.Read(b[0].Payload)
func (t *disabledTun) ReadMany(b []TunPacket, _ int) (int, error) {
return t.Read(b[0].GetPayload())
}
func (t *disabledTun) NewMultiQueueReader() (TunDev, error) {

View File

@@ -4,6 +4,7 @@
package overlay
import (
"context"
"fmt"
"net"
"net/netip"
@@ -183,6 +184,14 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n
return t, nil
}
func (t *tun) NewPacketArrays(batchSize int) []TunPacket {
inPackets := make([]TunPacket, batchSize)
for i := 0; i < batchSize; i++ {
inPackets[i] = vhostnet.NewVIO()
}
return inPackets
}
func (t *tun) reload(c *config.C, initial bool) error {
routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
@@ -725,12 +734,25 @@ func (t *tun) Close() error {
return nil
}
func (t *tun) ReadMany(p []*packet.VirtIOPacket, q int) (int, error) {
n, err := t.vdev[q].ReceivePackets(p) //we are TXing
func (t *tun) ReadMany(p []TunPacket, q int) (int, error) {
err := t.vdev[q].ReceiveQueue.WaitForUsedElements(context.TODO())
if err != nil {
return 0, err
}
return n, nil
i := 0
for i = 0; i < len(p); i++ {
item, ok := t.vdev[q].ReceiveQueue.TakeSingleNoBlock()
if !ok {
break
}
pkt := p[i].(*vhostnet.VirtIOPacket) //todo I'm not happy about this but I don't want to change how memory is "owned" rn
_, err = t.vdev[q].ProcessRxChain(pkt, item)
if err != nil {
return i, err
}
i++
}
return i, nil
}
func (t *tun) Write(b []byte) (int, error) {
@@ -783,6 +805,9 @@ func (t *tun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
return maximum, nil
}
func (t *tun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
return t.vdev[q].ReceiveQueue.OfferDescriptorChains(pkt.Chains, kick)
func (t *tun) RecycleRxSeg(pkt TunPacket, kick bool, q int) error {
vpkt := pkt.(*vhostnet.VirtIOPacket)
err := t.vdev[q].ReceiveQueue.OfferDescriptorChains(vpkt.Chains, kick)
vpkt.Reset() //intentionally ignoring err!
return err
}

View File

@@ -106,7 +106,7 @@ func (t *TestTun) Name() string {
return t.Device
}
func (t *TestTun) ReadMany(x []*packet.VirtIOPacket, q int) (int, error) {
func (t *TestTun) ReadMany(x []TunPacket, q int) (int, error) {
p, ok := <-t.rxPackets
if !ok {
return 0, os.ErrClosed
@@ -165,7 +165,7 @@ func (t *TestTun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
return len(x), nil
}
func (t *TestTun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
func (t *TestTun) RecycleRxSeg(pkt *TunPacket, kick bool, q int) error {
//todo this ought to maybe track something
return nil
}

View File

@@ -38,7 +38,18 @@ type UserDevice struct {
inboundWriter *io.PipeWriter
}
func (d *UserDevice) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
func (d *UserDevice) NewPacketArrays(batchSize int) []TunPacket {
//inPackets := make([]TunPacket, batchSize)
//outPackets := make([]OutPacket, batchSize)
panic("not implemented") //todo!
//for i := 0; i < batchSize; i++ {
// inPackets[i] = vhostnet.NewVIO()
// outPackets[i] = packet.New(false)
//}
//return inPackets, outPackets
}
func (d *UserDevice) RecycleRxSeg(pkt TunPacket, kick bool, q int) error {
return nil
}
@@ -76,8 +87,12 @@ func (d *UserDevice) Close() error {
return nil
}
func (d *UserDevice) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
return d.Read(b[0].Payload)
func (d *UserDevice) ReadMany(b []TunPacket, _ int) (int, error) {
_, err := d.Read(b[0].GetPayload())
if err != nil {
return 0, err
}
return 1, nil
}
func (d *UserDevice) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {

View File

@@ -118,7 +118,7 @@ func NewDevice(options ...Option) (*Device, error) {
return nil, fmt.Errorf("set transmit queue backend: %w", err)
}
// Fully populate the receive queue with available buffers which the device
// Fully populate the rx queue with available buffers which the device
// can write new packets into.
if err = dev.refillReceiveQueue(); err != nil {
return nil, fmt.Errorf("refill receive queue: %w", err)
@@ -198,11 +198,8 @@ func (dev *Device) Close() error {
// createQueue creates a new virtqueue and registers it with the vhost device
// using the given index.
func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*virtqueue.SplitQueue, error) {
var (
queue *virtqueue.SplitQueue
err error
)
if queue, err = virtqueue.NewSplitQueue(queueSize, itemSize); err != nil {
queue, err := virtqueue.NewSplitQueue(queueSize, itemSize)
if err != nil {
return nil, fmt.Errorf("create virtqueue: %w", err)
}
if err = vhost.RegisterQueue(controlFD, uint32(queueIndex), queue); err != nil {
@@ -218,10 +215,10 @@ func (dev *Device) GetPacketForTx() (uint16, []byte, error) {
idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
if err == virtqueue.ErrNotEnoughFreeDescriptors {
dev.fullTable = true
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO())
}
} else {
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO())
}
if err != nil {
return 0, nil, fmt.Errorf("transmit queue: %w", err)
@@ -271,18 +268,15 @@ func (dev *Device) TransmitPackets(pkts []*packet.OutPacket) error {
return nil
}
// processChains processes as many chains as needed to create one packet. The number of processed chains is returned.
func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.UsedElement) (int, error) {
// ProcessRxChain processes a single chain to create one packet. The number of processed chains is returned.
func (dev *Device) ProcessRxChain(pkt *VirtIOPacket, chain virtqueue.UsedElement) (int, error) {
//read first element to see how many descriptors we need:
pkt.Reset()
err := dev.ReceiveQueue.GetDescriptorInbuffers(uint16(chains[0].DescriptorIndex), &pkt.ChainRefs)
idx := uint16(chain.DescriptorIndex)
buf, err := dev.ReceiveQueue.GetDescriptorItem(idx)
if err != nil {
return 0, fmt.Errorf("get descriptor chain: %w", err)
}
if len(pkt.ChainRefs) == 0 {
return 1, nil
}
// The specification requires that the first descriptor chain starts
// with a virtio-net header. It is not clear, whether it is also
@@ -290,7 +284,7 @@ func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.Us
// descriptor chain, but it is reasonable to assume that this is
// always the case.
// The decode method already does the buffer length check.
if err = pkt.Header.Decode(pkt.ChainRefs[0][0:]); err != nil {
if err = pkt.header.Decode(buf); err != nil {
// The device misbehaved. There is no way we can gracefully
// recover from this, because we don't know how many of the
// following descriptor chains belong to this packet.
@@ -298,72 +292,44 @@ func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.Us
}
//we have the header now: what do we need to do?
if int(pkt.Header.NumBuffers) > len(chains) {
return 0, fmt.Errorf("number of buffers is greater than number of chains %d", len(chains))
if int(pkt.header.NumBuffers) > 1 {
return 0, fmt.Errorf("number of buffers is greater than number of chains %d", 1)
}
if int(pkt.Header.NumBuffers) != 1 {
return 0, fmt.Errorf("too smol-brain to handle more than one chain right now: %d chains", len(chains))
if int(pkt.header.NumBuffers) != 1 {
return 0, fmt.Errorf("too smol-brain to handle more than one buffer per chain item right now: %d chains, %d bufs", 1, int(pkt.header.NumBuffers))
}
if chains[0].Length > 16000 {
if chain.Length > 16000 {
//todo!
return 1, fmt.Errorf("too big packet length: %d", chains[0].Length)
return 1, fmt.Errorf("too big packet length: %d", chain.Length)
}
//shift the buffer out of out:
pkt.Payload = pkt.ChainRefs[0][virtio.NetHdrSize:chains[0].Length]
pkt.Chains = append(pkt.Chains, uint16(chains[0].DescriptorIndex))
pkt.payload = buf[virtio.NetHdrSize:chain.Length]
pkt.Chains = append(pkt.Chains, idx)
return 1, nil
//cursor := n - virtio.NetHdrSize
//
//if uint32(n) >= chains[0].Length && pkt.Header.NumBuffers == 1 {
// pkt.Payload = pkt.Payload[:chains[0].Length-virtio.NetHdrSize]
// return 1, nil
//}
//
//i := 1
//// we used chain 0 already
//for i = 1; i < len(chains); i++ {
// n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:], int(chains[i].Length))
// if err != nil {
// // When this fails we may miss to free some descriptor chains. We
// // could try to mitigate this by deferring the freeing somehow, but
// // it's not worth the hassle. When this method fails, the queue will
// // be in a broken state anyway.
// return i, fmt.Errorf("get descriptor chain: %w", err)
// }
// cursor += n
//}
////todo this has to be wrong
//pkt.Payload = pkt.Payload[:cursor]
//return i, nil
}
func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) {
//todo optimize?
var chains []virtqueue.UsedElement
var err error
chains, err = dev.ReceiveQueue.BlockAndGetHeadsCapped(context.TODO(), len(out))
if err != nil {
return 0, err
}
if len(chains) == 0 {
return 0, nil
type VirtIOPacket struct {
payload []byte
header virtio.NetHdr
Chains []uint16
}
numPackets := 0
chainsIdx := 0
for numPackets = 0; chainsIdx < len(chains); numPackets++ {
if numPackets >= len(out) {
return numPackets, fmt.Errorf("dropping %d packets, no room", len(chains)-numPackets)
}
numChains, err := dev.processChains(out[numPackets], chains[chainsIdx:])
if err != nil {
return 0, err
}
chainsIdx += numChains
func NewVIO() *VirtIOPacket {
out := new(VirtIOPacket)
out.payload = nil
out.Chains = make([]uint16, 0, 8)
return out
}
return numPackets, nil
func (v *VirtIOPacket) Reset() {
v.payload = nil
v.Chains = v.Chains[:0]
}
func (v *VirtIOPacket) GetPayload() []byte {
return v.payload
}
func (v *VirtIOPacket) SetPayload(x []byte) {
v.payload = x //todo?
}

View File

@@ -10,10 +10,6 @@ import (
)
var (
// ErrDescriptorChainEmpty is returned when a descriptor chain would contain
// no buffers, which is not allowed.
ErrDescriptorChainEmpty = errors.New("empty descriptor chains are not allowed")
// ErrNotEnoughFreeDescriptors is returned when the free descriptors are
// exhausted, meaning that the queue is full.
ErrNotEnoughFreeDescriptors = errors.New("not enough free descriptors, queue is full")
@@ -272,59 +268,6 @@ func (dt *DescriptorTable) createDescriptorForInputs() (uint16, error) {
return head, nil
}
// TODO: Implement a zero-copy variant of createDescriptorChain?
// getDescriptorChain returns the device-readable buffers (out buffers) and
// device-writable buffers (in buffers) of the descriptor chain that starts with
// the given head index. The descriptor chain must have been created using
// [createDescriptorChain] and must not have been freed yet (meaning that the
// head index must not be contained in the free chain).
//
// Be careful to only access the returned buffer slices when the device has not
// yet or is no longer using them. They must not be accessed after
// [freeDescriptorChain] has been called.
func (dt *DescriptorTable) getDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
if int(head) > len(dt.descriptors) {
return nil, nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
}
// Iterate over the chain. The iteration is limited to the queue size to
// avoid ending up in an endless loop when things go very wrong.
next := head
for range len(dt.descriptors) {
if next == dt.freeHeadIndex {
return nil, nil, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
}
desc := &dt.descriptors[next]
// The descriptor address points to memory not managed by Go, so this
// conversion is safe. See https://github.com/golang/go/issues/58625
//goland:noinspection GoVetUnsafePointer
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
if desc.flags&descriptorFlagWritable == 0 {
outBuffers = append(outBuffers, bs)
} else {
inBuffers = append(inBuffers, bs)
}
// Is this the tail of the chain?
if desc.flags&descriptorFlagHasNext == 0 {
break
}
// Detect loops.
if desc.next == head {
return nil, nil, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
}
next = desc.next
}
return
}
func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) {
if int(head) > len(dt.descriptors) {
return nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
@@ -339,121 +282,6 @@ func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) {
return bs, nil
}
func (dt *DescriptorTable) getDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
if int(head) > len(dt.descriptors) {
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
}
// Iterate over the chain. The iteration is limited to the queue size to
// avoid ending up in an endless loop when things go very wrong.
next := head
for range len(dt.descriptors) {
if next == dt.freeHeadIndex {
return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
}
desc := &dt.descriptors[next]
// The descriptor address points to memory not managed by Go, so this
// conversion is safe. See https://github.com/golang/go/issues/58625
//goland:noinspection GoVetUnsafePointer
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
if desc.flags&descriptorFlagWritable == 0 {
return fmt.Errorf("there should not be an outbuffer in %d", head)
} else {
*inBuffers = append(*inBuffers, bs)
}
// Is this the tail of the chain?
if desc.flags&descriptorFlagHasNext == 0 {
break
}
// Detect loops.
if desc.next == head {
return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
}
next = desc.next
}
return nil
}
// freeDescriptorChain can be used to free a descriptor chain when it is no
// longer in use. The descriptor chain that starts with the given index will be
// put back into the free chain, so the descriptors can be used for later calls
// of [createDescriptorChain].
// The descriptor chain must have been created using [createDescriptorChain] and
// must not have been freed yet (meaning that the head index must not be
// contained in the free chain).
func (dt *DescriptorTable) freeDescriptorChain(head uint16) error {
if int(head) > len(dt.descriptors) {
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
}
// Iterate over the chain. The iteration is limited to the queue size to
// avoid ending up in an endless loop when things go very wrong.
next := head
var tailDesc *Descriptor
var chainLen uint16
for range len(dt.descriptors) {
if next == dt.freeHeadIndex {
return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
}
desc := &dt.descriptors[next]
chainLen++
// Set the length of all unused descriptors back to zero.
desc.length = 0
// Unset all flags except the next flag.
desc.flags &= descriptorFlagHasNext
// Is this the tail of the chain?
if desc.flags&descriptorFlagHasNext == 0 {
tailDesc = desc
break
}
// Detect loops.
if desc.next == head {
return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
}
next = desc.next
}
if tailDesc == nil {
// A descriptor chain longer than the queue size but without loops
// should be impossible.
panic(fmt.Sprintf("could not find a tail for descriptor chain starting at %d", head))
}
// The tail descriptor does not have the next flag set, but when it comes
// back into the free chain, it should have.
tailDesc.flags = descriptorFlagHasNext
if dt.freeHeadIndex == noFreeHead {
// The whole free chain was used up, so we turn this returned descriptor
// chain into the new free chain by completing the circle and using its
// head.
tailDesc.next = head
dt.freeHeadIndex = head
} else {
// Attach the returned chain at the beginning of the free chain but
// right after the free chain head.
freeHeadDesc := &dt.descriptors[dt.freeHeadIndex]
tailDesc.next = freeHeadDesc.next
freeHeadDesc.next = head
}
dt.freeNum += chainLen
return nil
}
// checkUnusedDescriptorLength asserts that the length of an unused descriptor
// is zero, as it should be.
// This is not a requirement by the virtio spec but rather a thing we do to

View File

@@ -128,8 +128,7 @@ func NewSplitQueue(queueSize int, itemSize int) (_ *SplitQueue, err error) {
return nil, err
}
// Consume used buffer notifications in the background.
sq.stop = sq.startConsumeUsedRing()
sq.stop = sq.kickSelfToExit()
return &sq, nil
}
@@ -169,9 +168,7 @@ func (sq *SplitQueue) CallEventFD() int {
return sq.callEventFD.FD()
}
// startConsumeUsedRing starts a goroutine that runs [consumeUsedRing].
// A function is returned that can be used to gracefully cancel it. todo rename
func (sq *SplitQueue) startConsumeUsedRing() func() error {
func (sq *SplitQueue) kickSelfToExit() func() error {
return func() error {
// The goroutine blocks until it receives a signal on the event file
@@ -185,7 +182,15 @@ func (sq *SplitQueue) startConsumeUsedRing() func() error {
}
}
func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) {
func (sq *SplitQueue) TakeSingleIndex(ctx context.Context) (uint16, error) {
element, err := sq.TakeSingle(ctx)
if err != nil {
return 0xffff, err
}
return element.GetHead(), nil
}
func (sq *SplitQueue) TakeSingle(ctx context.Context) (UsedElement, error) {
var n int
var err error
for ctx.Err() == nil {
@@ -195,7 +200,7 @@ func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) {
}
// Wait for a signal from the device.
if n, err = sq.epoll.Block(); err != nil {
return 0, fmt.Errorf("wait: %w", err)
return UsedElement{}, fmt.Errorf("wait: %w", err)
}
if n > 0 {
@@ -208,7 +213,31 @@ func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) {
}
}
}
return 0, ctx.Err()
return UsedElement{}, ctx.Err()
}
func (sq *SplitQueue) TakeSingleNoBlock() (UsedElement, bool) {
return sq.usedRing.takeOne()
}
func (sq *SplitQueue) WaitForUsedElements(ctx context.Context) error {
if sq.usedRing.availableToTake() != 0 {
return nil
}
for ctx.Err() == nil {
// Wait for a signal from the device.
n, err := sq.epoll.Block()
if err != nil {
return fmt.Errorf("wait: %w", err)
}
if n > 0 {
_ = sq.epoll.Clear()
if sq.usedRing.availableToTake() != 0 {
return nil
}
}
}
return ctx.Err()
}
func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) ([]UsedElement, error) {
@@ -235,7 +264,7 @@ func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int)
return nil, fmt.Errorf("wait: %w", err)
}
if n > 0 {
_ = sq.epoll.Clear() //???
_ = sq.epoll.Clear()
stillNeedToTake, out = sq.usedRing.take(maxToTake)
sq.more = stillNeedToTake
return out, nil
@@ -296,16 +325,14 @@ func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) {
sq.availableRing.offerSingle(head)
// Notify the device to make it process the updated available ring.
if err := sq.kickEventFD.Kick(); err != nil {
if err = sq.kickEventFD.Kick(); err != nil {
return head, fmt.Errorf("notify device: %w", err)
}
return head, nil
}
// GetDescriptorChain returns the device-readable buffers (out buffers) and
// device-writable buffers (in buffers) of the descriptor chain with the given
// head index.
// GetDescriptorItem returns the buffer of a given index
// The head index must be one that was returned by a previous call to
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
// freed yet.
@@ -313,37 +340,11 @@ func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) {
// Be careful to only access the returned buffer slices when the device is no
// longer using them. They must not be accessed after
// [SplitQueue.FreeDescriptorChain] has been called.
func (sq *SplitQueue) GetDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
return sq.descriptorTable.getDescriptorChain(head)
}
func (sq *SplitQueue) GetDescriptorItem(head uint16) ([]byte, error) {
sq.descriptorTable.descriptors[head].length = uint32(sq.descriptorTable.itemSize)
return sq.descriptorTable.getDescriptorItem(head)
}
func (sq *SplitQueue) GetDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
return sq.descriptorTable.getDescriptorInbuffers(head, inBuffers)
}
// FreeDescriptorChain frees the descriptor chain with the given head index.
// The head index must be one that was returned by a previous call to
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
// freed yet.
//
// This creates new room in the queue which can be used by following
// [SplitQueue.OfferDescriptorChain] calls.
// When there are outstanding calls for [SplitQueue.OfferDescriptorChain] that
// are waiting for free room in the queue, they may become unblocked by this.
func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
//not called under lock
if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
return fmt.Errorf("free: %w", err)
}
return nil
}
func (sq *SplitQueue) SetDescSize(head uint16, sz int) {
//not called under lock
sq.descriptorTable.descriptors[int(head)].length = uint32(sz)

View File

@@ -84,17 +84,11 @@ func (r *UsedRing) Address() uintptr {
return uintptr(unsafe.Pointer(r.flags))
}
// take returns all new [UsedElement]s that the device put into the ring and
// that weren't already returned by a previous call to this method.
// had a lock, I removed it
func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
//r.mu.Lock()
//defer r.mu.Unlock()
func (r *UsedRing) availableToTake() int {
ringIndex := *r.ringIndex
if ringIndex == r.lastIndex {
// Nothing new.
return 0, nil
return 0
}
// Calculate the number new used elements that we can read from the ring.
@@ -103,6 +97,16 @@ func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
if count < 0 {
count += 0xffff
}
return count
}
// take returns all new [UsedElement]s that the device put into the ring and
// that weren't already returned by a previous call to this method.
func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
count := r.availableToTake()
if count == 0 {
return 0, nil
}
stillNeedToTake := 0
@@ -128,21 +132,13 @@ func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
return stillNeedToTake, elems
}
func (r *UsedRing) takeOne() (uint16, bool) {
func (r *UsedRing) takeOne() (UsedElement, bool) {
//r.mu.Lock()
//defer r.mu.Unlock()
ringIndex := *r.ringIndex
if ringIndex == r.lastIndex {
// Nothing new.
return 0xffff, false
}
// Calculate the number new used elements that we can read from the ring.
// The ring index may wrap, so special handling for that case is needed.
count := int(ringIndex - r.lastIndex)
if count < 0 {
count += 0xffff
count := r.availableToTake()
if count == 0 {
return UsedElement{}, false
}
// The number of new elements can never exceed the queue size.
@@ -150,11 +146,7 @@ func (r *UsedRing) takeOne() (uint16, bool) {
panic("used ring contains more new elements than the ring is long")
}
if count == 0 {
return 0xffff, false
}
out := r.ring[r.lastIndex%uint16(len(r.ring))].GetHead()
out := r.ring[r.lastIndex%uint16(len(r.ring))]
r.lastIndex++
return out, true

View File

@@ -14,7 +14,7 @@ import (
const Size = 0xffff
type Packet struct {
type UDPPacket struct {
Payload []byte
Control []byte
Name []byte
@@ -25,8 +25,8 @@ type Packet struct {
isV4 bool
}
func New(isV4 bool) *Packet {
return &Packet{
func New(isV4 bool) *UDPPacket {
return &UDPPacket{
Payload: make([]byte, Size),
Control: make([]byte, unix.CmsgSpace(2)),
Name: make([]byte, unix.SizeofSockaddrInet6),
@@ -34,7 +34,7 @@ func New(isV4 bool) *Packet {
}
}
func (p *Packet) AddrPort() netip.AddrPort {
func (p *UDPPacket) AddrPort() netip.AddrPort {
var ip netip.Addr
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
if p.isV4 {
@@ -45,7 +45,7 @@ func (p *Packet) AddrPort() netip.AddrPort {
return netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(p.Name[2:4]))
}
func (p *Packet) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
func (p *UDPPacket) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
//todo no chance this works on windows?
if p.isV4 {
if !addr.Addr().Is4() {
@@ -69,7 +69,7 @@ func (p *Packet) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error)
return uint32(size), nil
}
func (p *Packet) SetAddrPort(addr netip.AddrPort) error {
func (p *UDPPacket) SetAddrPort(addr netip.AddrPort) error {
nl, err := p.encodeSockaddr(p.Name, addr)
if err != nil {
return err
@@ -78,7 +78,7 @@ func (p *Packet) SetAddrPort(addr netip.AddrPort) error {
return nil
}
func (p *Packet) updateCtrl(ctrlLen int) {
func (p *UDPPacket) updateCtrl(ctrlLen int) {
p.SegSize = len(p.Payload)
p.wasSegmented = false
if ctrlLen == 0 {
@@ -101,12 +101,12 @@ func (p *Packet) updateCtrl(ctrlLen int) {
}
}
// Update sets a Packet into "just received, not processed" state
func (p *Packet) Update(ctrlLen int) {
// Update sets a UDPPacket into "just received, not processed" state
func (p *UDPPacket) Update(ctrlLen int) {
p.updateCtrl(ctrlLen)
}
func (p *Packet) SetSegSizeForTX() {
func (p *UDPPacket) SetSegSizeForTX() {
p.SegSize = len(p.Payload)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&p.Control[0]))
hdr.Level = unix.SOL_UDP
@@ -115,7 +115,7 @@ func (p *Packet) SetSegSizeForTX() {
binary.NativeEndian.PutUint16(p.Control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(p.SegSize))
}
func (p *Packet) CompatibleForSegmentationWith(otherP *Packet, currentTotalSize int) bool {
func (p *UDPPacket) CompatibleForSegmentationWith(otherP *UDPPacket, currentTotalSize int) bool {
//same dest
if !slices.Equal(p.Name, otherP.Name) {
return false
@@ -134,7 +134,7 @@ func (p *Packet) CompatibleForSegmentationWith(otherP *Packet, currentTotalSize
return true
}
func (p *Packet) Segments() iter.Seq[[]byte] {
func (p *UDPPacket) Segments() iter.Seq[[]byte] {
return func(yield func([]byte) bool) {
//cursor := 0
for offset := 0; offset < len(p.Payload); offset += p.SegSize {

View File

@@ -1,26 +0,0 @@
package packet
import (
"github.com/slackhq/nebula/util/virtio"
)
type VirtIOPacket struct {
Payload []byte
Header virtio.NetHdr
Chains []uint16
ChainRefs [][]byte
}
func NewVIO() *VirtIOPacket {
out := new(VirtIOPacket)
out.Payload = nil
out.ChainRefs = make([][]byte, 0, 4)
out.Chains = make([]uint16, 0, 8)
return out
}
func (v *VirtIOPacket) Reset() {
v.Payload = nil
v.ChainRefs = v.ChainRefs[:0]
v.Chains = v.Chains[:0]
}

View File

@@ -10,7 +10,7 @@ import (
const MTU = 9001
type EncReader func(
[]*packet.Packet,
[]*packet.UDPPacket,
)
type Conn interface {
@@ -19,8 +19,8 @@ type Conn interface {
ListenOut(r EncReader)
WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C)
Prep(pkt *packet.Packet, addr netip.AddrPort) error
WriteBatch(pkt []*packet.Packet) (int, error)
Prep(pkt *packet.UDPPacket, addr netip.AddrPort) error
WriteBatch(pkt []*packet.UDPPacket) (int, error)
SupportsMultipleReaders() bool
Close() error
}

View File

@@ -215,7 +215,7 @@ func (u *StdConn) WriteToBatch(b []byte, ip netip.AddrPort) error {
return u.writeTo6(b, ip)
}
func (u *StdConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error {
func (u *StdConn) Prep(pkt *packet.UDPPacket, addr netip.AddrPort) error {
//todo move this into pkt
nl, err := u.encodeSockaddr(pkt.Name, addr)
if err != nil {
@@ -226,7 +226,7 @@ func (u *StdConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error {
return nil
}
func (u *StdConn) WriteBatch(pkts []*packet.Packet) (int, error) {
func (u *StdConn) WriteBatch(pkts []*packet.UDPPacket) (int, error) {
if len(pkts) == 0 {
return 0, nil
}
@@ -235,7 +235,7 @@ func (u *StdConn) WriteBatch(pkts []*packet.Packet) (int, error) {
//u.iovs = u.iovs[:0]
sent := 0
var mostRecentPkt *packet.Packet
var mostRecentPkt *packet.UDPPacket
mostRecentPktSize := 0
//segmenting := false
idx := 0

View File

@@ -52,9 +52,9 @@ func setCmsgLen(h *unix.Cmsghdr, l int) {
h.Len = uint64(l)
}
func (u *StdConn) PrepareRawMessages(n int, isV4 bool) ([]rawMessage, []*packet.Packet) {
func (u *StdConn) PrepareRawMessages(n int, isV4 bool) ([]rawMessage, []*packet.UDPPacket) {
msgs := make([]rawMessage, n)
packets := make([]*packet.Packet, n)
packets := make([]*packet.UDPPacket, n)
for i := range msgs {
packets[i] = packet.New(isV4)

View File

@@ -41,7 +41,7 @@ type TesterConn struct {
l *logrus.Logger
}
func (u *TesterConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error {
func (u *TesterConn) Prep(pkt *packet.UDPPacket, addr netip.AddrPort) error {
pkt.ReadyToSend = true
return pkt.SetAddrPort(addr)
}
@@ -96,7 +96,7 @@ func (u *TesterConn) Get(block bool) *Packet {
// Below this is boilerplate implementation to make nebula actually work
//********************************************************************************************************************//
func (u *TesterConn) WriteBatch(pkts []*packet.Packet) (int, error) {
func (u *TesterConn) WriteBatch(pkts []*packet.UDPPacket) (int, error) {
for _, pkt := range pkts {
if !pkt.ReadyToSend {
continue
@@ -141,7 +141,7 @@ func (u *TesterConn) ListenOut(r EncReader) {
if err != nil {
panic(err)
}
y := []*packet.Packet{x}
y := []*packet.UDPPacket{x}
r(y)
}
}