diff --git a/main.go b/main.go index eb296fb..4aa7444 100644 --- a/main.go +++ b/main.go @@ -162,9 +162,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg listenHost = ips[0].Unmap() } - for i := 0; i < routines; i++ { - l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port))) - udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64)) + useWG := c.GetBool("listen.use_wireguard_stack", false) + var mkListener func(*logrus.Logger, netip.Addr, int, bool, int) (udp.Conn, error) + if useWG { + mkListener = udp.NewWireguardListener + } else { + mkListener = udp.NewListener + } + + for i := 0; i < routines; i++ { + l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port))) + udpServer, err := mkListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64)) if err != nil { return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 44d8746..64b400f 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -19,6 +19,7 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" + wgtun "github.com/slackhq/nebula/wgstack/tun" "github.com/vishvananda/netlink" "golang.org/x/sys/unix" ) @@ -33,6 +34,7 @@ type tun struct { TXQueueLen int deviceIndex int ioctlFd uintptr + wgDevice wgtun.Device Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] @@ -68,7 +70,8 @@ type ifreqQLEN struct { func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, vpnNetworks) + useWG := c.GetBool("tun.use_wireguard_stack", c.GetBool("listen.use_wireguard_stack", false)) + t, err := newTunGeneric(c, l, file, vpnNetworks, useWG) if err != nil { return nil, err } @@ -113,7 +116,8 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu name := strings.Trim(string(req.Name[:]), "\x00") file := os.NewFile(uintptr(fd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, vpnNetworks) + useWG := c.GetBool("tun.use_wireguard_stack", c.GetBool("listen.use_wireguard_stack", false)) + t, err := newTunGeneric(c, l, file, vpnNetworks, useWG) if err != nil { return nil, err } @@ -123,16 +127,45 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu return t, nil } -func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix, useWireguard bool) (*tun, error) { + var ( + rw io.ReadWriteCloser = file + fd = int(file.Fd()) + wgDev wgtun.Device + ) + + if useWireguard { + dev, err := wgtun.CreateTUNFromFile(file, c.GetInt("tun.mtu", DefaultMTU)) + if err != nil { + return nil, fmt.Errorf("failed to initialize wireguard tun device: %w", err) + } + wgDev = dev + rw = newWireguardTunIO(dev, c.GetInt("tun.mtu", DefaultMTU)) + fd = int(dev.File().Fd()) + } + t := &tun{ - ReadWriteCloser: file, - fd: int(file.Fd()), + ReadWriteCloser: rw, + fd: fd, vpnNetworks: vpnNetworks, TXQueueLen: c.GetInt("tun.tx_queue", 500), useSystemRoutes: c.GetBool("tun.use_system_route_table", false), useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0), l: l, } + if wgDev != nil { + t.wgDevice = wgDev + } + if wgDev != nil { + // replace ioctl fd with device file descriptor to keep route management working + file = wgDev.File() + t.fd = int(file.Fd()) + t.ioctlFd = file.Fd() + } + + if t.ioctlFd == 0 { + t.ioctlFd = file.Fd() + } err := t.reload(c, true) if err != nil { @@ -678,6 +711,14 @@ func (t *tun) Close() error { _ = t.ReadWriteCloser.Close() } + if t.wgDevice != nil { + _ = t.wgDevice.Close() + if t.ioctlFd > 0 { + // underlying fd already closed by the device + t.ioctlFd = 0 + } + } + if t.ioctlFd > 0 { _ = os.NewFile(t.ioctlFd, "ioctlFd").Close() }