diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 366a559..9df4581 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -99,10 +99,36 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu devName := c.GetString("tun.dev", "") mtu := c.GetInt("tun.mtu", DefaultMTU) - // Create TUN device using wireguard library - wgDev, err := wgtun.CreateTUN(devName, mtu) + // Create TUN device manually to support multiqueue + fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { - return nil, fmt.Errorf("failed to create TUN device: %w", err) + return nil, err + } + + var req ifReq + req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI) + if multiqueue { + req.Flags |= unix.IFF_MULTI_QUEUE + } + copy(req.Name[:], devName) + if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + unix.Close(fd) + return nil, err + } + + // Set nonblocking + if err = unix.SetNonblock(fd, true); err != nil { + unix.Close(fd) + return nil, err + } + + file := os.NewFile(uintptr(fd), "/dev/net/tun") + + // Create wireguard device from file descriptor + wgDev, err := wgtun.CreateTUNFromFile(file, mtu) + if err != nil { + file.Close() + return nil, fmt.Errorf("failed to create TUN from file: %w", err) } name, err := wgDev.Name() @@ -111,7 +137,8 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu return nil, fmt.Errorf("failed to get TUN device name: %w", err) } - file := wgDev.File() + // file is now owned by wgDev, get a new reference + file = wgDev.File() t, err := newTunGeneric(c, l, file, vpnNetworks) if err != nil { _ = wgDev.Close() @@ -224,6 +251,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { } var req ifReq + // MUST match the flags used in newTun req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) copy(req.Name[:], t.Device) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { @@ -241,20 +269,6 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { } func (t *tun) Write(b []byte) (int, error) { - if t.wgDevice != nil { - // Use wireguard device for writing - bufs := [][]byte{b} - n, err := t.wgDevice.Write(bufs, 0) - if err != nil { - return 0, err - } - if n != 1 { - return 0, fmt.Errorf("expected to write 1 packet, wrote %d", n) - } - return len(b), nil - } - - // Fallback to direct fd write if no wireguard device var nn int maximum := len(b)