mirror of
https://github.com/slackhq/nebula.git
synced 2025-12-29 10:08:27 +01:00
refactoring a bit
This commit is contained in:
@@ -13,7 +13,7 @@ import (
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.Packet, q int, localCache firewall.ConntrackCache, now time.Time) {
|
||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.UDPPacket, q int, localCache firewall.ConntrackCache, now time.Time) {
|
||||
err := newPacket(packet, false, fwPacket)
|
||||
if err != nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
@@ -412,7 +412,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) sendNoMetricsDelayed(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb []byte, out *packet.Packet, q int) {
|
||||
func (f *Interface) sendNoMetricsDelayed(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb []byte, out *packet.UDPPacket, q int) {
|
||||
if ci.eKey == nil {
|
||||
return
|
||||
}
|
||||
|
||||
13
interface.go
13
interface.go
@@ -294,7 +294,7 @@ func (f *Interface) listenOut(q int) {
|
||||
|
||||
toSend := make([][]byte, batch)
|
||||
|
||||
li.ListenOut(func(pkts []*packet.Packet) {
|
||||
li.ListenOut(func(pkts []*packet.UDPPacket) {
|
||||
toSend = toSend[:0]
|
||||
for i := range outPackets {
|
||||
outPackets[i].SegCounter = 0
|
||||
@@ -323,11 +323,11 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
|
||||
|
||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||
|
||||
packets := make([]*packet.VirtIOPacket, batch)
|
||||
outPackets := make([]*packet.Packet, batch)
|
||||
packets := reader.NewPacketArrays(batch)
|
||||
|
||||
outPackets := make([]*packet.UDPPacket, batch)
|
||||
for i := 0; i < batch; i++ {
|
||||
packets[i] = packet.NewVIO()
|
||||
outPackets[i] = packet.New(false) //todo?
|
||||
outPackets[i] = packet.New(false) //todo isv4?
|
||||
}
|
||||
|
||||
for {
|
||||
@@ -352,9 +352,8 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
|
||||
now := time.Now()
|
||||
for i, pkt := range packets[:n] {
|
||||
outPackets[i].ReadyToSend = false
|
||||
f.consumeInsidePacket(pkt.Payload, fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l), now)
|
||||
f.consumeInsidePacket(pkt.GetPayload(), fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l), now)
|
||||
reader.RecycleRxSeg(pkt, i == (n-1), queueNum) //todo handle err?
|
||||
pkt.Reset()
|
||||
}
|
||||
_, err = f.writers[queueNum].WriteBatch(outPackets[:n])
|
||||
if err != nil {
|
||||
|
||||
@@ -359,7 +359,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe
|
||||
f.connectionManager.In(hostinfo)
|
||||
}
|
||||
|
||||
func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
|
||||
func (f *Interface) readOutsidePacketsMany(packets []*packet.UDPPacket, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
|
||||
for i, pkt := range packets {
|
||||
out[i].Scratch = out[i].Scratch[:0]
|
||||
via := ViaSender{UdpAddr: pkt.AddrPort()}
|
||||
|
||||
36
overlay/packets.go
Normal file
36
overlay/packets.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package overlay
|
||||
|
||||
//import (
|
||||
// "github.com/slackhq/nebula/util/virtio"
|
||||
//)
|
||||
|
||||
//type VirtIOPacket struct {
|
||||
// Payload []byte
|
||||
// Header virtio.NetHdr
|
||||
// Chains []uint16
|
||||
// ChainRefs [][]byte
|
||||
//}
|
||||
//
|
||||
//func NewVIO() *VirtIOPacket {
|
||||
// out := new(VirtIOPacket)
|
||||
// out.Payload = nil
|
||||
// out.ChainRefs = make([][]byte, 0, 4)
|
||||
// out.Chains = make([]uint16, 0, 8)
|
||||
// return out
|
||||
//}
|
||||
//
|
||||
//func (v *VirtIOPacket) Reset() {
|
||||
// v.Payload = nil
|
||||
// v.ChainRefs = v.ChainRefs[:0]
|
||||
// v.Chains = v.Chains[:0]
|
||||
//}
|
||||
|
||||
// TunPacket is formerly VirtIOPacket
|
||||
type TunPacket interface {
|
||||
SetPayload([]byte)
|
||||
GetPayload() []byte
|
||||
}
|
||||
type OutPacket interface {
|
||||
SetPayload([]byte)
|
||||
GetPayload() []byte
|
||||
}
|
||||
@@ -16,13 +16,15 @@ const DefaultMTU = 1300
|
||||
|
||||
type TunDev interface {
|
||||
io.WriteCloser
|
||||
ReadMany(x []*packet.VirtIOPacket, q int) (int, error)
|
||||
NewPacketArrays(batchSize int) []TunPacket
|
||||
|
||||
ReadMany(x []TunPacket, q int) (int, error)
|
||||
RecycleRxSeg(pkt TunPacket, kick bool, q int) error
|
||||
|
||||
//todo this interface sux
|
||||
AllocSeg(pkt *packet.OutPacket, q int) (int, error)
|
||||
WriteOne(x *packet.OutPacket, kick bool, q int) (int, error)
|
||||
WriteMany(x []*packet.OutPacket, q int) (int, error)
|
||||
RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error
|
||||
}
|
||||
|
||||
// TODO: We may be able to remove routines
|
||||
@@ -31,8 +33,8 @@ type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefi
|
||||
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
switch {
|
||||
case c.GetBool("tun.disabled", false):
|
||||
tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
|
||||
return tun, nil
|
||||
t := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
|
||||
return t, nil
|
||||
|
||||
default:
|
||||
return newTun(c, l, vpnNetworks, routines > 1)
|
||||
|
||||
@@ -24,7 +24,11 @@ type disabledTun struct {
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func (*disabledTun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
|
||||
func (t *disabledTun) NewPacketArrays(batchSize int) []TunPacket {
|
||||
panic("implement me") //TODO
|
||||
}
|
||||
|
||||
func (*disabledTun) RecycleRxSeg(pkt TunPacket, kick bool, q int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -131,8 +135,8 @@ func (t *disabledTun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
|
||||
return 0, fmt.Errorf("tun_disabled: WriteMany not implemented")
|
||||
}
|
||||
|
||||
func (t *disabledTun) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
|
||||
return t.Read(b[0].Payload)
|
||||
func (t *disabledTun) ReadMany(b []TunPacket, _ int) (int, error) {
|
||||
return t.Read(b[0].GetPayload())
|
||||
}
|
||||
|
||||
func (t *disabledTun) NewMultiQueueReader() (TunDev, error) {
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -183,6 +184,14 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *tun) NewPacketArrays(batchSize int) []TunPacket {
|
||||
inPackets := make([]TunPacket, batchSize)
|
||||
for i := 0; i < batchSize; i++ {
|
||||
inPackets[i] = vhostnet.NewVIO()
|
||||
}
|
||||
return inPackets
|
||||
}
|
||||
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||
if err != nil {
|
||||
@@ -725,12 +734,25 @@ func (t *tun) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) ReadMany(p []*packet.VirtIOPacket, q int) (int, error) {
|
||||
n, err := t.vdev[q].ReceivePackets(p) //we are TXing
|
||||
func (t *tun) ReadMany(p []TunPacket, q int) (int, error) {
|
||||
err := t.vdev[q].ReceiveQueue.WaitForUsedElements(context.TODO())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return n, nil
|
||||
i := 0
|
||||
for i = 0; i < len(p); i++ {
|
||||
item, ok := t.vdev[q].ReceiveQueue.TakeSingleNoBlock()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
pkt := p[i].(*vhostnet.VirtIOPacket) //todo I'm not happy about this but I don't want to change how memory is "owned" rn
|
||||
_, err = t.vdev[q].ProcessRxChain(pkt, item)
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
i++
|
||||
}
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func (t *tun) Write(b []byte) (int, error) {
|
||||
@@ -783,6 +805,9 @@ func (t *tun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
|
||||
return maximum, nil
|
||||
}
|
||||
|
||||
func (t *tun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
|
||||
return t.vdev[q].ReceiveQueue.OfferDescriptorChains(pkt.Chains, kick)
|
||||
func (t *tun) RecycleRxSeg(pkt TunPacket, kick bool, q int) error {
|
||||
vpkt := pkt.(*vhostnet.VirtIOPacket)
|
||||
err := t.vdev[q].ReceiveQueue.OfferDescriptorChains(vpkt.Chains, kick)
|
||||
vpkt.Reset() //intentionally ignoring err!
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -106,7 +106,7 @@ func (t *TestTun) Name() string {
|
||||
return t.Device
|
||||
}
|
||||
|
||||
func (t *TestTun) ReadMany(x []*packet.VirtIOPacket, q int) (int, error) {
|
||||
func (t *TestTun) ReadMany(x []TunPacket, q int) (int, error) {
|
||||
p, ok := <-t.rxPackets
|
||||
if !ok {
|
||||
return 0, os.ErrClosed
|
||||
@@ -165,7 +165,7 @@ func (t *TestTun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
|
||||
return len(x), nil
|
||||
}
|
||||
|
||||
func (t *TestTun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
|
||||
func (t *TestTun) RecycleRxSeg(pkt *TunPacket, kick bool, q int) error {
|
||||
//todo this ought to maybe track something
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -38,7 +38,18 @@ type UserDevice struct {
|
||||
inboundWriter *io.PipeWriter
|
||||
}
|
||||
|
||||
func (d *UserDevice) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
|
||||
func (d *UserDevice) NewPacketArrays(batchSize int) []TunPacket {
|
||||
//inPackets := make([]TunPacket, batchSize)
|
||||
//outPackets := make([]OutPacket, batchSize)
|
||||
panic("not implemented") //todo!
|
||||
//for i := 0; i < batchSize; i++ {
|
||||
// inPackets[i] = vhostnet.NewVIO()
|
||||
// outPackets[i] = packet.New(false)
|
||||
//}
|
||||
//return inPackets, outPackets
|
||||
}
|
||||
|
||||
func (d *UserDevice) RecycleRxSeg(pkt TunPacket, kick bool, q int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -76,8 +87,12 @@ func (d *UserDevice) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *UserDevice) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
|
||||
return d.Read(b[0].Payload)
|
||||
func (d *UserDevice) ReadMany(b []TunPacket, _ int) (int, error) {
|
||||
_, err := d.Read(b[0].GetPayload())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func (d *UserDevice) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
|
||||
|
||||
@@ -118,7 +118,7 @@ func NewDevice(options ...Option) (*Device, error) {
|
||||
return nil, fmt.Errorf("set transmit queue backend: %w", err)
|
||||
}
|
||||
|
||||
// Fully populate the receive queue with available buffers which the device
|
||||
// Fully populate the rx queue with available buffers which the device
|
||||
// can write new packets into.
|
||||
if err = dev.refillReceiveQueue(); err != nil {
|
||||
return nil, fmt.Errorf("refill receive queue: %w", err)
|
||||
@@ -198,11 +198,8 @@ func (dev *Device) Close() error {
|
||||
// createQueue creates a new virtqueue and registers it with the vhost device
|
||||
// using the given index.
|
||||
func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*virtqueue.SplitQueue, error) {
|
||||
var (
|
||||
queue *virtqueue.SplitQueue
|
||||
err error
|
||||
)
|
||||
if queue, err = virtqueue.NewSplitQueue(queueSize, itemSize); err != nil {
|
||||
queue, err := virtqueue.NewSplitQueue(queueSize, itemSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create virtqueue: %w", err)
|
||||
}
|
||||
if err = vhost.RegisterQueue(controlFD, uint32(queueIndex), queue); err != nil {
|
||||
@@ -218,10 +215,10 @@ func (dev *Device) GetPacketForTx() (uint16, []byte, error) {
|
||||
idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
|
||||
if err == virtqueue.ErrNotEnoughFreeDescriptors {
|
||||
dev.fullTable = true
|
||||
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
|
||||
idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO())
|
||||
}
|
||||
} else {
|
||||
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
|
||||
idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO())
|
||||
}
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("transmit queue: %w", err)
|
||||
@@ -271,18 +268,15 @@ func (dev *Device) TransmitPackets(pkts []*packet.OutPacket) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// processChains processes as many chains as needed to create one packet. The number of processed chains is returned.
|
||||
func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.UsedElement) (int, error) {
|
||||
// ProcessRxChain processes a single chain to create one packet. The number of processed chains is returned.
|
||||
func (dev *Device) ProcessRxChain(pkt *VirtIOPacket, chain virtqueue.UsedElement) (int, error) {
|
||||
//read first element to see how many descriptors we need:
|
||||
pkt.Reset()
|
||||
|
||||
err := dev.ReceiveQueue.GetDescriptorInbuffers(uint16(chains[0].DescriptorIndex), &pkt.ChainRefs)
|
||||
idx := uint16(chain.DescriptorIndex)
|
||||
buf, err := dev.ReceiveQueue.GetDescriptorItem(idx)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get descriptor chain: %w", err)
|
||||
}
|
||||
if len(pkt.ChainRefs) == 0 {
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
// The specification requires that the first descriptor chain starts
|
||||
// with a virtio-net header. It is not clear, whether it is also
|
||||
@@ -290,7 +284,7 @@ func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.Us
|
||||
// descriptor chain, but it is reasonable to assume that this is
|
||||
// always the case.
|
||||
// The decode method already does the buffer length check.
|
||||
if err = pkt.Header.Decode(pkt.ChainRefs[0][0:]); err != nil {
|
||||
if err = pkt.header.Decode(buf); err != nil {
|
||||
// The device misbehaved. There is no way we can gracefully
|
||||
// recover from this, because we don't know how many of the
|
||||
// following descriptor chains belong to this packet.
|
||||
@@ -298,72 +292,44 @@ func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.Us
|
||||
}
|
||||
|
||||
//we have the header now: what do we need to do?
|
||||
if int(pkt.Header.NumBuffers) > len(chains) {
|
||||
return 0, fmt.Errorf("number of buffers is greater than number of chains %d", len(chains))
|
||||
if int(pkt.header.NumBuffers) > 1 {
|
||||
return 0, fmt.Errorf("number of buffers is greater than number of chains %d", 1)
|
||||
}
|
||||
if int(pkt.Header.NumBuffers) != 1 {
|
||||
return 0, fmt.Errorf("too smol-brain to handle more than one chain right now: %d chains", len(chains))
|
||||
if int(pkt.header.NumBuffers) != 1 {
|
||||
return 0, fmt.Errorf("too smol-brain to handle more than one buffer per chain item right now: %d chains, %d bufs", 1, int(pkt.header.NumBuffers))
|
||||
}
|
||||
if chains[0].Length > 16000 {
|
||||
if chain.Length > 16000 {
|
||||
//todo!
|
||||
return 1, fmt.Errorf("too big packet length: %d", chains[0].Length)
|
||||
return 1, fmt.Errorf("too big packet length: %d", chain.Length)
|
||||
}
|
||||
|
||||
//shift the buffer out of out:
|
||||
pkt.Payload = pkt.ChainRefs[0][virtio.NetHdrSize:chains[0].Length]
|
||||
pkt.Chains = append(pkt.Chains, uint16(chains[0].DescriptorIndex))
|
||||
pkt.payload = buf[virtio.NetHdrSize:chain.Length]
|
||||
pkt.Chains = append(pkt.Chains, idx)
|
||||
return 1, nil
|
||||
|
||||
//cursor := n - virtio.NetHdrSize
|
||||
//
|
||||
//if uint32(n) >= chains[0].Length && pkt.Header.NumBuffers == 1 {
|
||||
// pkt.Payload = pkt.Payload[:chains[0].Length-virtio.NetHdrSize]
|
||||
// return 1, nil
|
||||
//}
|
||||
//
|
||||
//i := 1
|
||||
//// we used chain 0 already
|
||||
//for i = 1; i < len(chains); i++ {
|
||||
// n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:], int(chains[i].Length))
|
||||
// if err != nil {
|
||||
// // When this fails we may miss to free some descriptor chains. We
|
||||
// // could try to mitigate this by deferring the freeing somehow, but
|
||||
// // it's not worth the hassle. When this method fails, the queue will
|
||||
// // be in a broken state anyway.
|
||||
// return i, fmt.Errorf("get descriptor chain: %w", err)
|
||||
// }
|
||||
// cursor += n
|
||||
//}
|
||||
////todo this has to be wrong
|
||||
//pkt.Payload = pkt.Payload[:cursor]
|
||||
//return i, nil
|
||||
}
|
||||
|
||||
func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) {
|
||||
//todo optimize?
|
||||
var chains []virtqueue.UsedElement
|
||||
var err error
|
||||
|
||||
chains, err = dev.ReceiveQueue.BlockAndGetHeadsCapped(context.TODO(), len(out))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(chains) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
numPackets := 0
|
||||
chainsIdx := 0
|
||||
for numPackets = 0; chainsIdx < len(chains); numPackets++ {
|
||||
if numPackets >= len(out) {
|
||||
return numPackets, fmt.Errorf("dropping %d packets, no room", len(chains)-numPackets)
|
||||
}
|
||||
numChains, err := dev.processChains(out[numPackets], chains[chainsIdx:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
chainsIdx += numChains
|
||||
}
|
||||
|
||||
return numPackets, nil
|
||||
type VirtIOPacket struct {
|
||||
payload []byte
|
||||
header virtio.NetHdr
|
||||
Chains []uint16
|
||||
}
|
||||
|
||||
func NewVIO() *VirtIOPacket {
|
||||
out := new(VirtIOPacket)
|
||||
out.payload = nil
|
||||
out.Chains = make([]uint16, 0, 8)
|
||||
return out
|
||||
}
|
||||
|
||||
func (v *VirtIOPacket) Reset() {
|
||||
v.payload = nil
|
||||
v.Chains = v.Chains[:0]
|
||||
}
|
||||
|
||||
func (v *VirtIOPacket) GetPayload() []byte {
|
||||
return v.payload
|
||||
}
|
||||
func (v *VirtIOPacket) SetPayload(x []byte) {
|
||||
v.payload = x //todo?
|
||||
}
|
||||
|
||||
@@ -10,10 +10,6 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrDescriptorChainEmpty is returned when a descriptor chain would contain
|
||||
// no buffers, which is not allowed.
|
||||
ErrDescriptorChainEmpty = errors.New("empty descriptor chains are not allowed")
|
||||
|
||||
// ErrNotEnoughFreeDescriptors is returned when the free descriptors are
|
||||
// exhausted, meaning that the queue is full.
|
||||
ErrNotEnoughFreeDescriptors = errors.New("not enough free descriptors, queue is full")
|
||||
@@ -272,59 +268,6 @@ func (dt *DescriptorTable) createDescriptorForInputs() (uint16, error) {
|
||||
return head, nil
|
||||
}
|
||||
|
||||
// TODO: Implement a zero-copy variant of createDescriptorChain?
|
||||
|
||||
// getDescriptorChain returns the device-readable buffers (out buffers) and
|
||||
// device-writable buffers (in buffers) of the descriptor chain that starts with
|
||||
// the given head index. The descriptor chain must have been created using
|
||||
// [createDescriptorChain] and must not have been freed yet (meaning that the
|
||||
// head index must not be contained in the free chain).
|
||||
//
|
||||
// Be careful to only access the returned buffer slices when the device has not
|
||||
// yet or is no longer using them. They must not be accessed after
|
||||
// [freeDescriptorChain] has been called.
|
||||
func (dt *DescriptorTable) getDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
|
||||
if int(head) > len(dt.descriptors) {
|
||||
return nil, nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
// Iterate over the chain. The iteration is limited to the queue size to
|
||||
// avoid ending up in an endless loop when things go very wrong.
|
||||
next := head
|
||||
for range len(dt.descriptors) {
|
||||
if next == dt.freeHeadIndex {
|
||||
return nil, nil, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
desc := &dt.descriptors[next]
|
||||
|
||||
// The descriptor address points to memory not managed by Go, so this
|
||||
// conversion is safe. See https://github.com/golang/go/issues/58625
|
||||
//goland:noinspection GoVetUnsafePointer
|
||||
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
|
||||
|
||||
if desc.flags&descriptorFlagWritable == 0 {
|
||||
outBuffers = append(outBuffers, bs)
|
||||
} else {
|
||||
inBuffers = append(inBuffers, bs)
|
||||
}
|
||||
|
||||
// Is this the tail of the chain?
|
||||
if desc.flags&descriptorFlagHasNext == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Detect loops.
|
||||
if desc.next == head {
|
||||
return nil, nil, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
next = desc.next
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) {
|
||||
if int(head) > len(dt.descriptors) {
|
||||
return nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
||||
@@ -339,121 +282,6 @@ func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) {
|
||||
return bs, nil
|
||||
}
|
||||
|
||||
func (dt *DescriptorTable) getDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
|
||||
if int(head) > len(dt.descriptors) {
|
||||
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
// Iterate over the chain. The iteration is limited to the queue size to
|
||||
// avoid ending up in an endless loop when things go very wrong.
|
||||
next := head
|
||||
for range len(dt.descriptors) {
|
||||
if next == dt.freeHeadIndex {
|
||||
return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
desc := &dt.descriptors[next]
|
||||
|
||||
// The descriptor address points to memory not managed by Go, so this
|
||||
// conversion is safe. See https://github.com/golang/go/issues/58625
|
||||
//goland:noinspection GoVetUnsafePointer
|
||||
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
|
||||
|
||||
if desc.flags&descriptorFlagWritable == 0 {
|
||||
return fmt.Errorf("there should not be an outbuffer in %d", head)
|
||||
} else {
|
||||
*inBuffers = append(*inBuffers, bs)
|
||||
}
|
||||
|
||||
// Is this the tail of the chain?
|
||||
if desc.flags&descriptorFlagHasNext == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Detect loops.
|
||||
if desc.next == head {
|
||||
return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
next = desc.next
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// freeDescriptorChain can be used to free a descriptor chain when it is no
|
||||
// longer in use. The descriptor chain that starts with the given index will be
|
||||
// put back into the free chain, so the descriptors can be used for later calls
|
||||
// of [createDescriptorChain].
|
||||
// The descriptor chain must have been created using [createDescriptorChain] and
|
||||
// must not have been freed yet (meaning that the head index must not be
|
||||
// contained in the free chain).
|
||||
func (dt *DescriptorTable) freeDescriptorChain(head uint16) error {
|
||||
if int(head) > len(dt.descriptors) {
|
||||
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
// Iterate over the chain. The iteration is limited to the queue size to
|
||||
// avoid ending up in an endless loop when things go very wrong.
|
||||
next := head
|
||||
var tailDesc *Descriptor
|
||||
var chainLen uint16
|
||||
for range len(dt.descriptors) {
|
||||
if next == dt.freeHeadIndex {
|
||||
return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
desc := &dt.descriptors[next]
|
||||
chainLen++
|
||||
|
||||
// Set the length of all unused descriptors back to zero.
|
||||
desc.length = 0
|
||||
|
||||
// Unset all flags except the next flag.
|
||||
desc.flags &= descriptorFlagHasNext
|
||||
|
||||
// Is this the tail of the chain?
|
||||
if desc.flags&descriptorFlagHasNext == 0 {
|
||||
tailDesc = desc
|
||||
break
|
||||
}
|
||||
|
||||
// Detect loops.
|
||||
if desc.next == head {
|
||||
return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
next = desc.next
|
||||
}
|
||||
if tailDesc == nil {
|
||||
// A descriptor chain longer than the queue size but without loops
|
||||
// should be impossible.
|
||||
panic(fmt.Sprintf("could not find a tail for descriptor chain starting at %d", head))
|
||||
}
|
||||
|
||||
// The tail descriptor does not have the next flag set, but when it comes
|
||||
// back into the free chain, it should have.
|
||||
tailDesc.flags = descriptorFlagHasNext
|
||||
|
||||
if dt.freeHeadIndex == noFreeHead {
|
||||
// The whole free chain was used up, so we turn this returned descriptor
|
||||
// chain into the new free chain by completing the circle and using its
|
||||
// head.
|
||||
tailDesc.next = head
|
||||
dt.freeHeadIndex = head
|
||||
} else {
|
||||
// Attach the returned chain at the beginning of the free chain but
|
||||
// right after the free chain head.
|
||||
freeHeadDesc := &dt.descriptors[dt.freeHeadIndex]
|
||||
tailDesc.next = freeHeadDesc.next
|
||||
freeHeadDesc.next = head
|
||||
}
|
||||
|
||||
dt.freeNum += chainLen
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkUnusedDescriptorLength asserts that the length of an unused descriptor
|
||||
// is zero, as it should be.
|
||||
// This is not a requirement by the virtio spec but rather a thing we do to
|
||||
|
||||
@@ -128,8 +128,7 @@ func NewSplitQueue(queueSize int, itemSize int) (_ *SplitQueue, err error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Consume used buffer notifications in the background.
|
||||
sq.stop = sq.startConsumeUsedRing()
|
||||
sq.stop = sq.kickSelfToExit()
|
||||
|
||||
return &sq, nil
|
||||
}
|
||||
@@ -169,9 +168,7 @@ func (sq *SplitQueue) CallEventFD() int {
|
||||
return sq.callEventFD.FD()
|
||||
}
|
||||
|
||||
// startConsumeUsedRing starts a goroutine that runs [consumeUsedRing].
|
||||
// A function is returned that can be used to gracefully cancel it. todo rename
|
||||
func (sq *SplitQueue) startConsumeUsedRing() func() error {
|
||||
func (sq *SplitQueue) kickSelfToExit() func() error {
|
||||
return func() error {
|
||||
|
||||
// The goroutine blocks until it receives a signal on the event file
|
||||
@@ -185,7 +182,15 @@ func (sq *SplitQueue) startConsumeUsedRing() func() error {
|
||||
}
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) {
|
||||
func (sq *SplitQueue) TakeSingleIndex(ctx context.Context) (uint16, error) {
|
||||
element, err := sq.TakeSingle(ctx)
|
||||
if err != nil {
|
||||
return 0xffff, err
|
||||
}
|
||||
return element.GetHead(), nil
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) TakeSingle(ctx context.Context) (UsedElement, error) {
|
||||
var n int
|
||||
var err error
|
||||
for ctx.Err() == nil {
|
||||
@@ -195,7 +200,7 @@ func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) {
|
||||
}
|
||||
// Wait for a signal from the device.
|
||||
if n, err = sq.epoll.Block(); err != nil {
|
||||
return 0, fmt.Errorf("wait: %w", err)
|
||||
return UsedElement{}, fmt.Errorf("wait: %w", err)
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
@@ -208,7 +213,31 @@ func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0, ctx.Err()
|
||||
return UsedElement{}, ctx.Err()
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) TakeSingleNoBlock() (UsedElement, bool) {
|
||||
return sq.usedRing.takeOne()
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) WaitForUsedElements(ctx context.Context) error {
|
||||
if sq.usedRing.availableToTake() != 0 {
|
||||
return nil
|
||||
}
|
||||
for ctx.Err() == nil {
|
||||
// Wait for a signal from the device.
|
||||
n, err := sq.epoll.Block()
|
||||
if err != nil {
|
||||
return fmt.Errorf("wait: %w", err)
|
||||
}
|
||||
if n > 0 {
|
||||
_ = sq.epoll.Clear()
|
||||
if sq.usedRing.availableToTake() != 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) ([]UsedElement, error) {
|
||||
@@ -235,7 +264,7 @@ func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int)
|
||||
return nil, fmt.Errorf("wait: %w", err)
|
||||
}
|
||||
if n > 0 {
|
||||
_ = sq.epoll.Clear() //???
|
||||
_ = sq.epoll.Clear()
|
||||
stillNeedToTake, out = sq.usedRing.take(maxToTake)
|
||||
sq.more = stillNeedToTake
|
||||
return out, nil
|
||||
@@ -296,16 +325,14 @@ func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) {
|
||||
sq.availableRing.offerSingle(head)
|
||||
|
||||
// Notify the device to make it process the updated available ring.
|
||||
if err := sq.kickEventFD.Kick(); err != nil {
|
||||
if err = sq.kickEventFD.Kick(); err != nil {
|
||||
return head, fmt.Errorf("notify device: %w", err)
|
||||
}
|
||||
|
||||
return head, nil
|
||||
}
|
||||
|
||||
// GetDescriptorChain returns the device-readable buffers (out buffers) and
|
||||
// device-writable buffers (in buffers) of the descriptor chain with the given
|
||||
// head index.
|
||||
// GetDescriptorItem returns the buffer of a given index
|
||||
// The head index must be one that was returned by a previous call to
|
||||
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
|
||||
// freed yet.
|
||||
@@ -313,37 +340,11 @@ func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) {
|
||||
// Be careful to only access the returned buffer slices when the device is no
|
||||
// longer using them. They must not be accessed after
|
||||
// [SplitQueue.FreeDescriptorChain] has been called.
|
||||
func (sq *SplitQueue) GetDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
|
||||
return sq.descriptorTable.getDescriptorChain(head)
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) GetDescriptorItem(head uint16) ([]byte, error) {
|
||||
sq.descriptorTable.descriptors[head].length = uint32(sq.descriptorTable.itemSize)
|
||||
return sq.descriptorTable.getDescriptorItem(head)
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) GetDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
|
||||
return sq.descriptorTable.getDescriptorInbuffers(head, inBuffers)
|
||||
}
|
||||
|
||||
// FreeDescriptorChain frees the descriptor chain with the given head index.
|
||||
// The head index must be one that was returned by a previous call to
|
||||
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
|
||||
// freed yet.
|
||||
//
|
||||
// This creates new room in the queue which can be used by following
|
||||
// [SplitQueue.OfferDescriptorChain] calls.
|
||||
// When there are outstanding calls for [SplitQueue.OfferDescriptorChain] that
|
||||
// are waiting for free room in the queue, they may become unblocked by this.
|
||||
func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
|
||||
//not called under lock
|
||||
if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
|
||||
return fmt.Errorf("free: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) SetDescSize(head uint16, sz int) {
|
||||
//not called under lock
|
||||
sq.descriptorTable.descriptors[int(head)].length = uint32(sz)
|
||||
|
||||
@@ -84,17 +84,11 @@ func (r *UsedRing) Address() uintptr {
|
||||
return uintptr(unsafe.Pointer(r.flags))
|
||||
}
|
||||
|
||||
// take returns all new [UsedElement]s that the device put into the ring and
|
||||
// that weren't already returned by a previous call to this method.
|
||||
// had a lock, I removed it
|
||||
func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
|
||||
//r.mu.Lock()
|
||||
//defer r.mu.Unlock()
|
||||
|
||||
func (r *UsedRing) availableToTake() int {
|
||||
ringIndex := *r.ringIndex
|
||||
if ringIndex == r.lastIndex {
|
||||
// Nothing new.
|
||||
return 0, nil
|
||||
return 0
|
||||
}
|
||||
|
||||
// Calculate the number new used elements that we can read from the ring.
|
||||
@@ -103,6 +97,16 @@ func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
|
||||
if count < 0 {
|
||||
count += 0xffff
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// take returns all new [UsedElement]s that the device put into the ring and
|
||||
// that weren't already returned by a previous call to this method.
|
||||
func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
|
||||
count := r.availableToTake()
|
||||
if count == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
stillNeedToTake := 0
|
||||
|
||||
@@ -128,21 +132,13 @@ func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
|
||||
return stillNeedToTake, elems
|
||||
}
|
||||
|
||||
func (r *UsedRing) takeOne() (uint16, bool) {
|
||||
func (r *UsedRing) takeOne() (UsedElement, bool) {
|
||||
//r.mu.Lock()
|
||||
//defer r.mu.Unlock()
|
||||
|
||||
ringIndex := *r.ringIndex
|
||||
if ringIndex == r.lastIndex {
|
||||
// Nothing new.
|
||||
return 0xffff, false
|
||||
}
|
||||
|
||||
// Calculate the number new used elements that we can read from the ring.
|
||||
// The ring index may wrap, so special handling for that case is needed.
|
||||
count := int(ringIndex - r.lastIndex)
|
||||
if count < 0 {
|
||||
count += 0xffff
|
||||
count := r.availableToTake()
|
||||
if count == 0 {
|
||||
return UsedElement{}, false
|
||||
}
|
||||
|
||||
// The number of new elements can never exceed the queue size.
|
||||
@@ -150,11 +146,7 @@ func (r *UsedRing) takeOne() (uint16, bool) {
|
||||
panic("used ring contains more new elements than the ring is long")
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
return 0xffff, false
|
||||
}
|
||||
|
||||
out := r.ring[r.lastIndex%uint16(len(r.ring))].GetHead()
|
||||
out := r.ring[r.lastIndex%uint16(len(r.ring))]
|
||||
r.lastIndex++
|
||||
|
||||
return out, true
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
|
||||
const Size = 0xffff
|
||||
|
||||
type Packet struct {
|
||||
type UDPPacket struct {
|
||||
Payload []byte
|
||||
Control []byte
|
||||
Name []byte
|
||||
@@ -25,8 +25,8 @@ type Packet struct {
|
||||
isV4 bool
|
||||
}
|
||||
|
||||
func New(isV4 bool) *Packet {
|
||||
return &Packet{
|
||||
func New(isV4 bool) *UDPPacket {
|
||||
return &UDPPacket{
|
||||
Payload: make([]byte, Size),
|
||||
Control: make([]byte, unix.CmsgSpace(2)),
|
||||
Name: make([]byte, unix.SizeofSockaddrInet6),
|
||||
@@ -34,7 +34,7 @@ func New(isV4 bool) *Packet {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Packet) AddrPort() netip.AddrPort {
|
||||
func (p *UDPPacket) AddrPort() netip.AddrPort {
|
||||
var ip netip.Addr
|
||||
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
|
||||
if p.isV4 {
|
||||
@@ -45,7 +45,7 @@ func (p *Packet) AddrPort() netip.AddrPort {
|
||||
return netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(p.Name[2:4]))
|
||||
}
|
||||
|
||||
func (p *Packet) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
|
||||
func (p *UDPPacket) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
|
||||
//todo no chance this works on windows?
|
||||
if p.isV4 {
|
||||
if !addr.Addr().Is4() {
|
||||
@@ -69,7 +69,7 @@ func (p *Packet) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error)
|
||||
return uint32(size), nil
|
||||
}
|
||||
|
||||
func (p *Packet) SetAddrPort(addr netip.AddrPort) error {
|
||||
func (p *UDPPacket) SetAddrPort(addr netip.AddrPort) error {
|
||||
nl, err := p.encodeSockaddr(p.Name, addr)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -78,7 +78,7 @@ func (p *Packet) SetAddrPort(addr netip.AddrPort) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Packet) updateCtrl(ctrlLen int) {
|
||||
func (p *UDPPacket) updateCtrl(ctrlLen int) {
|
||||
p.SegSize = len(p.Payload)
|
||||
p.wasSegmented = false
|
||||
if ctrlLen == 0 {
|
||||
@@ -101,12 +101,12 @@ func (p *Packet) updateCtrl(ctrlLen int) {
|
||||
}
|
||||
}
|
||||
|
||||
// Update sets a Packet into "just received, not processed" state
|
||||
func (p *Packet) Update(ctrlLen int) {
|
||||
// Update sets a UDPPacket into "just received, not processed" state
|
||||
func (p *UDPPacket) Update(ctrlLen int) {
|
||||
p.updateCtrl(ctrlLen)
|
||||
}
|
||||
|
||||
func (p *Packet) SetSegSizeForTX() {
|
||||
func (p *UDPPacket) SetSegSizeForTX() {
|
||||
p.SegSize = len(p.Payload)
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&p.Control[0]))
|
||||
hdr.Level = unix.SOL_UDP
|
||||
@@ -115,7 +115,7 @@ func (p *Packet) SetSegSizeForTX() {
|
||||
binary.NativeEndian.PutUint16(p.Control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(p.SegSize))
|
||||
}
|
||||
|
||||
func (p *Packet) CompatibleForSegmentationWith(otherP *Packet, currentTotalSize int) bool {
|
||||
func (p *UDPPacket) CompatibleForSegmentationWith(otherP *UDPPacket, currentTotalSize int) bool {
|
||||
//same dest
|
||||
if !slices.Equal(p.Name, otherP.Name) {
|
||||
return false
|
||||
@@ -134,7 +134,7 @@ func (p *Packet) CompatibleForSegmentationWith(otherP *Packet, currentTotalSize
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *Packet) Segments() iter.Seq[[]byte] {
|
||||
func (p *UDPPacket) Segments() iter.Seq[[]byte] {
|
||||
return func(yield func([]byte) bool) {
|
||||
//cursor := 0
|
||||
for offset := 0; offset < len(p.Payload); offset += p.SegSize {
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
package packet
|
||||
|
||||
import (
|
||||
"github.com/slackhq/nebula/util/virtio"
|
||||
)
|
||||
|
||||
type VirtIOPacket struct {
|
||||
Payload []byte
|
||||
Header virtio.NetHdr
|
||||
Chains []uint16
|
||||
ChainRefs [][]byte
|
||||
}
|
||||
|
||||
func NewVIO() *VirtIOPacket {
|
||||
out := new(VirtIOPacket)
|
||||
out.Payload = nil
|
||||
out.ChainRefs = make([][]byte, 0, 4)
|
||||
out.Chains = make([]uint16, 0, 8)
|
||||
return out
|
||||
}
|
||||
|
||||
func (v *VirtIOPacket) Reset() {
|
||||
v.Payload = nil
|
||||
v.ChainRefs = v.ChainRefs[:0]
|
||||
v.Chains = v.Chains[:0]
|
||||
}
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
const MTU = 9001
|
||||
|
||||
type EncReader func(
|
||||
[]*packet.Packet,
|
||||
[]*packet.UDPPacket,
|
||||
)
|
||||
|
||||
type Conn interface {
|
||||
@@ -19,8 +19,8 @@ type Conn interface {
|
||||
ListenOut(r EncReader)
|
||||
WriteTo(b []byte, addr netip.AddrPort) error
|
||||
ReloadConfig(c *config.C)
|
||||
Prep(pkt *packet.Packet, addr netip.AddrPort) error
|
||||
WriteBatch(pkt []*packet.Packet) (int, error)
|
||||
Prep(pkt *packet.UDPPacket, addr netip.AddrPort) error
|
||||
WriteBatch(pkt []*packet.UDPPacket) (int, error)
|
||||
SupportsMultipleReaders() bool
|
||||
Close() error
|
||||
}
|
||||
|
||||
@@ -215,7 +215,7 @@ func (u *StdConn) WriteToBatch(b []byte, ip netip.AddrPort) error {
|
||||
return u.writeTo6(b, ip)
|
||||
}
|
||||
|
||||
func (u *StdConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error {
|
||||
func (u *StdConn) Prep(pkt *packet.UDPPacket, addr netip.AddrPort) error {
|
||||
//todo move this into pkt
|
||||
nl, err := u.encodeSockaddr(pkt.Name, addr)
|
||||
if err != nil {
|
||||
@@ -226,7 +226,7 @@ func (u *StdConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *StdConn) WriteBatch(pkts []*packet.Packet) (int, error) {
|
||||
func (u *StdConn) WriteBatch(pkts []*packet.UDPPacket) (int, error) {
|
||||
if len(pkts) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
@@ -235,7 +235,7 @@ func (u *StdConn) WriteBatch(pkts []*packet.Packet) (int, error) {
|
||||
//u.iovs = u.iovs[:0]
|
||||
|
||||
sent := 0
|
||||
var mostRecentPkt *packet.Packet
|
||||
var mostRecentPkt *packet.UDPPacket
|
||||
mostRecentPktSize := 0
|
||||
//segmenting := false
|
||||
idx := 0
|
||||
|
||||
@@ -52,9 +52,9 @@ func setCmsgLen(h *unix.Cmsghdr, l int) {
|
||||
h.Len = uint64(l)
|
||||
}
|
||||
|
||||
func (u *StdConn) PrepareRawMessages(n int, isV4 bool) ([]rawMessage, []*packet.Packet) {
|
||||
func (u *StdConn) PrepareRawMessages(n int, isV4 bool) ([]rawMessage, []*packet.UDPPacket) {
|
||||
msgs := make([]rawMessage, n)
|
||||
packets := make([]*packet.Packet, n)
|
||||
packets := make([]*packet.UDPPacket, n)
|
||||
|
||||
for i := range msgs {
|
||||
packets[i] = packet.New(isV4)
|
||||
|
||||
@@ -41,7 +41,7 @@ type TesterConn struct {
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func (u *TesterConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error {
|
||||
func (u *TesterConn) Prep(pkt *packet.UDPPacket, addr netip.AddrPort) error {
|
||||
pkt.ReadyToSend = true
|
||||
return pkt.SetAddrPort(addr)
|
||||
}
|
||||
@@ -96,7 +96,7 @@ func (u *TesterConn) Get(block bool) *Packet {
|
||||
// Below this is boilerplate implementation to make nebula actually work
|
||||
//********************************************************************************************************************//
|
||||
|
||||
func (u *TesterConn) WriteBatch(pkts []*packet.Packet) (int, error) {
|
||||
func (u *TesterConn) WriteBatch(pkts []*packet.UDPPacket) (int, error) {
|
||||
for _, pkt := range pkts {
|
||||
if !pkt.ReadyToSend {
|
||||
continue
|
||||
@@ -141,7 +141,7 @@ func (u *TesterConn) ListenOut(r EncReader) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
y := []*packet.Packet{x}
|
||||
y := []*packet.UDPPacket{x}
|
||||
r(y)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user