mirror of
https://github.com/slackhq/nebula.git
synced 2025-12-31 02:58:28 +01:00
refactoring a bit
This commit is contained in:
@@ -13,7 +13,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.Packet, q int, localCache firewall.ConntrackCache, now time.Time) {
|
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.UDPPacket, q int, localCache firewall.ConntrackCache, now time.Time) {
|
||||||
err := newPacket(packet, false, fwPacket)
|
err := newPacket(packet, false, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
@@ -412,7 +412,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) sendNoMetricsDelayed(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb []byte, out *packet.Packet, q int) {
|
func (f *Interface) sendNoMetricsDelayed(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb []byte, out *packet.UDPPacket, q int) {
|
||||||
if ci.eKey == nil {
|
if ci.eKey == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
13
interface.go
13
interface.go
@@ -294,7 +294,7 @@ func (f *Interface) listenOut(q int) {
|
|||||||
|
|
||||||
toSend := make([][]byte, batch)
|
toSend := make([][]byte, batch)
|
||||||
|
|
||||||
li.ListenOut(func(pkts []*packet.Packet) {
|
li.ListenOut(func(pkts []*packet.UDPPacket) {
|
||||||
toSend = toSend[:0]
|
toSend = toSend[:0]
|
||||||
for i := range outPackets {
|
for i := range outPackets {
|
||||||
outPackets[i].SegCounter = 0
|
outPackets[i].SegCounter = 0
|
||||||
@@ -323,11 +323,11 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
|
|||||||
|
|
||||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||||
|
|
||||||
packets := make([]*packet.VirtIOPacket, batch)
|
packets := reader.NewPacketArrays(batch)
|
||||||
outPackets := make([]*packet.Packet, batch)
|
|
||||||
|
outPackets := make([]*packet.UDPPacket, batch)
|
||||||
for i := 0; i < batch; i++ {
|
for i := 0; i < batch; i++ {
|
||||||
packets[i] = packet.NewVIO()
|
outPackets[i] = packet.New(false) //todo isv4?
|
||||||
outPackets[i] = packet.New(false) //todo?
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -352,9 +352,8 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
|
|||||||
now := time.Now()
|
now := time.Now()
|
||||||
for i, pkt := range packets[:n] {
|
for i, pkt := range packets[:n] {
|
||||||
outPackets[i].ReadyToSend = false
|
outPackets[i].ReadyToSend = false
|
||||||
f.consumeInsidePacket(pkt.Payload, fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l), now)
|
f.consumeInsidePacket(pkt.GetPayload(), fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l), now)
|
||||||
reader.RecycleRxSeg(pkt, i == (n-1), queueNum) //todo handle err?
|
reader.RecycleRxSeg(pkt, i == (n-1), queueNum) //todo handle err?
|
||||||
pkt.Reset()
|
|
||||||
}
|
}
|
||||||
_, err = f.writers[queueNum].WriteBatch(outPackets[:n])
|
_, err = f.writers[queueNum].WriteBatch(outPackets[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -359,7 +359,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe
|
|||||||
f.connectionManager.In(hostinfo)
|
f.connectionManager.In(hostinfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
|
func (f *Interface) readOutsidePacketsMany(packets []*packet.UDPPacket, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
|
||||||
for i, pkt := range packets {
|
for i, pkt := range packets {
|
||||||
out[i].Scratch = out[i].Scratch[:0]
|
out[i].Scratch = out[i].Scratch[:0]
|
||||||
via := ViaSender{UdpAddr: pkt.AddrPort()}
|
via := ViaSender{UdpAddr: pkt.AddrPort()}
|
||||||
|
|||||||
36
overlay/packets.go
Normal file
36
overlay/packets.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package overlay
|
||||||
|
|
||||||
|
//import (
|
||||||
|
// "github.com/slackhq/nebula/util/virtio"
|
||||||
|
//)
|
||||||
|
|
||||||
|
//type VirtIOPacket struct {
|
||||||
|
// Payload []byte
|
||||||
|
// Header virtio.NetHdr
|
||||||
|
// Chains []uint16
|
||||||
|
// ChainRefs [][]byte
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//func NewVIO() *VirtIOPacket {
|
||||||
|
// out := new(VirtIOPacket)
|
||||||
|
// out.Payload = nil
|
||||||
|
// out.ChainRefs = make([][]byte, 0, 4)
|
||||||
|
// out.Chains = make([]uint16, 0, 8)
|
||||||
|
// return out
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//func (v *VirtIOPacket) Reset() {
|
||||||
|
// v.Payload = nil
|
||||||
|
// v.ChainRefs = v.ChainRefs[:0]
|
||||||
|
// v.Chains = v.Chains[:0]
|
||||||
|
//}
|
||||||
|
|
||||||
|
// TunPacket is formerly VirtIOPacket
|
||||||
|
type TunPacket interface {
|
||||||
|
SetPayload([]byte)
|
||||||
|
GetPayload() []byte
|
||||||
|
}
|
||||||
|
type OutPacket interface {
|
||||||
|
SetPayload([]byte)
|
||||||
|
GetPayload() []byte
|
||||||
|
}
|
||||||
@@ -16,13 +16,15 @@ const DefaultMTU = 1300
|
|||||||
|
|
||||||
type TunDev interface {
|
type TunDev interface {
|
||||||
io.WriteCloser
|
io.WriteCloser
|
||||||
ReadMany(x []*packet.VirtIOPacket, q int) (int, error)
|
NewPacketArrays(batchSize int) []TunPacket
|
||||||
|
|
||||||
|
ReadMany(x []TunPacket, q int) (int, error)
|
||||||
|
RecycleRxSeg(pkt TunPacket, kick bool, q int) error
|
||||||
|
|
||||||
//todo this interface sux
|
//todo this interface sux
|
||||||
AllocSeg(pkt *packet.OutPacket, q int) (int, error)
|
AllocSeg(pkt *packet.OutPacket, q int) (int, error)
|
||||||
WriteOne(x *packet.OutPacket, kick bool, q int) (int, error)
|
WriteOne(x *packet.OutPacket, kick bool, q int) (int, error)
|
||||||
WriteMany(x []*packet.OutPacket, q int) (int, error)
|
WriteMany(x []*packet.OutPacket, q int) (int, error)
|
||||||
RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: We may be able to remove routines
|
// TODO: We may be able to remove routines
|
||||||
@@ -31,8 +33,8 @@ type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefi
|
|||||||
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||||
switch {
|
switch {
|
||||||
case c.GetBool("tun.disabled", false):
|
case c.GetBool("tun.disabled", false):
|
||||||
tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
|
t := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
|
||||||
return tun, nil
|
return t, nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return newTun(c, l, vpnNetworks, routines > 1)
|
return newTun(c, l, vpnNetworks, routines > 1)
|
||||||
|
|||||||
@@ -24,7 +24,11 @@ type disabledTun struct {
|
|||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*disabledTun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
|
func (t *disabledTun) NewPacketArrays(batchSize int) []TunPacket {
|
||||||
|
panic("implement me") //TODO
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*disabledTun) RecycleRxSeg(pkt TunPacket, kick bool, q int) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -131,8 +135,8 @@ func (t *disabledTun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
|
|||||||
return 0, fmt.Errorf("tun_disabled: WriteMany not implemented")
|
return 0, fmt.Errorf("tun_disabled: WriteMany not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *disabledTun) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
|
func (t *disabledTun) ReadMany(b []TunPacket, _ int) (int, error) {
|
||||||
return t.Read(b[0].Payload)
|
return t.Read(b[0].GetPayload())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *disabledTun) NewMultiQueueReader() (TunDev, error) {
|
func (t *disabledTun) NewMultiQueueReader() (TunDev, error) {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -183,6 +184,14 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) NewPacketArrays(batchSize int) []TunPacket {
|
||||||
|
inPackets := make([]TunPacket, batchSize)
|
||||||
|
for i := 0; i < batchSize; i++ {
|
||||||
|
inPackets[i] = vhostnet.NewVIO()
|
||||||
|
}
|
||||||
|
return inPackets
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) reload(c *config.C, initial bool) error {
|
func (t *tun) reload(c *config.C, initial bool) error {
|
||||||
routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -725,12 +734,25 @@ func (t *tun) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) ReadMany(p []*packet.VirtIOPacket, q int) (int, error) {
|
func (t *tun) ReadMany(p []TunPacket, q int) (int, error) {
|
||||||
n, err := t.vdev[q].ReceivePackets(p) //we are TXing
|
err := t.vdev[q].ReceiveQueue.WaitForUsedElements(context.TODO())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return n, nil
|
i := 0
|
||||||
|
for i = 0; i < len(p); i++ {
|
||||||
|
item, ok := t.vdev[q].ReceiveQueue.TakeSingleNoBlock()
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
pkt := p[i].(*vhostnet.VirtIOPacket) //todo I'm not happy about this but I don't want to change how memory is "owned" rn
|
||||||
|
_, err = t.vdev[q].ProcessRxChain(pkt, item)
|
||||||
|
if err != nil {
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
return i, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Write(b []byte) (int, error) {
|
func (t *tun) Write(b []byte) (int, error) {
|
||||||
@@ -783,6 +805,9 @@ func (t *tun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
|
|||||||
return maximum, nil
|
return maximum, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
|
func (t *tun) RecycleRxSeg(pkt TunPacket, kick bool, q int) error {
|
||||||
return t.vdev[q].ReceiveQueue.OfferDescriptorChains(pkt.Chains, kick)
|
vpkt := pkt.(*vhostnet.VirtIOPacket)
|
||||||
|
err := t.vdev[q].ReceiveQueue.OfferDescriptorChains(vpkt.Chains, kick)
|
||||||
|
vpkt.Reset() //intentionally ignoring err!
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -106,7 +106,7 @@ func (t *TestTun) Name() string {
|
|||||||
return t.Device
|
return t.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TestTun) ReadMany(x []*packet.VirtIOPacket, q int) (int, error) {
|
func (t *TestTun) ReadMany(x []TunPacket, q int) (int, error) {
|
||||||
p, ok := <-t.rxPackets
|
p, ok := <-t.rxPackets
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, os.ErrClosed
|
return 0, os.ErrClosed
|
||||||
@@ -165,7 +165,7 @@ func (t *TestTun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
|
|||||||
return len(x), nil
|
return len(x), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TestTun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
|
func (t *TestTun) RecycleRxSeg(pkt *TunPacket, kick bool, q int) error {
|
||||||
//todo this ought to maybe track something
|
//todo this ought to maybe track something
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,7 +38,18 @@ type UserDevice struct {
|
|||||||
inboundWriter *io.PipeWriter
|
inboundWriter *io.PipeWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *UserDevice) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
|
func (d *UserDevice) NewPacketArrays(batchSize int) []TunPacket {
|
||||||
|
//inPackets := make([]TunPacket, batchSize)
|
||||||
|
//outPackets := make([]OutPacket, batchSize)
|
||||||
|
panic("not implemented") //todo!
|
||||||
|
//for i := 0; i < batchSize; i++ {
|
||||||
|
// inPackets[i] = vhostnet.NewVIO()
|
||||||
|
// outPackets[i] = packet.New(false)
|
||||||
|
//}
|
||||||
|
//return inPackets, outPackets
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *UserDevice) RecycleRxSeg(pkt TunPacket, kick bool, q int) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,8 +87,12 @@ func (d *UserDevice) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *UserDevice) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
|
func (d *UserDevice) ReadMany(b []TunPacket, _ int) (int, error) {
|
||||||
return d.Read(b[0].Payload)
|
_, err := d.Read(b[0].GetPayload())
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return 1, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *UserDevice) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
|
func (d *UserDevice) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ func NewDevice(options ...Option) (*Device, error) {
|
|||||||
return nil, fmt.Errorf("set transmit queue backend: %w", err)
|
return nil, fmt.Errorf("set transmit queue backend: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fully populate the receive queue with available buffers which the device
|
// Fully populate the rx queue with available buffers which the device
|
||||||
// can write new packets into.
|
// can write new packets into.
|
||||||
if err = dev.refillReceiveQueue(); err != nil {
|
if err = dev.refillReceiveQueue(); err != nil {
|
||||||
return nil, fmt.Errorf("refill receive queue: %w", err)
|
return nil, fmt.Errorf("refill receive queue: %w", err)
|
||||||
@@ -198,11 +198,8 @@ func (dev *Device) Close() error {
|
|||||||
// createQueue creates a new virtqueue and registers it with the vhost device
|
// createQueue creates a new virtqueue and registers it with the vhost device
|
||||||
// using the given index.
|
// using the given index.
|
||||||
func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*virtqueue.SplitQueue, error) {
|
func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*virtqueue.SplitQueue, error) {
|
||||||
var (
|
queue, err := virtqueue.NewSplitQueue(queueSize, itemSize)
|
||||||
queue *virtqueue.SplitQueue
|
if err != nil {
|
||||||
err error
|
|
||||||
)
|
|
||||||
if queue, err = virtqueue.NewSplitQueue(queueSize, itemSize); err != nil {
|
|
||||||
return nil, fmt.Errorf("create virtqueue: %w", err)
|
return nil, fmt.Errorf("create virtqueue: %w", err)
|
||||||
}
|
}
|
||||||
if err = vhost.RegisterQueue(controlFD, uint32(queueIndex), queue); err != nil {
|
if err = vhost.RegisterQueue(controlFD, uint32(queueIndex), queue); err != nil {
|
||||||
@@ -218,10 +215,10 @@ func (dev *Device) GetPacketForTx() (uint16, []byte, error) {
|
|||||||
idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
|
idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
|
||||||
if err == virtqueue.ErrNotEnoughFreeDescriptors {
|
if err == virtqueue.ErrNotEnoughFreeDescriptors {
|
||||||
dev.fullTable = true
|
dev.fullTable = true
|
||||||
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
|
idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO())
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
|
idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO())
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, nil, fmt.Errorf("transmit queue: %w", err)
|
return 0, nil, fmt.Errorf("transmit queue: %w", err)
|
||||||
@@ -271,18 +268,15 @@ func (dev *Device) TransmitPackets(pkts []*packet.OutPacket) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// processChains processes as many chains as needed to create one packet. The number of processed chains is returned.
|
// ProcessRxChain processes a single chain to create one packet. The number of processed chains is returned.
|
||||||
func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.UsedElement) (int, error) {
|
func (dev *Device) ProcessRxChain(pkt *VirtIOPacket, chain virtqueue.UsedElement) (int, error) {
|
||||||
//read first element to see how many descriptors we need:
|
//read first element to see how many descriptors we need:
|
||||||
pkt.Reset()
|
pkt.Reset()
|
||||||
|
idx := uint16(chain.DescriptorIndex)
|
||||||
err := dev.ReceiveQueue.GetDescriptorInbuffers(uint16(chains[0].DescriptorIndex), &pkt.ChainRefs)
|
buf, err := dev.ReceiveQueue.GetDescriptorItem(idx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("get descriptor chain: %w", err)
|
return 0, fmt.Errorf("get descriptor chain: %w", err)
|
||||||
}
|
}
|
||||||
if len(pkt.ChainRefs) == 0 {
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// The specification requires that the first descriptor chain starts
|
// The specification requires that the first descriptor chain starts
|
||||||
// with a virtio-net header. It is not clear, whether it is also
|
// with a virtio-net header. It is not clear, whether it is also
|
||||||
@@ -290,7 +284,7 @@ func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.Us
|
|||||||
// descriptor chain, but it is reasonable to assume that this is
|
// descriptor chain, but it is reasonable to assume that this is
|
||||||
// always the case.
|
// always the case.
|
||||||
// The decode method already does the buffer length check.
|
// The decode method already does the buffer length check.
|
||||||
if err = pkt.Header.Decode(pkt.ChainRefs[0][0:]); err != nil {
|
if err = pkt.header.Decode(buf); err != nil {
|
||||||
// The device misbehaved. There is no way we can gracefully
|
// The device misbehaved. There is no way we can gracefully
|
||||||
// recover from this, because we don't know how many of the
|
// recover from this, because we don't know how many of the
|
||||||
// following descriptor chains belong to this packet.
|
// following descriptor chains belong to this packet.
|
||||||
@@ -298,72 +292,44 @@ func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.Us
|
|||||||
}
|
}
|
||||||
|
|
||||||
//we have the header now: what do we need to do?
|
//we have the header now: what do we need to do?
|
||||||
if int(pkt.Header.NumBuffers) > len(chains) {
|
if int(pkt.header.NumBuffers) > 1 {
|
||||||
return 0, fmt.Errorf("number of buffers is greater than number of chains %d", len(chains))
|
return 0, fmt.Errorf("number of buffers is greater than number of chains %d", 1)
|
||||||
}
|
}
|
||||||
if int(pkt.Header.NumBuffers) != 1 {
|
if int(pkt.header.NumBuffers) != 1 {
|
||||||
return 0, fmt.Errorf("too smol-brain to handle more than one chain right now: %d chains", len(chains))
|
return 0, fmt.Errorf("too smol-brain to handle more than one buffer per chain item right now: %d chains, %d bufs", 1, int(pkt.header.NumBuffers))
|
||||||
}
|
}
|
||||||
if chains[0].Length > 16000 {
|
if chain.Length > 16000 {
|
||||||
//todo!
|
//todo!
|
||||||
return 1, fmt.Errorf("too big packet length: %d", chains[0].Length)
|
return 1, fmt.Errorf("too big packet length: %d", chain.Length)
|
||||||
}
|
}
|
||||||
|
|
||||||
//shift the buffer out of out:
|
//shift the buffer out of out:
|
||||||
pkt.Payload = pkt.ChainRefs[0][virtio.NetHdrSize:chains[0].Length]
|
pkt.payload = buf[virtio.NetHdrSize:chain.Length]
|
||||||
pkt.Chains = append(pkt.Chains, uint16(chains[0].DescriptorIndex))
|
pkt.Chains = append(pkt.Chains, idx)
|
||||||
return 1, nil
|
return 1, nil
|
||||||
|
|
||||||
//cursor := n - virtio.NetHdrSize
|
|
||||||
//
|
|
||||||
//if uint32(n) >= chains[0].Length && pkt.Header.NumBuffers == 1 {
|
|
||||||
// pkt.Payload = pkt.Payload[:chains[0].Length-virtio.NetHdrSize]
|
|
||||||
// return 1, nil
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//i := 1
|
|
||||||
//// we used chain 0 already
|
|
||||||
//for i = 1; i < len(chains); i++ {
|
|
||||||
// n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:], int(chains[i].Length))
|
|
||||||
// if err != nil {
|
|
||||||
// // When this fails we may miss to free some descriptor chains. We
|
|
||||||
// // could try to mitigate this by deferring the freeing somehow, but
|
|
||||||
// // it's not worth the hassle. When this method fails, the queue will
|
|
||||||
// // be in a broken state anyway.
|
|
||||||
// return i, fmt.Errorf("get descriptor chain: %w", err)
|
|
||||||
// }
|
|
||||||
// cursor += n
|
|
||||||
//}
|
|
||||||
////todo this has to be wrong
|
|
||||||
//pkt.Payload = pkt.Payload[:cursor]
|
|
||||||
//return i, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) {
|
type VirtIOPacket struct {
|
||||||
//todo optimize?
|
payload []byte
|
||||||
var chains []virtqueue.UsedElement
|
header virtio.NetHdr
|
||||||
var err error
|
Chains []uint16
|
||||||
|
}
|
||||||
chains, err = dev.ReceiveQueue.BlockAndGetHeadsCapped(context.TODO(), len(out))
|
|
||||||
if err != nil {
|
func NewVIO() *VirtIOPacket {
|
||||||
return 0, err
|
out := new(VirtIOPacket)
|
||||||
}
|
out.payload = nil
|
||||||
if len(chains) == 0 {
|
out.Chains = make([]uint16, 0, 8)
|
||||||
return 0, nil
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
numPackets := 0
|
func (v *VirtIOPacket) Reset() {
|
||||||
chainsIdx := 0
|
v.payload = nil
|
||||||
for numPackets = 0; chainsIdx < len(chains); numPackets++ {
|
v.Chains = v.Chains[:0]
|
||||||
if numPackets >= len(out) {
|
}
|
||||||
return numPackets, fmt.Errorf("dropping %d packets, no room", len(chains)-numPackets)
|
|
||||||
}
|
func (v *VirtIOPacket) GetPayload() []byte {
|
||||||
numChains, err := dev.processChains(out[numPackets], chains[chainsIdx:])
|
return v.payload
|
||||||
if err != nil {
|
}
|
||||||
return 0, err
|
func (v *VirtIOPacket) SetPayload(x []byte) {
|
||||||
}
|
v.payload = x //todo?
|
||||||
chainsIdx += numChains
|
|
||||||
}
|
|
||||||
|
|
||||||
return numPackets, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,10 +10,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// ErrDescriptorChainEmpty is returned when a descriptor chain would contain
|
|
||||||
// no buffers, which is not allowed.
|
|
||||||
ErrDescriptorChainEmpty = errors.New("empty descriptor chains are not allowed")
|
|
||||||
|
|
||||||
// ErrNotEnoughFreeDescriptors is returned when the free descriptors are
|
// ErrNotEnoughFreeDescriptors is returned when the free descriptors are
|
||||||
// exhausted, meaning that the queue is full.
|
// exhausted, meaning that the queue is full.
|
||||||
ErrNotEnoughFreeDescriptors = errors.New("not enough free descriptors, queue is full")
|
ErrNotEnoughFreeDescriptors = errors.New("not enough free descriptors, queue is full")
|
||||||
@@ -272,59 +268,6 @@ func (dt *DescriptorTable) createDescriptorForInputs() (uint16, error) {
|
|||||||
return head, nil
|
return head, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Implement a zero-copy variant of createDescriptorChain?
|
|
||||||
|
|
||||||
// getDescriptorChain returns the device-readable buffers (out buffers) and
|
|
||||||
// device-writable buffers (in buffers) of the descriptor chain that starts with
|
|
||||||
// the given head index. The descriptor chain must have been created using
|
|
||||||
// [createDescriptorChain] and must not have been freed yet (meaning that the
|
|
||||||
// head index must not be contained in the free chain).
|
|
||||||
//
|
|
||||||
// Be careful to only access the returned buffer slices when the device has not
|
|
||||||
// yet or is no longer using them. They must not be accessed after
|
|
||||||
// [freeDescriptorChain] has been called.
|
|
||||||
func (dt *DescriptorTable) getDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
|
|
||||||
if int(head) > len(dt.descriptors) {
|
|
||||||
return nil, nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate over the chain. The iteration is limited to the queue size to
|
|
||||||
// avoid ending up in an endless loop when things go very wrong.
|
|
||||||
next := head
|
|
||||||
for range len(dt.descriptors) {
|
|
||||||
if next == dt.freeHeadIndex {
|
|
||||||
return nil, nil, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
desc := &dt.descriptors[next]
|
|
||||||
|
|
||||||
// The descriptor address points to memory not managed by Go, so this
|
|
||||||
// conversion is safe. See https://github.com/golang/go/issues/58625
|
|
||||||
//goland:noinspection GoVetUnsafePointer
|
|
||||||
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
|
|
||||||
|
|
||||||
if desc.flags&descriptorFlagWritable == 0 {
|
|
||||||
outBuffers = append(outBuffers, bs)
|
|
||||||
} else {
|
|
||||||
inBuffers = append(inBuffers, bs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Is this the tail of the chain?
|
|
||||||
if desc.flags&descriptorFlagHasNext == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// Detect loops.
|
|
||||||
if desc.next == head {
|
|
||||||
return nil, nil, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
next = desc.next
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) {
|
func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) {
|
||||||
if int(head) > len(dt.descriptors) {
|
if int(head) > len(dt.descriptors) {
|
||||||
return nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
return nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
||||||
@@ -339,121 +282,6 @@ func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) {
|
|||||||
return bs, nil
|
return bs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dt *DescriptorTable) getDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
|
|
||||||
if int(head) > len(dt.descriptors) {
|
|
||||||
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate over the chain. The iteration is limited to the queue size to
|
|
||||||
// avoid ending up in an endless loop when things go very wrong.
|
|
||||||
next := head
|
|
||||||
for range len(dt.descriptors) {
|
|
||||||
if next == dt.freeHeadIndex {
|
|
||||||
return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
desc := &dt.descriptors[next]
|
|
||||||
|
|
||||||
// The descriptor address points to memory not managed by Go, so this
|
|
||||||
// conversion is safe. See https://github.com/golang/go/issues/58625
|
|
||||||
//goland:noinspection GoVetUnsafePointer
|
|
||||||
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
|
|
||||||
|
|
||||||
if desc.flags&descriptorFlagWritable == 0 {
|
|
||||||
return fmt.Errorf("there should not be an outbuffer in %d", head)
|
|
||||||
} else {
|
|
||||||
*inBuffers = append(*inBuffers, bs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Is this the tail of the chain?
|
|
||||||
if desc.flags&descriptorFlagHasNext == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// Detect loops.
|
|
||||||
if desc.next == head {
|
|
||||||
return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
next = desc.next
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// freeDescriptorChain can be used to free a descriptor chain when it is no
|
|
||||||
// longer in use. The descriptor chain that starts with the given index will be
|
|
||||||
// put back into the free chain, so the descriptors can be used for later calls
|
|
||||||
// of [createDescriptorChain].
|
|
||||||
// The descriptor chain must have been created using [createDescriptorChain] and
|
|
||||||
// must not have been freed yet (meaning that the head index must not be
|
|
||||||
// contained in the free chain).
|
|
||||||
func (dt *DescriptorTable) freeDescriptorChain(head uint16) error {
|
|
||||||
if int(head) > len(dt.descriptors) {
|
|
||||||
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate over the chain. The iteration is limited to the queue size to
|
|
||||||
// avoid ending up in an endless loop when things go very wrong.
|
|
||||||
next := head
|
|
||||||
var tailDesc *Descriptor
|
|
||||||
var chainLen uint16
|
|
||||||
for range len(dt.descriptors) {
|
|
||||||
if next == dt.freeHeadIndex {
|
|
||||||
return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
desc := &dt.descriptors[next]
|
|
||||||
chainLen++
|
|
||||||
|
|
||||||
// Set the length of all unused descriptors back to zero.
|
|
||||||
desc.length = 0
|
|
||||||
|
|
||||||
// Unset all flags except the next flag.
|
|
||||||
desc.flags &= descriptorFlagHasNext
|
|
||||||
|
|
||||||
// Is this the tail of the chain?
|
|
||||||
if desc.flags&descriptorFlagHasNext == 0 {
|
|
||||||
tailDesc = desc
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// Detect loops.
|
|
||||||
if desc.next == head {
|
|
||||||
return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
next = desc.next
|
|
||||||
}
|
|
||||||
if tailDesc == nil {
|
|
||||||
// A descriptor chain longer than the queue size but without loops
|
|
||||||
// should be impossible.
|
|
||||||
panic(fmt.Sprintf("could not find a tail for descriptor chain starting at %d", head))
|
|
||||||
}
|
|
||||||
|
|
||||||
// The tail descriptor does not have the next flag set, but when it comes
|
|
||||||
// back into the free chain, it should have.
|
|
||||||
tailDesc.flags = descriptorFlagHasNext
|
|
||||||
|
|
||||||
if dt.freeHeadIndex == noFreeHead {
|
|
||||||
// The whole free chain was used up, so we turn this returned descriptor
|
|
||||||
// chain into the new free chain by completing the circle and using its
|
|
||||||
// head.
|
|
||||||
tailDesc.next = head
|
|
||||||
dt.freeHeadIndex = head
|
|
||||||
} else {
|
|
||||||
// Attach the returned chain at the beginning of the free chain but
|
|
||||||
// right after the free chain head.
|
|
||||||
freeHeadDesc := &dt.descriptors[dt.freeHeadIndex]
|
|
||||||
tailDesc.next = freeHeadDesc.next
|
|
||||||
freeHeadDesc.next = head
|
|
||||||
}
|
|
||||||
|
|
||||||
dt.freeNum += chainLen
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkUnusedDescriptorLength asserts that the length of an unused descriptor
|
// checkUnusedDescriptorLength asserts that the length of an unused descriptor
|
||||||
// is zero, as it should be.
|
// is zero, as it should be.
|
||||||
// This is not a requirement by the virtio spec but rather a thing we do to
|
// This is not a requirement by the virtio spec but rather a thing we do to
|
||||||
|
|||||||
@@ -128,8 +128,7 @@ func NewSplitQueue(queueSize int, itemSize int) (_ *SplitQueue, err error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Consume used buffer notifications in the background.
|
sq.stop = sq.kickSelfToExit()
|
||||||
sq.stop = sq.startConsumeUsedRing()
|
|
||||||
|
|
||||||
return &sq, nil
|
return &sq, nil
|
||||||
}
|
}
|
||||||
@@ -169,9 +168,7 @@ func (sq *SplitQueue) CallEventFD() int {
|
|||||||
return sq.callEventFD.FD()
|
return sq.callEventFD.FD()
|
||||||
}
|
}
|
||||||
|
|
||||||
// startConsumeUsedRing starts a goroutine that runs [consumeUsedRing].
|
func (sq *SplitQueue) kickSelfToExit() func() error {
|
||||||
// A function is returned that can be used to gracefully cancel it. todo rename
|
|
||||||
func (sq *SplitQueue) startConsumeUsedRing() func() error {
|
|
||||||
return func() error {
|
return func() error {
|
||||||
|
|
||||||
// The goroutine blocks until it receives a signal on the event file
|
// The goroutine blocks until it receives a signal on the event file
|
||||||
@@ -185,7 +182,15 @@ func (sq *SplitQueue) startConsumeUsedRing() func() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) {
|
func (sq *SplitQueue) TakeSingleIndex(ctx context.Context) (uint16, error) {
|
||||||
|
element, err := sq.TakeSingle(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0xffff, err
|
||||||
|
}
|
||||||
|
return element.GetHead(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sq *SplitQueue) TakeSingle(ctx context.Context) (UsedElement, error) {
|
||||||
var n int
|
var n int
|
||||||
var err error
|
var err error
|
||||||
for ctx.Err() == nil {
|
for ctx.Err() == nil {
|
||||||
@@ -195,7 +200,7 @@ func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) {
|
|||||||
}
|
}
|
||||||
// Wait for a signal from the device.
|
// Wait for a signal from the device.
|
||||||
if n, err = sq.epoll.Block(); err != nil {
|
if n, err = sq.epoll.Block(); err != nil {
|
||||||
return 0, fmt.Errorf("wait: %w", err)
|
return UsedElement{}, fmt.Errorf("wait: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
@@ -208,7 +213,31 @@ func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return 0, ctx.Err()
|
return UsedElement{}, ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sq *SplitQueue) TakeSingleNoBlock() (UsedElement, bool) {
|
||||||
|
return sq.usedRing.takeOne()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sq *SplitQueue) WaitForUsedElements(ctx context.Context) error {
|
||||||
|
if sq.usedRing.availableToTake() != 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for ctx.Err() == nil {
|
||||||
|
// Wait for a signal from the device.
|
||||||
|
n, err := sq.epoll.Block()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("wait: %w", err)
|
||||||
|
}
|
||||||
|
if n > 0 {
|
||||||
|
_ = sq.epoll.Clear()
|
||||||
|
if sq.usedRing.availableToTake() != 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) ([]UsedElement, error) {
|
func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) ([]UsedElement, error) {
|
||||||
@@ -235,7 +264,7 @@ func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int)
|
|||||||
return nil, fmt.Errorf("wait: %w", err)
|
return nil, fmt.Errorf("wait: %w", err)
|
||||||
}
|
}
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
_ = sq.epoll.Clear() //???
|
_ = sq.epoll.Clear()
|
||||||
stillNeedToTake, out = sq.usedRing.take(maxToTake)
|
stillNeedToTake, out = sq.usedRing.take(maxToTake)
|
||||||
sq.more = stillNeedToTake
|
sq.more = stillNeedToTake
|
||||||
return out, nil
|
return out, nil
|
||||||
@@ -296,16 +325,14 @@ func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) {
|
|||||||
sq.availableRing.offerSingle(head)
|
sq.availableRing.offerSingle(head)
|
||||||
|
|
||||||
// Notify the device to make it process the updated available ring.
|
// Notify the device to make it process the updated available ring.
|
||||||
if err := sq.kickEventFD.Kick(); err != nil {
|
if err = sq.kickEventFD.Kick(); err != nil {
|
||||||
return head, fmt.Errorf("notify device: %w", err)
|
return head, fmt.Errorf("notify device: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return head, nil
|
return head, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDescriptorChain returns the device-readable buffers (out buffers) and
|
// GetDescriptorItem returns the buffer of a given index
|
||||||
// device-writable buffers (in buffers) of the descriptor chain with the given
|
|
||||||
// head index.
|
|
||||||
// The head index must be one that was returned by a previous call to
|
// The head index must be one that was returned by a previous call to
|
||||||
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
|
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
|
||||||
// freed yet.
|
// freed yet.
|
||||||
@@ -313,37 +340,11 @@ func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) {
|
|||||||
// Be careful to only access the returned buffer slices when the device is no
|
// Be careful to only access the returned buffer slices when the device is no
|
||||||
// longer using them. They must not be accessed after
|
// longer using them. They must not be accessed after
|
||||||
// [SplitQueue.FreeDescriptorChain] has been called.
|
// [SplitQueue.FreeDescriptorChain] has been called.
|
||||||
func (sq *SplitQueue) GetDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
|
|
||||||
return sq.descriptorTable.getDescriptorChain(head)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sq *SplitQueue) GetDescriptorItem(head uint16) ([]byte, error) {
|
func (sq *SplitQueue) GetDescriptorItem(head uint16) ([]byte, error) {
|
||||||
sq.descriptorTable.descriptors[head].length = uint32(sq.descriptorTable.itemSize)
|
sq.descriptorTable.descriptors[head].length = uint32(sq.descriptorTable.itemSize)
|
||||||
return sq.descriptorTable.getDescriptorItem(head)
|
return sq.descriptorTable.getDescriptorItem(head)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sq *SplitQueue) GetDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
|
|
||||||
return sq.descriptorTable.getDescriptorInbuffers(head, inBuffers)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FreeDescriptorChain frees the descriptor chain with the given head index.
|
|
||||||
// The head index must be one that was returned by a previous call to
|
|
||||||
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
|
|
||||||
// freed yet.
|
|
||||||
//
|
|
||||||
// This creates new room in the queue which can be used by following
|
|
||||||
// [SplitQueue.OfferDescriptorChain] calls.
|
|
||||||
// When there are outstanding calls for [SplitQueue.OfferDescriptorChain] that
|
|
||||||
// are waiting for free room in the queue, they may become unblocked by this.
|
|
||||||
func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
|
|
||||||
//not called under lock
|
|
||||||
if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
|
|
||||||
return fmt.Errorf("free: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sq *SplitQueue) SetDescSize(head uint16, sz int) {
|
func (sq *SplitQueue) SetDescSize(head uint16, sz int) {
|
||||||
//not called under lock
|
//not called under lock
|
||||||
sq.descriptorTable.descriptors[int(head)].length = uint32(sz)
|
sq.descriptorTable.descriptors[int(head)].length = uint32(sz)
|
||||||
|
|||||||
@@ -84,17 +84,11 @@ func (r *UsedRing) Address() uintptr {
|
|||||||
return uintptr(unsafe.Pointer(r.flags))
|
return uintptr(unsafe.Pointer(r.flags))
|
||||||
}
|
}
|
||||||
|
|
||||||
// take returns all new [UsedElement]s that the device put into the ring and
|
func (r *UsedRing) availableToTake() int {
|
||||||
// that weren't already returned by a previous call to this method.
|
|
||||||
// had a lock, I removed it
|
|
||||||
func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
|
|
||||||
//r.mu.Lock()
|
|
||||||
//defer r.mu.Unlock()
|
|
||||||
|
|
||||||
ringIndex := *r.ringIndex
|
ringIndex := *r.ringIndex
|
||||||
if ringIndex == r.lastIndex {
|
if ringIndex == r.lastIndex {
|
||||||
// Nothing new.
|
// Nothing new.
|
||||||
return 0, nil
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate the number new used elements that we can read from the ring.
|
// Calculate the number new used elements that we can read from the ring.
|
||||||
@@ -103,6 +97,16 @@ func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
|
|||||||
if count < 0 {
|
if count < 0 {
|
||||||
count += 0xffff
|
count += 0xffff
|
||||||
}
|
}
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
// take returns all new [UsedElement]s that the device put into the ring and
|
||||||
|
// that weren't already returned by a previous call to this method.
|
||||||
|
func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
|
||||||
|
count := r.availableToTake()
|
||||||
|
if count == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
stillNeedToTake := 0
|
stillNeedToTake := 0
|
||||||
|
|
||||||
@@ -128,21 +132,13 @@ func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
|
|||||||
return stillNeedToTake, elems
|
return stillNeedToTake, elems
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *UsedRing) takeOne() (uint16, bool) {
|
func (r *UsedRing) takeOne() (UsedElement, bool) {
|
||||||
//r.mu.Lock()
|
//r.mu.Lock()
|
||||||
//defer r.mu.Unlock()
|
//defer r.mu.Unlock()
|
||||||
|
|
||||||
ringIndex := *r.ringIndex
|
count := r.availableToTake()
|
||||||
if ringIndex == r.lastIndex {
|
if count == 0 {
|
||||||
// Nothing new.
|
return UsedElement{}, false
|
||||||
return 0xffff, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate the number new used elements that we can read from the ring.
|
|
||||||
// The ring index may wrap, so special handling for that case is needed.
|
|
||||||
count := int(ringIndex - r.lastIndex)
|
|
||||||
if count < 0 {
|
|
||||||
count += 0xffff
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// The number of new elements can never exceed the queue size.
|
// The number of new elements can never exceed the queue size.
|
||||||
@@ -150,11 +146,7 @@ func (r *UsedRing) takeOne() (uint16, bool) {
|
|||||||
panic("used ring contains more new elements than the ring is long")
|
panic("used ring contains more new elements than the ring is long")
|
||||||
}
|
}
|
||||||
|
|
||||||
if count == 0 {
|
out := r.ring[r.lastIndex%uint16(len(r.ring))]
|
||||||
return 0xffff, false
|
|
||||||
}
|
|
||||||
|
|
||||||
out := r.ring[r.lastIndex%uint16(len(r.ring))].GetHead()
|
|
||||||
r.lastIndex++
|
r.lastIndex++
|
||||||
|
|
||||||
return out, true
|
return out, true
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
|
|
||||||
const Size = 0xffff
|
const Size = 0xffff
|
||||||
|
|
||||||
type Packet struct {
|
type UDPPacket struct {
|
||||||
Payload []byte
|
Payload []byte
|
||||||
Control []byte
|
Control []byte
|
||||||
Name []byte
|
Name []byte
|
||||||
@@ -25,8 +25,8 @@ type Packet struct {
|
|||||||
isV4 bool
|
isV4 bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(isV4 bool) *Packet {
|
func New(isV4 bool) *UDPPacket {
|
||||||
return &Packet{
|
return &UDPPacket{
|
||||||
Payload: make([]byte, Size),
|
Payload: make([]byte, Size),
|
||||||
Control: make([]byte, unix.CmsgSpace(2)),
|
Control: make([]byte, unix.CmsgSpace(2)),
|
||||||
Name: make([]byte, unix.SizeofSockaddrInet6),
|
Name: make([]byte, unix.SizeofSockaddrInet6),
|
||||||
@@ -34,7 +34,7 @@ func New(isV4 bool) *Packet {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Packet) AddrPort() netip.AddrPort {
|
func (p *UDPPacket) AddrPort() netip.AddrPort {
|
||||||
var ip netip.Addr
|
var ip netip.Addr
|
||||||
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
|
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
|
||||||
if p.isV4 {
|
if p.isV4 {
|
||||||
@@ -45,7 +45,7 @@ func (p *Packet) AddrPort() netip.AddrPort {
|
|||||||
return netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(p.Name[2:4]))
|
return netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(p.Name[2:4]))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Packet) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
|
func (p *UDPPacket) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
|
||||||
//todo no chance this works on windows?
|
//todo no chance this works on windows?
|
||||||
if p.isV4 {
|
if p.isV4 {
|
||||||
if !addr.Addr().Is4() {
|
if !addr.Addr().Is4() {
|
||||||
@@ -69,7 +69,7 @@ func (p *Packet) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error)
|
|||||||
return uint32(size), nil
|
return uint32(size), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Packet) SetAddrPort(addr netip.AddrPort) error {
|
func (p *UDPPacket) SetAddrPort(addr netip.AddrPort) error {
|
||||||
nl, err := p.encodeSockaddr(p.Name, addr)
|
nl, err := p.encodeSockaddr(p.Name, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -78,7 +78,7 @@ func (p *Packet) SetAddrPort(addr netip.AddrPort) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Packet) updateCtrl(ctrlLen int) {
|
func (p *UDPPacket) updateCtrl(ctrlLen int) {
|
||||||
p.SegSize = len(p.Payload)
|
p.SegSize = len(p.Payload)
|
||||||
p.wasSegmented = false
|
p.wasSegmented = false
|
||||||
if ctrlLen == 0 {
|
if ctrlLen == 0 {
|
||||||
@@ -101,12 +101,12 @@ func (p *Packet) updateCtrl(ctrlLen int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update sets a Packet into "just received, not processed" state
|
// Update sets a UDPPacket into "just received, not processed" state
|
||||||
func (p *Packet) Update(ctrlLen int) {
|
func (p *UDPPacket) Update(ctrlLen int) {
|
||||||
p.updateCtrl(ctrlLen)
|
p.updateCtrl(ctrlLen)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Packet) SetSegSizeForTX() {
|
func (p *UDPPacket) SetSegSizeForTX() {
|
||||||
p.SegSize = len(p.Payload)
|
p.SegSize = len(p.Payload)
|
||||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&p.Control[0]))
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&p.Control[0]))
|
||||||
hdr.Level = unix.SOL_UDP
|
hdr.Level = unix.SOL_UDP
|
||||||
@@ -115,7 +115,7 @@ func (p *Packet) SetSegSizeForTX() {
|
|||||||
binary.NativeEndian.PutUint16(p.Control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(p.SegSize))
|
binary.NativeEndian.PutUint16(p.Control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(p.SegSize))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Packet) CompatibleForSegmentationWith(otherP *Packet, currentTotalSize int) bool {
|
func (p *UDPPacket) CompatibleForSegmentationWith(otherP *UDPPacket, currentTotalSize int) bool {
|
||||||
//same dest
|
//same dest
|
||||||
if !slices.Equal(p.Name, otherP.Name) {
|
if !slices.Equal(p.Name, otherP.Name) {
|
||||||
return false
|
return false
|
||||||
@@ -134,7 +134,7 @@ func (p *Packet) CompatibleForSegmentationWith(otherP *Packet, currentTotalSize
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Packet) Segments() iter.Seq[[]byte] {
|
func (p *UDPPacket) Segments() iter.Seq[[]byte] {
|
||||||
return func(yield func([]byte) bool) {
|
return func(yield func([]byte) bool) {
|
||||||
//cursor := 0
|
//cursor := 0
|
||||||
for offset := 0; offset < len(p.Payload); offset += p.SegSize {
|
for offset := 0; offset < len(p.Payload); offset += p.SegSize {
|
||||||
|
|||||||
@@ -1,26 +0,0 @@
|
|||||||
package packet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/slackhq/nebula/util/virtio"
|
|
||||||
)
|
|
||||||
|
|
||||||
type VirtIOPacket struct {
|
|
||||||
Payload []byte
|
|
||||||
Header virtio.NetHdr
|
|
||||||
Chains []uint16
|
|
||||||
ChainRefs [][]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewVIO() *VirtIOPacket {
|
|
||||||
out := new(VirtIOPacket)
|
|
||||||
out.Payload = nil
|
|
||||||
out.ChainRefs = make([][]byte, 0, 4)
|
|
||||||
out.Chains = make([]uint16, 0, 8)
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *VirtIOPacket) Reset() {
|
|
||||||
v.Payload = nil
|
|
||||||
v.ChainRefs = v.ChainRefs[:0]
|
|
||||||
v.Chains = v.Chains[:0]
|
|
||||||
}
|
|
||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
const MTU = 9001
|
const MTU = 9001
|
||||||
|
|
||||||
type EncReader func(
|
type EncReader func(
|
||||||
[]*packet.Packet,
|
[]*packet.UDPPacket,
|
||||||
)
|
)
|
||||||
|
|
||||||
type Conn interface {
|
type Conn interface {
|
||||||
@@ -19,8 +19,8 @@ type Conn interface {
|
|||||||
ListenOut(r EncReader)
|
ListenOut(r EncReader)
|
||||||
WriteTo(b []byte, addr netip.AddrPort) error
|
WriteTo(b []byte, addr netip.AddrPort) error
|
||||||
ReloadConfig(c *config.C)
|
ReloadConfig(c *config.C)
|
||||||
Prep(pkt *packet.Packet, addr netip.AddrPort) error
|
Prep(pkt *packet.UDPPacket, addr netip.AddrPort) error
|
||||||
WriteBatch(pkt []*packet.Packet) (int, error)
|
WriteBatch(pkt []*packet.UDPPacket) (int, error)
|
||||||
SupportsMultipleReaders() bool
|
SupportsMultipleReaders() bool
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -215,7 +215,7 @@ func (u *StdConn) WriteToBatch(b []byte, ip netip.AddrPort) error {
|
|||||||
return u.writeTo6(b, ip)
|
return u.writeTo6(b, ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error {
|
func (u *StdConn) Prep(pkt *packet.UDPPacket, addr netip.AddrPort) error {
|
||||||
//todo move this into pkt
|
//todo move this into pkt
|
||||||
nl, err := u.encodeSockaddr(pkt.Name, addr)
|
nl, err := u.encodeSockaddr(pkt.Name, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -226,7 +226,7 @@ func (u *StdConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) WriteBatch(pkts []*packet.Packet) (int, error) {
|
func (u *StdConn) WriteBatch(pkts []*packet.UDPPacket) (int, error) {
|
||||||
if len(pkts) == 0 {
|
if len(pkts) == 0 {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
@@ -235,7 +235,7 @@ func (u *StdConn) WriteBatch(pkts []*packet.Packet) (int, error) {
|
|||||||
//u.iovs = u.iovs[:0]
|
//u.iovs = u.iovs[:0]
|
||||||
|
|
||||||
sent := 0
|
sent := 0
|
||||||
var mostRecentPkt *packet.Packet
|
var mostRecentPkt *packet.UDPPacket
|
||||||
mostRecentPktSize := 0
|
mostRecentPktSize := 0
|
||||||
//segmenting := false
|
//segmenting := false
|
||||||
idx := 0
|
idx := 0
|
||||||
|
|||||||
@@ -52,9 +52,9 @@ func setCmsgLen(h *unix.Cmsghdr, l int) {
|
|||||||
h.Len = uint64(l)
|
h.Len = uint64(l)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) PrepareRawMessages(n int, isV4 bool) ([]rawMessage, []*packet.Packet) {
|
func (u *StdConn) PrepareRawMessages(n int, isV4 bool) ([]rawMessage, []*packet.UDPPacket) {
|
||||||
msgs := make([]rawMessage, n)
|
msgs := make([]rawMessage, n)
|
||||||
packets := make([]*packet.Packet, n)
|
packets := make([]*packet.UDPPacket, n)
|
||||||
|
|
||||||
for i := range msgs {
|
for i := range msgs {
|
||||||
packets[i] = packet.New(isV4)
|
packets[i] = packet.New(isV4)
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ type TesterConn struct {
|
|||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *TesterConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error {
|
func (u *TesterConn) Prep(pkt *packet.UDPPacket, addr netip.AddrPort) error {
|
||||||
pkt.ReadyToSend = true
|
pkt.ReadyToSend = true
|
||||||
return pkt.SetAddrPort(addr)
|
return pkt.SetAddrPort(addr)
|
||||||
}
|
}
|
||||||
@@ -96,7 +96,7 @@ func (u *TesterConn) Get(block bool) *Packet {
|
|||||||
// Below this is boilerplate implementation to make nebula actually work
|
// Below this is boilerplate implementation to make nebula actually work
|
||||||
//********************************************************************************************************************//
|
//********************************************************************************************************************//
|
||||||
|
|
||||||
func (u *TesterConn) WriteBatch(pkts []*packet.Packet) (int, error) {
|
func (u *TesterConn) WriteBatch(pkts []*packet.UDPPacket) (int, error) {
|
||||||
for _, pkt := range pkts {
|
for _, pkt := range pkts {
|
||||||
if !pkt.ReadyToSend {
|
if !pkt.ReadyToSend {
|
||||||
continue
|
continue
|
||||||
@@ -141,7 +141,7 @@ func (u *TesterConn) ListenOut(r EncReader) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
y := []*packet.Packet{x}
|
y := []*packet.UDPPacket{x}
|
||||||
r(y)
|
r(y)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user