diff --git a/inside.go b/inside.go index 1896df8..ac65a3f 100644 --- a/inside.go +++ b/inside.go @@ -13,7 +13,7 @@ import ( "github.com/slackhq/nebula/routing" ) -func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.Packet, q int, localCache firewall.ConntrackCache, now time.Time) { +func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.UDPPacket, q int, localCache firewall.ConntrackCache, now time.Time) { err := newPacket(packet, false, fwPacket) if err != nil { if f.l.Level >= logrus.DebugLevel { @@ -412,7 +412,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType } } -func (f *Interface) sendNoMetricsDelayed(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb []byte, out *packet.Packet, q int) { +func (f *Interface) sendNoMetricsDelayed(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb []byte, out *packet.UDPPacket, q int) { if ci.eKey == nil { return } diff --git a/interface.go b/interface.go index 241c0c0..5cde576 100644 --- a/interface.go +++ b/interface.go @@ -294,7 +294,7 @@ func (f *Interface) listenOut(q int) { toSend := make([][]byte, batch) - li.ListenOut(func(pkts []*packet.Packet) { + li.ListenOut(func(pkts []*packet.UDPPacket) { toSend = toSend[:0] for i := range outPackets { outPackets[i].SegCounter = 0 @@ -323,11 +323,11 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) { conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) - packets := make([]*packet.VirtIOPacket, batch) - outPackets := make([]*packet.Packet, batch) + packets := reader.NewPacketArrays(batch) + + outPackets := make([]*packet.UDPPacket, batch) for i := 0; i < batch; i++ { - packets[i] = packet.NewVIO() - outPackets[i] = packet.New(false) //todo? + outPackets[i] = packet.New(false) //todo isv4? } for { @@ -352,9 +352,8 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) { now := time.Now() for i, pkt := range packets[:n] { outPackets[i].ReadyToSend = false - f.consumeInsidePacket(pkt.Payload, fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l), now) + f.consumeInsidePacket(pkt.GetPayload(), fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l), now) reader.RecycleRxSeg(pkt, i == (n-1), queueNum) //todo handle err? - pkt.Reset() } _, err = f.writers[queueNum].WriteBatch(outPackets[:n]) if err != nil { diff --git a/outside.go b/outside.go index 061b179..5ef80fa 100644 --- a/outside.go +++ b/outside.go @@ -359,7 +359,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe f.connectionManager.In(hostinfo) } -func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) { +func (f *Interface) readOutsidePacketsMany(packets []*packet.UDPPacket, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) { for i, pkt := range packets { out[i].Scratch = out[i].Scratch[:0] via := ViaSender{UdpAddr: pkt.AddrPort()} diff --git a/overlay/packets.go b/overlay/packets.go new file mode 100644 index 0000000..4245d1c --- /dev/null +++ b/overlay/packets.go @@ -0,0 +1,36 @@ +package overlay + +//import ( +// "github.com/slackhq/nebula/util/virtio" +//) + +//type VirtIOPacket struct { +// Payload []byte +// Header virtio.NetHdr +// Chains []uint16 +// ChainRefs [][]byte +//} +// +//func NewVIO() *VirtIOPacket { +// out := new(VirtIOPacket) +// out.Payload = nil +// out.ChainRefs = make([][]byte, 0, 4) +// out.Chains = make([]uint16, 0, 8) +// return out +//} +// +//func (v *VirtIOPacket) Reset() { +// v.Payload = nil +// v.ChainRefs = v.ChainRefs[:0] +// v.Chains = v.Chains[:0] +//} + +// TunPacket is formerly VirtIOPacket +type TunPacket interface { + SetPayload([]byte) + GetPayload() []byte +} +type OutPacket interface { + SetPayload([]byte) + GetPayload() []byte +} diff --git a/overlay/tun.go b/overlay/tun.go index 1dc914c..d22af28 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -16,13 +16,15 @@ const DefaultMTU = 1300 type TunDev interface { io.WriteCloser - ReadMany(x []*packet.VirtIOPacket, q int) (int, error) + NewPacketArrays(batchSize int) []TunPacket + + ReadMany(x []TunPacket, q int) (int, error) + RecycleRxSeg(pkt TunPacket, kick bool, q int) error //todo this interface sux AllocSeg(pkt *packet.OutPacket, q int) (int, error) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) WriteMany(x []*packet.OutPacket, q int) (int, error) - RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error } // TODO: We may be able to remove routines @@ -31,8 +33,8 @@ type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefi func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { switch { case c.GetBool("tun.disabled", false): - tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) - return tun, nil + t := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) + return t, nil default: return newTun(c, l, vpnNetworks, routines > 1) diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index 086a676..755e293 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -24,7 +24,11 @@ type disabledTun struct { l *logrus.Logger } -func (*disabledTun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error { +func (t *disabledTun) NewPacketArrays(batchSize int) []TunPacket { + panic("implement me") //TODO +} + +func (*disabledTun) RecycleRxSeg(pkt TunPacket, kick bool, q int) error { return nil } @@ -131,8 +135,8 @@ func (t *disabledTun) WriteMany(x []*packet.OutPacket, q int) (int, error) { return 0, fmt.Errorf("tun_disabled: WriteMany not implemented") } -func (t *disabledTun) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) { - return t.Read(b[0].Payload) +func (t *disabledTun) ReadMany(b []TunPacket, _ int) (int, error) { + return t.Read(b[0].GetPayload()) } func (t *disabledTun) NewMultiQueueReader() (TunDev, error) { diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 386b37d..cb1aa58 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -4,6 +4,7 @@ package overlay import ( + "context" "fmt" "net" "net/netip" @@ -183,6 +184,14 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n return t, nil } +func (t *tun) NewPacketArrays(batchSize int) []TunPacket { + inPackets := make([]TunPacket, batchSize) + for i := 0; i < batchSize; i++ { + inPackets[i] = vhostnet.NewVIO() + } + return inPackets +} + func (t *tun) reload(c *config.C, initial bool) error { routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { @@ -725,12 +734,25 @@ func (t *tun) Close() error { return nil } -func (t *tun) ReadMany(p []*packet.VirtIOPacket, q int) (int, error) { - n, err := t.vdev[q].ReceivePackets(p) //we are TXing +func (t *tun) ReadMany(p []TunPacket, q int) (int, error) { + err := t.vdev[q].ReceiveQueue.WaitForUsedElements(context.TODO()) if err != nil { return 0, err } - return n, nil + i := 0 + for i = 0; i < len(p); i++ { + item, ok := t.vdev[q].ReceiveQueue.TakeSingleNoBlock() + if !ok { + break + } + pkt := p[i].(*vhostnet.VirtIOPacket) //todo I'm not happy about this but I don't want to change how memory is "owned" rn + _, err = t.vdev[q].ProcessRxChain(pkt, item) + if err != nil { + return i, err + } + i++ + } + return i, nil } func (t *tun) Write(b []byte) (int, error) { @@ -783,6 +805,9 @@ func (t *tun) WriteMany(x []*packet.OutPacket, q int) (int, error) { return maximum, nil } -func (t *tun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error { - return t.vdev[q].ReceiveQueue.OfferDescriptorChains(pkt.Chains, kick) +func (t *tun) RecycleRxSeg(pkt TunPacket, kick bool, q int) error { + vpkt := pkt.(*vhostnet.VirtIOPacket) + err := t.vdev[q].ReceiveQueue.OfferDescriptorChains(vpkt.Chains, kick) + vpkt.Reset() //intentionally ignoring err! + return err } diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 5ffbb95..96abd0f 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -106,7 +106,7 @@ func (t *TestTun) Name() string { return t.Device } -func (t *TestTun) ReadMany(x []*packet.VirtIOPacket, q int) (int, error) { +func (t *TestTun) ReadMany(x []TunPacket, q int) (int, error) { p, ok := <-t.rxPackets if !ok { return 0, os.ErrClosed @@ -165,7 +165,7 @@ func (t *TestTun) WriteMany(x []*packet.OutPacket, q int) (int, error) { return len(x), nil } -func (t *TestTun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error { +func (t *TestTun) RecycleRxSeg(pkt *TunPacket, kick bool, q int) error { //todo this ought to maybe track something return nil } diff --git a/overlay/user.go b/overlay/user.go index 0a5857e..789d00d 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -38,7 +38,18 @@ type UserDevice struct { inboundWriter *io.PipeWriter } -func (d *UserDevice) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error { +func (d *UserDevice) NewPacketArrays(batchSize int) []TunPacket { + //inPackets := make([]TunPacket, batchSize) + //outPackets := make([]OutPacket, batchSize) + panic("not implemented") //todo! + //for i := 0; i < batchSize; i++ { + // inPackets[i] = vhostnet.NewVIO() + // outPackets[i] = packet.New(false) + //} + //return inPackets, outPackets +} + +func (d *UserDevice) RecycleRxSeg(pkt TunPacket, kick bool, q int) error { return nil } @@ -76,8 +87,12 @@ func (d *UserDevice) Close() error { return nil } -func (d *UserDevice) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) { - return d.Read(b[0].Payload) +func (d *UserDevice) ReadMany(b []TunPacket, _ int) (int, error) { + _, err := d.Read(b[0].GetPayload()) + if err != nil { + return 0, err + } + return 1, nil } func (d *UserDevice) AllocSeg(pkt *packet.OutPacket, q int) (int, error) { diff --git a/overlay/vhostnet/device.go b/overlay/vhostnet/device.go index c154336..6e6109c 100644 --- a/overlay/vhostnet/device.go +++ b/overlay/vhostnet/device.go @@ -118,7 +118,7 @@ func NewDevice(options ...Option) (*Device, error) { return nil, fmt.Errorf("set transmit queue backend: %w", err) } - // Fully populate the receive queue with available buffers which the device + // Fully populate the rx queue with available buffers which the device // can write new packets into. if err = dev.refillReceiveQueue(); err != nil { return nil, fmt.Errorf("refill receive queue: %w", err) @@ -198,11 +198,8 @@ func (dev *Device) Close() error { // createQueue creates a new virtqueue and registers it with the vhost device // using the given index. func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*virtqueue.SplitQueue, error) { - var ( - queue *virtqueue.SplitQueue - err error - ) - if queue, err = virtqueue.NewSplitQueue(queueSize, itemSize); err != nil { + queue, err := virtqueue.NewSplitQueue(queueSize, itemSize) + if err != nil { return nil, fmt.Errorf("create virtqueue: %w", err) } if err = vhost.RegisterQueue(controlFD, uint32(queueIndex), queue); err != nil { @@ -218,10 +215,10 @@ func (dev *Device) GetPacketForTx() (uint16, []byte, error) { idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs() if err == virtqueue.ErrNotEnoughFreeDescriptors { dev.fullTable = true - idx, err = dev.TransmitQueue.TakeSingle(context.TODO()) + idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO()) } } else { - idx, err = dev.TransmitQueue.TakeSingle(context.TODO()) + idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO()) } if err != nil { return 0, nil, fmt.Errorf("transmit queue: %w", err) @@ -271,18 +268,15 @@ func (dev *Device) TransmitPackets(pkts []*packet.OutPacket) error { return nil } -// processChains processes as many chains as needed to create one packet. The number of processed chains is returned. -func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.UsedElement) (int, error) { +// ProcessRxChain processes a single chain to create one packet. The number of processed chains is returned. +func (dev *Device) ProcessRxChain(pkt *VirtIOPacket, chain virtqueue.UsedElement) (int, error) { //read first element to see how many descriptors we need: pkt.Reset() - - err := dev.ReceiveQueue.GetDescriptorInbuffers(uint16(chains[0].DescriptorIndex), &pkt.ChainRefs) + idx := uint16(chain.DescriptorIndex) + buf, err := dev.ReceiveQueue.GetDescriptorItem(idx) if err != nil { return 0, fmt.Errorf("get descriptor chain: %w", err) } - if len(pkt.ChainRefs) == 0 { - return 1, nil - } // The specification requires that the first descriptor chain starts // with a virtio-net header. It is not clear, whether it is also @@ -290,7 +284,7 @@ func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.Us // descriptor chain, but it is reasonable to assume that this is // always the case. // The decode method already does the buffer length check. - if err = pkt.Header.Decode(pkt.ChainRefs[0][0:]); err != nil { + if err = pkt.header.Decode(buf); err != nil { // The device misbehaved. There is no way we can gracefully // recover from this, because we don't know how many of the // following descriptor chains belong to this packet. @@ -298,72 +292,44 @@ func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.Us } //we have the header now: what do we need to do? - if int(pkt.Header.NumBuffers) > len(chains) { - return 0, fmt.Errorf("number of buffers is greater than number of chains %d", len(chains)) + if int(pkt.header.NumBuffers) > 1 { + return 0, fmt.Errorf("number of buffers is greater than number of chains %d", 1) } - if int(pkt.Header.NumBuffers) != 1 { - return 0, fmt.Errorf("too smol-brain to handle more than one chain right now: %d chains", len(chains)) + if int(pkt.header.NumBuffers) != 1 { + return 0, fmt.Errorf("too smol-brain to handle more than one buffer per chain item right now: %d chains, %d bufs", 1, int(pkt.header.NumBuffers)) } - if chains[0].Length > 16000 { + if chain.Length > 16000 { //todo! - return 1, fmt.Errorf("too big packet length: %d", chains[0].Length) + return 1, fmt.Errorf("too big packet length: %d", chain.Length) } //shift the buffer out of out: - pkt.Payload = pkt.ChainRefs[0][virtio.NetHdrSize:chains[0].Length] - pkt.Chains = append(pkt.Chains, uint16(chains[0].DescriptorIndex)) + pkt.payload = buf[virtio.NetHdrSize:chain.Length] + pkt.Chains = append(pkt.Chains, idx) return 1, nil - - //cursor := n - virtio.NetHdrSize - // - //if uint32(n) >= chains[0].Length && pkt.Header.NumBuffers == 1 { - // pkt.Payload = pkt.Payload[:chains[0].Length-virtio.NetHdrSize] - // return 1, nil - //} - // - //i := 1 - //// we used chain 0 already - //for i = 1; i < len(chains); i++ { - // n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:], int(chains[i].Length)) - // if err != nil { - // // When this fails we may miss to free some descriptor chains. We - // // could try to mitigate this by deferring the freeing somehow, but - // // it's not worth the hassle. When this method fails, the queue will - // // be in a broken state anyway. - // return i, fmt.Errorf("get descriptor chain: %w", err) - // } - // cursor += n - //} - ////todo this has to be wrong - //pkt.Payload = pkt.Payload[:cursor] - //return i, nil } -func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) { - //todo optimize? - var chains []virtqueue.UsedElement - var err error - - chains, err = dev.ReceiveQueue.BlockAndGetHeadsCapped(context.TODO(), len(out)) - if err != nil { - return 0, err - } - if len(chains) == 0 { - return 0, nil - } - - numPackets := 0 - chainsIdx := 0 - for numPackets = 0; chainsIdx < len(chains); numPackets++ { - if numPackets >= len(out) { - return numPackets, fmt.Errorf("dropping %d packets, no room", len(chains)-numPackets) - } - numChains, err := dev.processChains(out[numPackets], chains[chainsIdx:]) - if err != nil { - return 0, err - } - chainsIdx += numChains - } - - return numPackets, nil +type VirtIOPacket struct { + payload []byte + header virtio.NetHdr + Chains []uint16 +} + +func NewVIO() *VirtIOPacket { + out := new(VirtIOPacket) + out.payload = nil + out.Chains = make([]uint16, 0, 8) + return out +} + +func (v *VirtIOPacket) Reset() { + v.payload = nil + v.Chains = v.Chains[:0] +} + +func (v *VirtIOPacket) GetPayload() []byte { + return v.payload +} +func (v *VirtIOPacket) SetPayload(x []byte) { + v.payload = x //todo? } diff --git a/overlay/virtqueue/descriptor_table.go b/overlay/virtqueue/descriptor_table.go index 298036f..7b4ca75 100644 --- a/overlay/virtqueue/descriptor_table.go +++ b/overlay/virtqueue/descriptor_table.go @@ -10,10 +10,6 @@ import ( ) var ( - // ErrDescriptorChainEmpty is returned when a descriptor chain would contain - // no buffers, which is not allowed. - ErrDescriptorChainEmpty = errors.New("empty descriptor chains are not allowed") - // ErrNotEnoughFreeDescriptors is returned when the free descriptors are // exhausted, meaning that the queue is full. ErrNotEnoughFreeDescriptors = errors.New("not enough free descriptors, queue is full") @@ -272,59 +268,6 @@ func (dt *DescriptorTable) createDescriptorForInputs() (uint16, error) { return head, nil } -// TODO: Implement a zero-copy variant of createDescriptorChain? - -// getDescriptorChain returns the device-readable buffers (out buffers) and -// device-writable buffers (in buffers) of the descriptor chain that starts with -// the given head index. The descriptor chain must have been created using -// [createDescriptorChain] and must not have been freed yet (meaning that the -// head index must not be contained in the free chain). -// -// Be careful to only access the returned buffer slices when the device has not -// yet or is no longer using them. They must not be accessed after -// [freeDescriptorChain] has been called. -func (dt *DescriptorTable) getDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) { - if int(head) > len(dt.descriptors) { - return nil, nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain) - } - - // Iterate over the chain. The iteration is limited to the queue size to - // avoid ending up in an endless loop when things go very wrong. - next := head - for range len(dt.descriptors) { - if next == dt.freeHeadIndex { - return nil, nil, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain) - } - - desc := &dt.descriptors[next] - - // The descriptor address points to memory not managed by Go, so this - // conversion is safe. See https://github.com/golang/go/issues/58625 - //goland:noinspection GoVetUnsafePointer - bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length) - - if desc.flags&descriptorFlagWritable == 0 { - outBuffers = append(outBuffers, bs) - } else { - inBuffers = append(inBuffers, bs) - } - - // Is this the tail of the chain? - if desc.flags&descriptorFlagHasNext == 0 { - break - } - - // Detect loops. - if desc.next == head { - return nil, nil, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain) - } - - next = desc.next - } - - return -} - func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) { if int(head) > len(dt.descriptors) { return nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain) @@ -339,121 +282,6 @@ func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) { return bs, nil } -func (dt *DescriptorTable) getDescriptorInbuffers(head uint16, inBuffers *[][]byte) error { - if int(head) > len(dt.descriptors) { - return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain) - } - - // Iterate over the chain. The iteration is limited to the queue size to - // avoid ending up in an endless loop when things go very wrong. - next := head - for range len(dt.descriptors) { - if next == dt.freeHeadIndex { - return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain) - } - - desc := &dt.descriptors[next] - - // The descriptor address points to memory not managed by Go, so this - // conversion is safe. See https://github.com/golang/go/issues/58625 - //goland:noinspection GoVetUnsafePointer - bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length) - - if desc.flags&descriptorFlagWritable == 0 { - return fmt.Errorf("there should not be an outbuffer in %d", head) - } else { - *inBuffers = append(*inBuffers, bs) - } - - // Is this the tail of the chain? - if desc.flags&descriptorFlagHasNext == 0 { - break - } - - // Detect loops. - if desc.next == head { - return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain) - } - - next = desc.next - } - - return nil -} - -// freeDescriptorChain can be used to free a descriptor chain when it is no -// longer in use. The descriptor chain that starts with the given index will be -// put back into the free chain, so the descriptors can be used for later calls -// of [createDescriptorChain]. -// The descriptor chain must have been created using [createDescriptorChain] and -// must not have been freed yet (meaning that the head index must not be -// contained in the free chain). -func (dt *DescriptorTable) freeDescriptorChain(head uint16) error { - if int(head) > len(dt.descriptors) { - return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain) - } - - // Iterate over the chain. The iteration is limited to the queue size to - // avoid ending up in an endless loop when things go very wrong. - next := head - var tailDesc *Descriptor - var chainLen uint16 - for range len(dt.descriptors) { - if next == dt.freeHeadIndex { - return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain) - } - - desc := &dt.descriptors[next] - chainLen++ - - // Set the length of all unused descriptors back to zero. - desc.length = 0 - - // Unset all flags except the next flag. - desc.flags &= descriptorFlagHasNext - - // Is this the tail of the chain? - if desc.flags&descriptorFlagHasNext == 0 { - tailDesc = desc - break - } - - // Detect loops. - if desc.next == head { - return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain) - } - - next = desc.next - } - if tailDesc == nil { - // A descriptor chain longer than the queue size but without loops - // should be impossible. - panic(fmt.Sprintf("could not find a tail for descriptor chain starting at %d", head)) - } - - // The tail descriptor does not have the next flag set, but when it comes - // back into the free chain, it should have. - tailDesc.flags = descriptorFlagHasNext - - if dt.freeHeadIndex == noFreeHead { - // The whole free chain was used up, so we turn this returned descriptor - // chain into the new free chain by completing the circle and using its - // head. - tailDesc.next = head - dt.freeHeadIndex = head - } else { - // Attach the returned chain at the beginning of the free chain but - // right after the free chain head. - freeHeadDesc := &dt.descriptors[dt.freeHeadIndex] - tailDesc.next = freeHeadDesc.next - freeHeadDesc.next = head - } - - dt.freeNum += chainLen - - return nil -} - // checkUnusedDescriptorLength asserts that the length of an unused descriptor // is zero, as it should be. // This is not a requirement by the virtio spec but rather a thing we do to diff --git a/overlay/virtqueue/split_virtqueue.go b/overlay/virtqueue/split_virtqueue.go index 59421ea..d426b70 100644 --- a/overlay/virtqueue/split_virtqueue.go +++ b/overlay/virtqueue/split_virtqueue.go @@ -128,8 +128,7 @@ func NewSplitQueue(queueSize int, itemSize int) (_ *SplitQueue, err error) { return nil, err } - // Consume used buffer notifications in the background. - sq.stop = sq.startConsumeUsedRing() + sq.stop = sq.kickSelfToExit() return &sq, nil } @@ -169,9 +168,7 @@ func (sq *SplitQueue) CallEventFD() int { return sq.callEventFD.FD() } -// startConsumeUsedRing starts a goroutine that runs [consumeUsedRing]. -// A function is returned that can be used to gracefully cancel it. todo rename -func (sq *SplitQueue) startConsumeUsedRing() func() error { +func (sq *SplitQueue) kickSelfToExit() func() error { return func() error { // The goroutine blocks until it receives a signal on the event file @@ -185,7 +182,15 @@ func (sq *SplitQueue) startConsumeUsedRing() func() error { } } -func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) { +func (sq *SplitQueue) TakeSingleIndex(ctx context.Context) (uint16, error) { + element, err := sq.TakeSingle(ctx) + if err != nil { + return 0xffff, err + } + return element.GetHead(), nil +} + +func (sq *SplitQueue) TakeSingle(ctx context.Context) (UsedElement, error) { var n int var err error for ctx.Err() == nil { @@ -195,7 +200,7 @@ func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) { } // Wait for a signal from the device. if n, err = sq.epoll.Block(); err != nil { - return 0, fmt.Errorf("wait: %w", err) + return UsedElement{}, fmt.Errorf("wait: %w", err) } if n > 0 { @@ -208,7 +213,31 @@ func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) { } } } - return 0, ctx.Err() + return UsedElement{}, ctx.Err() +} + +func (sq *SplitQueue) TakeSingleNoBlock() (UsedElement, bool) { + return sq.usedRing.takeOne() +} + +func (sq *SplitQueue) WaitForUsedElements(ctx context.Context) error { + if sq.usedRing.availableToTake() != 0 { + return nil + } + for ctx.Err() == nil { + // Wait for a signal from the device. + n, err := sq.epoll.Block() + if err != nil { + return fmt.Errorf("wait: %w", err) + } + if n > 0 { + _ = sq.epoll.Clear() + if sq.usedRing.availableToTake() != 0 { + return nil + } + } + } + return ctx.Err() } func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) ([]UsedElement, error) { @@ -235,7 +264,7 @@ func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) return nil, fmt.Errorf("wait: %w", err) } if n > 0 { - _ = sq.epoll.Clear() //??? + _ = sq.epoll.Clear() stillNeedToTake, out = sq.usedRing.take(maxToTake) sq.more = stillNeedToTake return out, nil @@ -296,16 +325,14 @@ func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) { sq.availableRing.offerSingle(head) // Notify the device to make it process the updated available ring. - if err := sq.kickEventFD.Kick(); err != nil { + if err = sq.kickEventFD.Kick(); err != nil { return head, fmt.Errorf("notify device: %w", err) } return head, nil } -// GetDescriptorChain returns the device-readable buffers (out buffers) and -// device-writable buffers (in buffers) of the descriptor chain with the given -// head index. +// GetDescriptorItem returns the buffer of a given index // The head index must be one that was returned by a previous call to // [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been // freed yet. @@ -313,37 +340,11 @@ func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) { // Be careful to only access the returned buffer slices when the device is no // longer using them. They must not be accessed after // [SplitQueue.FreeDescriptorChain] has been called. -func (sq *SplitQueue) GetDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) { - return sq.descriptorTable.getDescriptorChain(head) -} - func (sq *SplitQueue) GetDescriptorItem(head uint16) ([]byte, error) { sq.descriptorTable.descriptors[head].length = uint32(sq.descriptorTable.itemSize) return sq.descriptorTable.getDescriptorItem(head) } -func (sq *SplitQueue) GetDescriptorInbuffers(head uint16, inBuffers *[][]byte) error { - return sq.descriptorTable.getDescriptorInbuffers(head, inBuffers) -} - -// FreeDescriptorChain frees the descriptor chain with the given head index. -// The head index must be one that was returned by a previous call to -// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been -// freed yet. -// -// This creates new room in the queue which can be used by following -// [SplitQueue.OfferDescriptorChain] calls. -// When there are outstanding calls for [SplitQueue.OfferDescriptorChain] that -// are waiting for free room in the queue, they may become unblocked by this. -func (sq *SplitQueue) FreeDescriptorChain(head uint16) error { - //not called under lock - if err := sq.descriptorTable.freeDescriptorChain(head); err != nil { - return fmt.Errorf("free: %w", err) - } - - return nil -} - func (sq *SplitQueue) SetDescSize(head uint16, sz int) { //not called under lock sq.descriptorTable.descriptors[int(head)].length = uint32(sz) diff --git a/overlay/virtqueue/used_ring.go b/overlay/virtqueue/used_ring.go index 824c07c..acf65fe 100644 --- a/overlay/virtqueue/used_ring.go +++ b/overlay/virtqueue/used_ring.go @@ -84,17 +84,11 @@ func (r *UsedRing) Address() uintptr { return uintptr(unsafe.Pointer(r.flags)) } -// take returns all new [UsedElement]s that the device put into the ring and -// that weren't already returned by a previous call to this method. -// had a lock, I removed it -func (r *UsedRing) take(maxToTake int) (int, []UsedElement) { - //r.mu.Lock() - //defer r.mu.Unlock() - +func (r *UsedRing) availableToTake() int { ringIndex := *r.ringIndex if ringIndex == r.lastIndex { // Nothing new. - return 0, nil + return 0 } // Calculate the number new used elements that we can read from the ring. @@ -103,6 +97,16 @@ func (r *UsedRing) take(maxToTake int) (int, []UsedElement) { if count < 0 { count += 0xffff } + return count +} + +// take returns all new [UsedElement]s that the device put into the ring and +// that weren't already returned by a previous call to this method. +func (r *UsedRing) take(maxToTake int) (int, []UsedElement) { + count := r.availableToTake() + if count == 0 { + return 0, nil + } stillNeedToTake := 0 @@ -128,21 +132,13 @@ func (r *UsedRing) take(maxToTake int) (int, []UsedElement) { return stillNeedToTake, elems } -func (r *UsedRing) takeOne() (uint16, bool) { +func (r *UsedRing) takeOne() (UsedElement, bool) { //r.mu.Lock() //defer r.mu.Unlock() - ringIndex := *r.ringIndex - if ringIndex == r.lastIndex { - // Nothing new. - return 0xffff, false - } - - // Calculate the number new used elements that we can read from the ring. - // The ring index may wrap, so special handling for that case is needed. - count := int(ringIndex - r.lastIndex) - if count < 0 { - count += 0xffff + count := r.availableToTake() + if count == 0 { + return UsedElement{}, false } // The number of new elements can never exceed the queue size. @@ -150,11 +146,7 @@ func (r *UsedRing) takeOne() (uint16, bool) { panic("used ring contains more new elements than the ring is long") } - if count == 0 { - return 0xffff, false - } - - out := r.ring[r.lastIndex%uint16(len(r.ring))].GetHead() + out := r.ring[r.lastIndex%uint16(len(r.ring))] r.lastIndex++ return out, true diff --git a/packet/packet.go b/packet/packet.go index 022f31d..5ae6db6 100644 --- a/packet/packet.go +++ b/packet/packet.go @@ -14,7 +14,7 @@ import ( const Size = 0xffff -type Packet struct { +type UDPPacket struct { Payload []byte Control []byte Name []byte @@ -25,8 +25,8 @@ type Packet struct { isV4 bool } -func New(isV4 bool) *Packet { - return &Packet{ +func New(isV4 bool) *UDPPacket { + return &UDPPacket{ Payload: make([]byte, Size), Control: make([]byte, unix.CmsgSpace(2)), Name: make([]byte, unix.SizeofSockaddrInet6), @@ -34,7 +34,7 @@ func New(isV4 bool) *Packet { } } -func (p *Packet) AddrPort() netip.AddrPort { +func (p *UDPPacket) AddrPort() netip.AddrPort { var ip netip.Addr // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic if p.isV4 { @@ -45,7 +45,7 @@ func (p *Packet) AddrPort() netip.AddrPort { return netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(p.Name[2:4])) } -func (p *Packet) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) { +func (p *UDPPacket) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) { //todo no chance this works on windows? if p.isV4 { if !addr.Addr().Is4() { @@ -69,7 +69,7 @@ func (p *Packet) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) return uint32(size), nil } -func (p *Packet) SetAddrPort(addr netip.AddrPort) error { +func (p *UDPPacket) SetAddrPort(addr netip.AddrPort) error { nl, err := p.encodeSockaddr(p.Name, addr) if err != nil { return err @@ -78,7 +78,7 @@ func (p *Packet) SetAddrPort(addr netip.AddrPort) error { return nil } -func (p *Packet) updateCtrl(ctrlLen int) { +func (p *UDPPacket) updateCtrl(ctrlLen int) { p.SegSize = len(p.Payload) p.wasSegmented = false if ctrlLen == 0 { @@ -101,12 +101,12 @@ func (p *Packet) updateCtrl(ctrlLen int) { } } -// Update sets a Packet into "just received, not processed" state -func (p *Packet) Update(ctrlLen int) { +// Update sets a UDPPacket into "just received, not processed" state +func (p *UDPPacket) Update(ctrlLen int) { p.updateCtrl(ctrlLen) } -func (p *Packet) SetSegSizeForTX() { +func (p *UDPPacket) SetSegSizeForTX() { p.SegSize = len(p.Payload) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&p.Control[0])) hdr.Level = unix.SOL_UDP @@ -115,7 +115,7 @@ func (p *Packet) SetSegSizeForTX() { binary.NativeEndian.PutUint16(p.Control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(p.SegSize)) } -func (p *Packet) CompatibleForSegmentationWith(otherP *Packet, currentTotalSize int) bool { +func (p *UDPPacket) CompatibleForSegmentationWith(otherP *UDPPacket, currentTotalSize int) bool { //same dest if !slices.Equal(p.Name, otherP.Name) { return false @@ -134,7 +134,7 @@ func (p *Packet) CompatibleForSegmentationWith(otherP *Packet, currentTotalSize return true } -func (p *Packet) Segments() iter.Seq[[]byte] { +func (p *UDPPacket) Segments() iter.Seq[[]byte] { return func(yield func([]byte) bool) { //cursor := 0 for offset := 0; offset < len(p.Payload); offset += p.SegSize { diff --git a/packet/virtio.go b/packet/virtio.go deleted file mode 100644 index e07a8ae..0000000 --- a/packet/virtio.go +++ /dev/null @@ -1,26 +0,0 @@ -package packet - -import ( - "github.com/slackhq/nebula/util/virtio" -) - -type VirtIOPacket struct { - Payload []byte - Header virtio.NetHdr - Chains []uint16 - ChainRefs [][]byte -} - -func NewVIO() *VirtIOPacket { - out := new(VirtIOPacket) - out.Payload = nil - out.ChainRefs = make([][]byte, 0, 4) - out.Chains = make([]uint16, 0, 8) - return out -} - -func (v *VirtIOPacket) Reset() { - v.Payload = nil - v.ChainRefs = v.ChainRefs[:0] - v.Chains = v.Chains[:0] -} diff --git a/udp/conn.go b/udp/conn.go index f249389..380f4db 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -10,7 +10,7 @@ import ( const MTU = 9001 type EncReader func( - []*packet.Packet, + []*packet.UDPPacket, ) type Conn interface { @@ -19,8 +19,8 @@ type Conn interface { ListenOut(r EncReader) WriteTo(b []byte, addr netip.AddrPort) error ReloadConfig(c *config.C) - Prep(pkt *packet.Packet, addr netip.AddrPort) error - WriteBatch(pkt []*packet.Packet) (int, error) + Prep(pkt *packet.UDPPacket, addr netip.AddrPort) error + WriteBatch(pkt []*packet.UDPPacket) (int, error) SupportsMultipleReaders() bool Close() error } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index a70939a..8f3b092 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -215,7 +215,7 @@ func (u *StdConn) WriteToBatch(b []byte, ip netip.AddrPort) error { return u.writeTo6(b, ip) } -func (u *StdConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error { +func (u *StdConn) Prep(pkt *packet.UDPPacket, addr netip.AddrPort) error { //todo move this into pkt nl, err := u.encodeSockaddr(pkt.Name, addr) if err != nil { @@ -226,7 +226,7 @@ func (u *StdConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error { return nil } -func (u *StdConn) WriteBatch(pkts []*packet.Packet) (int, error) { +func (u *StdConn) WriteBatch(pkts []*packet.UDPPacket) (int, error) { if len(pkts) == 0 { return 0, nil } @@ -235,7 +235,7 @@ func (u *StdConn) WriteBatch(pkts []*packet.Packet) (int, error) { //u.iovs = u.iovs[:0] sent := 0 - var mostRecentPkt *packet.Packet + var mostRecentPkt *packet.UDPPacket mostRecentPktSize := 0 //segmenting := false idx := 0 diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index e9e3ccb..f5022b4 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -52,9 +52,9 @@ func setCmsgLen(h *unix.Cmsghdr, l int) { h.Len = uint64(l) } -func (u *StdConn) PrepareRawMessages(n int, isV4 bool) ([]rawMessage, []*packet.Packet) { +func (u *StdConn) PrepareRawMessages(n int, isV4 bool) ([]rawMessage, []*packet.UDPPacket) { msgs := make([]rawMessage, n) - packets := make([]*packet.Packet, n) + packets := make([]*packet.UDPPacket, n) for i := range msgs { packets[i] = packet.New(isV4) diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 45e49f5..f618899 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -41,7 +41,7 @@ type TesterConn struct { l *logrus.Logger } -func (u *TesterConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error { +func (u *TesterConn) Prep(pkt *packet.UDPPacket, addr netip.AddrPort) error { pkt.ReadyToSend = true return pkt.SetAddrPort(addr) } @@ -96,7 +96,7 @@ func (u *TesterConn) Get(block bool) *Packet { // Below this is boilerplate implementation to make nebula actually work //********************************************************************************************************************// -func (u *TesterConn) WriteBatch(pkts []*packet.Packet) (int, error) { +func (u *TesterConn) WriteBatch(pkts []*packet.UDPPacket) (int, error) { for _, pkt := range pkts { if !pkt.ReadyToSend { continue @@ -141,7 +141,7 @@ func (u *TesterConn) ListenOut(r EncReader) { if err != nil { panic(err) } - y := []*packet.Packet{x} + y := []*packet.UDPPacket{x} r(y) } }