mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-23 00:44:25 +01:00
this is awful, but also it's about 20% better
This commit is contained in:
@@ -1,17 +1,16 @@
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/netip"
|
||||
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
type Device interface {
|
||||
io.ReadWriteCloser
|
||||
TunDev
|
||||
Activate() error
|
||||
Networks() []netip.Prefix
|
||||
Name() string
|
||||
RoutesFor(netip.Addr) routing.Gateways
|
||||
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
||||
NewMultiQueueReader() (TunDev, error)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package overlay
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
@@ -12,6 +13,11 @@ import (
|
||||
|
||||
const DefaultMTU = 1300
|
||||
|
||||
type TunDev interface {
|
||||
io.ReadWriteCloser
|
||||
WriteMany([][]byte) (int, error)
|
||||
}
|
||||
|
||||
// TODO: We may be able to remove routines
|
||||
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
||||
|
||||
|
||||
@@ -105,7 +105,19 @@ func (t *disabledTun) Write(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
func (t *disabledTun) WriteMany(b [][]byte) (int, error) {
|
||||
out := 0
|
||||
for i := range b {
|
||||
x, err := t.Write(b[i])
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
out += x
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (t *disabledTun) NewMultiQueueReader() (TunDev, error) {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -257,7 +257,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
func (t *tun) NewMultiQueueReader() (TunDev, error) {
|
||||
//fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
//if err != nil {
|
||||
// return nil, err
|
||||
@@ -741,3 +741,24 @@ func (t *tun) Write(b []byte) (int, error) {
|
||||
}
|
||||
return maximum, nil
|
||||
}
|
||||
|
||||
func (t *tun) WriteMany(b [][]byte) (int, error) {
|
||||
maximum := len(b) //we are RXing
|
||||
|
||||
hdr := virtio.NetHdr{ //todo
|
||||
Flags: unix.VIRTIO_NET_HDR_F_DATA_VALID,
|
||||
GSOType: unix.VIRTIO_NET_HDR_GSO_NONE,
|
||||
HdrLen: 0,
|
||||
GSOSize: 0,
|
||||
CsumStart: 0,
|
||||
CsumOffset: 0,
|
||||
NumBuffers: 0,
|
||||
}
|
||||
|
||||
err := t.vdev.TransmitPackets(hdr, b)
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("Transmitting packet")
|
||||
return 0, err
|
||||
}
|
||||
return maximum, nil
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
return routing.Gateways{routing.NewGateway(ip, 1)}
|
||||
}
|
||||
|
||||
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
func (d *UserDevice) NewMultiQueueReader() (TunDev, error) {
|
||||
return d, nil
|
||||
}
|
||||
|
||||
@@ -65,3 +65,15 @@ func (d *UserDevice) Close() error {
|
||||
d.outboundWriter.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *UserDevice) WriteMany(b [][]byte) (int, error) {
|
||||
out := 0
|
||||
for i := range b {
|
||||
x, err := d.Write(b[i])
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
out += x
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
@@ -311,6 +311,33 @@ func (dev *Device) TransmitPacket(vnethdr virtio.NetHdr, packet []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) error {
|
||||
// Prepend the packet with its virtio-net header.
|
||||
vnethdrBuf := make([]byte, virtio.NetHdrSize+14) //todo WHY
|
||||
if err := vnethdr.Encode(vnethdrBuf); err != nil {
|
||||
return fmt.Errorf("encode vnethdr: %w", err)
|
||||
}
|
||||
vnethdrBuf[virtio.NetHdrSize+14-2] = 0x86
|
||||
vnethdrBuf[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype
|
||||
|
||||
chainIndexes, err := dev.transmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("offer descriptor chain: %w", err)
|
||||
}
|
||||
|
||||
//todo blocking here suxxxx
|
||||
// Wait for the packet to have been transmitted.
|
||||
for i := range chainIndexes {
|
||||
<-dev.transmitted[chainIndexes[i]]
|
||||
|
||||
if err = dev.transmitQueue.FreeDescriptorChain(chainIndexes[i]); err != nil {
|
||||
return fmt.Errorf("free descriptor chain: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReceivePacket reads the next available packet from the receive queue of this
|
||||
// device and returns its [virtio.NetHdr] and packet data separately.
|
||||
//
|
||||
|
||||
@@ -345,6 +345,66 @@ func (sq *SplitQueue) OfferDescriptorChain(outBuffers [][]byte, numInBuffers int
|
||||
return head, nil
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) OfferOutDescriptorChains(prepend []byte, outBuffers [][]byte, waitFree bool) ([]uint16, error) {
|
||||
sq.ensureInitialized()
|
||||
|
||||
// TODO change this
|
||||
// Each descriptor can only hold a whole memory page, so split large out
|
||||
// buffers into multiple smaller ones.
|
||||
outBuffers = splitBuffers(outBuffers, sq.pageSize)
|
||||
|
||||
// Synchronize the offering of descriptor chains. While the descriptor table
|
||||
// and available ring are synchronized on their own as well, this does not
|
||||
// protect us from interleaved calls which could cause reordering.
|
||||
// By locking here, we can ensure that all descriptor chains are made
|
||||
// available to the device in the same order as this method was called.
|
||||
sq.offerMutex.Lock()
|
||||
defer sq.offerMutex.Unlock()
|
||||
|
||||
chains := make([]uint16, len(outBuffers))
|
||||
|
||||
// Create a descriptor chain for the given buffers.
|
||||
var (
|
||||
head uint16
|
||||
err error
|
||||
)
|
||||
for i := range outBuffers {
|
||||
for {
|
||||
bufs := [][]byte{prepend, outBuffers[i]}
|
||||
head, err = sq.descriptorTable.createDescriptorChain(bufs, 0)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
// I don't wanna use errors.Is, it's slow
|
||||
//goland:noinspection GoDirectComparisonOfErrors
|
||||
if err == ErrNotEnoughFreeDescriptors {
|
||||
if waitFree {
|
||||
// Wait for more free descriptors to be put back into the queue.
|
||||
// If the number of free descriptors is still not sufficient, we'll
|
||||
// land here again.
|
||||
sq.blockForMoreDescriptors()
|
||||
continue
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("create descriptor chain: %w", err)
|
||||
}
|
||||
chains[i] = head
|
||||
}
|
||||
|
||||
// Make the descriptor chain available to the device.
|
||||
sq.availableRing.offer(chains)
|
||||
|
||||
// Notify the device to make it process the updated available ring.
|
||||
if err := sq.kickEventFD.Kick(); err != nil {
|
||||
return chains, fmt.Errorf("notify device: %w", err)
|
||||
}
|
||||
|
||||
return chains, nil
|
||||
}
|
||||
|
||||
// GetDescriptorChain returns the device-readable buffers (out buffers) and
|
||||
// device-writable buffers (in buffers) of the descriptor chain with the given
|
||||
// head index.
|
||||
|
||||
Reference in New Issue
Block a user