mirror of
https://github.com/slackhq/nebula.git
synced 2025-12-16 20:08:27 +01:00
vhost
This commit is contained in:
@@ -5,7 +5,6 @@ package overlay
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
@@ -17,15 +16,19 @@ import (
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/overlay/vhostnet"
|
||||
"github.com/slackhq/nebula/packet"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/slackhq/nebula/util/virtio"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
file *os.File
|
||||
fd int
|
||||
vdev []*vhostnet.Device
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
MaxMTU int
|
||||
@@ -40,7 +43,8 @@ type tun struct {
|
||||
useSystemRoutes bool
|
||||
useSystemRoutesBufferSize int
|
||||
|
||||
l *logrus.Logger
|
||||
isV6 bool
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func (t *tun) Networks() []netip.Prefix {
|
||||
@@ -102,7 +106,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
||||
}
|
||||
|
||||
var req ifReq
|
||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
|
||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_TUN_EXCL | unix.IFF_VNET_HDR | unix.IFF_NAPI)
|
||||
if multiqueue {
|
||||
req.Flags |= unix.IFF_MULTI_QUEUE
|
||||
}
|
||||
@@ -112,20 +116,47 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
||||
}
|
||||
name := strings.Trim(string(req.Name[:]), "\x00")
|
||||
|
||||
if err = unix.SetNonblock(fd, true); err != nil {
|
||||
_ = unix.Close(fd)
|
||||
return nil, fmt.Errorf("make file descriptor non-blocking: %w", err)
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||
|
||||
err = unix.IoctlSetPointerInt(fd, unix.TUNSETVNETHDRSZ, virtio.NetHdrSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("set vnethdr size: %w", err)
|
||||
}
|
||||
|
||||
flags := 0
|
||||
//flags = //unix.TUN_F_CSUM //| unix.TUN_F_TSO4 | unix.TUN_F_USO4 | unix.TUN_F_TSO6 | unix.TUN_F_USO6
|
||||
err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, flags)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("set offloads: %w", err)
|
||||
}
|
||||
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t.fd = fd
|
||||
t.Device = name
|
||||
|
||||
vdev, err := vhostnet.NewDevice(
|
||||
vhostnet.WithBackendFD(fd),
|
||||
vhostnet.WithQueueSize(8192), //todo config
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.vdev = []*vhostnet.Device{vdev}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
file: file,
|
||||
fd: int(file.Fd()),
|
||||
vpnNetworks: vpnNetworks,
|
||||
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||
@@ -133,6 +164,9 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n
|
||||
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
||||
l: l,
|
||||
}
|
||||
if len(vpnNetworks) != 0 {
|
||||
t.isV6 = vpnNetworks[0].Addr().Is6() //todo what about multi-IP?
|
||||
}
|
||||
|
||||
err := t.reload(c, true)
|
||||
if err != nil {
|
||||
@@ -220,7 +254,7 @@ func (t *tun) SupportsMultiqueue() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
func (t *tun) NewMultiQueueReader() (TunDev, error) {
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -233,9 +267,17 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||
vdev, err := vhostnet.NewDevice(
|
||||
vhostnet.WithBackendFD(fd),
|
||||
vhostnet.WithQueueSize(8192), //todo config
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return file, nil
|
||||
t.vdev = append(t.vdev, vdev)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
@@ -243,29 +285,6 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *tun) Write(b []byte) (int, error) {
|
||||
var nn int
|
||||
maximum := len(b)
|
||||
|
||||
for {
|
||||
n, err := unix.Write(t.fd, b[nn:maximum])
|
||||
if n > 0 {
|
||||
nn += n
|
||||
}
|
||||
if nn == len(b) {
|
||||
return nn, err
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nn, err
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
return nn, io.ErrUnexpectedEOF
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tun) deviceBytes() (o [16]byte) {
|
||||
for i, c := range t.Device {
|
||||
o[i] = byte(c)
|
||||
@@ -689,8 +708,14 @@ func (t *tun) Close() error {
|
||||
close(t.routeChan)
|
||||
}
|
||||
|
||||
if t.ReadWriteCloser != nil {
|
||||
_ = t.ReadWriteCloser.Close()
|
||||
for _, v := range t.vdev {
|
||||
if v != nil {
|
||||
_ = v.Close()
|
||||
}
|
||||
}
|
||||
|
||||
if t.file != nil {
|
||||
_ = t.file.Close()
|
||||
}
|
||||
|
||||
if t.ioctlFd > 0 {
|
||||
@@ -699,3 +724,65 @@ func (t *tun) Close() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) ReadMany(p []*packet.VirtIOPacket, q int) (int, error) {
|
||||
n, err := t.vdev[q].ReceivePackets(p) //we are TXing
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (t *tun) Write(b []byte) (int, error) {
|
||||
maximum := len(b) //we are RXing
|
||||
|
||||
//todo garbagey
|
||||
out := packet.NewOut()
|
||||
x, err := t.AllocSeg(out, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
copy(out.SegmentPayloads[x], b)
|
||||
err = t.vdev[0].TransmitPacket(out, true)
|
||||
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("Transmitting packet")
|
||||
return 0, err
|
||||
}
|
||||
return maximum, nil
|
||||
}
|
||||
|
||||
func (t *tun) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
|
||||
idx, buf, err := t.vdev[q].GetPacketForTx()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
x := pkt.UseSegment(idx, buf, t.isV6)
|
||||
return x, nil
|
||||
}
|
||||
|
||||
func (t *tun) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
|
||||
if err := t.vdev[q].TransmitPacket(x, kick); err != nil {
|
||||
t.l.WithError(err).Error("Transmitting packet")
|
||||
return 0, err
|
||||
}
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func (t *tun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
|
||||
maximum := len(x) //we are RXing
|
||||
if maximum == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
err := t.vdev[q].TransmitPackets(x)
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("Transmitting packet")
|
||||
return 0, err
|
||||
}
|
||||
return maximum, nil
|
||||
}
|
||||
|
||||
func (t *tun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
|
||||
return t.vdev[q].ReceiveQueue.OfferDescriptorChains(pkt.Chains, kick)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user