this is awful, but also it's about 20% better

This commit is contained in:
JackDoan
2025-11-08 14:57:33 -06:00
parent 1f043f84f3
commit 42591c2042
11 changed files with 427 additions and 19 deletions

View File

@@ -1,17 +1,16 @@
package overlay
import (
"io"
"net/netip"
"github.com/slackhq/nebula/routing"
)
type Device interface {
io.ReadWriteCloser
TunDev
Activate() error
Networks() []netip.Prefix
Name() string
RoutesFor(netip.Addr) routing.Gateways
NewMultiQueueReader() (io.ReadWriteCloser, error)
NewMultiQueueReader() (TunDev, error)
}

View File

@@ -2,6 +2,7 @@ package overlay
import (
"fmt"
"io"
"net"
"net/netip"
@@ -12,6 +13,11 @@ import (
const DefaultMTU = 1300
type TunDev interface {
io.ReadWriteCloser
WriteMany([][]byte) (int, error)
}
// TODO: We may be able to remove routines
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)

View File

@@ -105,7 +105,19 @@ func (t *disabledTun) Write(b []byte) (int, error) {
return len(b), nil
}
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *disabledTun) WriteMany(b [][]byte) (int, error) {
out := 0
for i := range b {
x, err := t.Write(b[i])
if err != nil {
return out, err
}
out += x
}
return out, nil
}
func (t *disabledTun) NewMultiQueueReader() (TunDev, error) {
return t, nil
}

View File

@@ -257,7 +257,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) NewMultiQueueReader() (TunDev, error) {
//fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
//if err != nil {
// return nil, err
@@ -741,3 +741,24 @@ func (t *tun) Write(b []byte) (int, error) {
}
return maximum, nil
}
func (t *tun) WriteMany(b [][]byte) (int, error) {
maximum := len(b) //we are RXing
hdr := virtio.NetHdr{ //todo
Flags: unix.VIRTIO_NET_HDR_F_DATA_VALID,
GSOType: unix.VIRTIO_NET_HDR_GSO_NONE,
HdrLen: 0,
GSOSize: 0,
CsumStart: 0,
CsumOffset: 0,
NumBuffers: 0,
}
err := t.vdev.TransmitPackets(hdr, b)
if err != nil {
t.l.WithError(err).Error("Transmitting packet")
return 0, err
}
return maximum, nil
}

View File

@@ -46,7 +46,7 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
return routing.Gateways{routing.NewGateway(ip, 1)}
}
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (d *UserDevice) NewMultiQueueReader() (TunDev, error) {
return d, nil
}
@@ -65,3 +65,15 @@ func (d *UserDevice) Close() error {
d.outboundWriter.Close()
return nil
}
func (d *UserDevice) WriteMany(b [][]byte) (int, error) {
out := 0
for i := range b {
x, err := d.Write(b[i])
if err != nil {
return out, err
}
out += x
}
return out, nil
}

View File

@@ -311,6 +311,33 @@ func (dev *Device) TransmitPacket(vnethdr virtio.NetHdr, packet []byte) error {
return nil
}
func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) error {
// Prepend the packet with its virtio-net header.
vnethdrBuf := make([]byte, virtio.NetHdrSize+14) //todo WHY
if err := vnethdr.Encode(vnethdrBuf); err != nil {
return fmt.Errorf("encode vnethdr: %w", err)
}
vnethdrBuf[virtio.NetHdrSize+14-2] = 0x86
vnethdrBuf[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype
chainIndexes, err := dev.transmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets, true)
if err != nil {
return fmt.Errorf("offer descriptor chain: %w", err)
}
//todo blocking here suxxxx
// Wait for the packet to have been transmitted.
for i := range chainIndexes {
<-dev.transmitted[chainIndexes[i]]
if err = dev.transmitQueue.FreeDescriptorChain(chainIndexes[i]); err != nil {
return fmt.Errorf("free descriptor chain: %w", err)
}
}
return nil
}
// ReceivePacket reads the next available packet from the receive queue of this
// device and returns its [virtio.NetHdr] and packet data separately.
//

View File

@@ -345,6 +345,66 @@ func (sq *SplitQueue) OfferDescriptorChain(outBuffers [][]byte, numInBuffers int
return head, nil
}
func (sq *SplitQueue) OfferOutDescriptorChains(prepend []byte, outBuffers [][]byte, waitFree bool) ([]uint16, error) {
sq.ensureInitialized()
// TODO change this
// Each descriptor can only hold a whole memory page, so split large out
// buffers into multiple smaller ones.
outBuffers = splitBuffers(outBuffers, sq.pageSize)
// Synchronize the offering of descriptor chains. While the descriptor table
// and available ring are synchronized on their own as well, this does not
// protect us from interleaved calls which could cause reordering.
// By locking here, we can ensure that all descriptor chains are made
// available to the device in the same order as this method was called.
sq.offerMutex.Lock()
defer sq.offerMutex.Unlock()
chains := make([]uint16, len(outBuffers))
// Create a descriptor chain for the given buffers.
var (
head uint16
err error
)
for i := range outBuffers {
for {
bufs := [][]byte{prepend, outBuffers[i]}
head, err = sq.descriptorTable.createDescriptorChain(bufs, 0)
if err == nil {
break
}
// I don't wanna use errors.Is, it's slow
//goland:noinspection GoDirectComparisonOfErrors
if err == ErrNotEnoughFreeDescriptors {
if waitFree {
// Wait for more free descriptors to be put back into the queue.
// If the number of free descriptors is still not sufficient, we'll
// land here again.
sq.blockForMoreDescriptors()
continue
} else {
return nil, err
}
}
return nil, fmt.Errorf("create descriptor chain: %w", err)
}
chains[i] = head
}
// Make the descriptor chain available to the device.
sq.availableRing.offer(chains)
// Notify the device to make it process the updated available ring.
if err := sq.kickEventFD.Kick(); err != nil {
return chains, fmt.Errorf("notify device: %w", err)
}
return chains, nil
}
// GetDescriptorChain returns the device-readable buffers (out buffers) and
// device-writable buffers (in buffers) of the descriptor chain with the given
// head index.