From 9ac45a06cfd8140240ce6ce3c58dbc10b10ead92 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Wed, 15 Apr 2026 17:45:50 -0500 Subject: [PATCH] tun_linux.go: stdlib too slow, but can't use blocking IO and clean shutdown --- overlay/tun_linux.go | 109 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 93 insertions(+), 16 deletions(-) diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 9d779a4b..beec79e1 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -4,6 +4,7 @@ package overlay import ( + "encoding/binary" "fmt" "io" "net" @@ -24,9 +25,61 @@ import ( "golang.org/x/sys/unix" ) +// tunFd wraps a non-blocking TUN file descriptor with poll-based reads. +// A shared eventfd allows Close to wake all readers blocked in poll. +type tunFd struct { + fd int + pollFds [2]unix.PollFd +} + +func newTunFd(fd, shutdownFd int) *tunFd { + return &tunFd{ + fd: fd, + pollFds: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLIN}, + {Fd: int32(shutdownFd), Events: unix.POLLIN}, + }, + } +} + +func (r *tunFd) Read(buf []byte) (int, error) { + for { + n, err := unix.Read(r.fd, buf) + if err != nil { + if err == unix.EAGAIN { + for { + _, err = unix.Poll(r.pollFds[:], -1) + if err != unix.EINTR { + break + } + } + if err != nil { + return 0, err + } + if r.pollFds[1].Revents&unix.POLLIN != 0 { + return 0, os.ErrClosed + } + r.pollFds[0].Revents = 0 + r.pollFds[1].Revents = 0 + continue + } + return 0, err + } + return n, nil + } +} + +func (r *tunFd) Write(buf []byte) (int, error) { + return unix.Write(r.fd, buf) +} + +func (r *tunFd) Close() error { + return unix.Close(r.fd) +} + type tun struct { - io.ReadWriteCloser - fd int + *tunFd + shutdownFd int Device string vpnNetworks []netip.Prefix MaxMTU int @@ -72,9 +125,7 @@ type ifreqQLEN struct { } func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { - file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") - - t, err := newTunGeneric(c, l, file, vpnNetworks) + t, err := newTunGeneric(c, l, deviceFd, vpnNetworks) if err != nil { return nil, err } @@ -122,8 +173,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu } name := strings.Trim(string(req.Name[:]), "\x00") - file := os.NewFile(uintptr(fd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, vpnNetworks) + t, err := newTunGeneric(c, l, fd, vpnNetworks) if err != nil { return nil, err } @@ -133,10 +183,22 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu return t, nil } -func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) { + shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC) + if err != nil { + unix.Close(fd) + return nil, fmt.Errorf("failed to create eventfd: %w", err) + } + + if err := unix.SetNonblock(fd, true); err != nil { + unix.Close(fd) + unix.Close(shutdownFd) + return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) + } + t := &tun{ - ReadWriteCloser: file, - fd: int(file.Fd()), + tunFd: newTunFd(fd, shutdownFd), + shutdownFd: shutdownFd, vpnNetworks: vpnNetworks, TXQueueLen: c.GetInt("tun.tx_queue", 500), useSystemRoutes: c.GetBool("tun.use_system_route_table", false), @@ -145,8 +207,7 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n l: l, } - err := t.reload(c, true) - if err != nil { + if err := t.reload(c, true); err != nil { return nil, err } @@ -248,12 +309,16 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) copy(req.Name[:], t.Device) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + unix.Close(fd) return nil, err } - file := os.NewFile(uintptr(fd), "/dev/net/tun") + if err = unix.SetNonblock(fd, true); err != nil { + unix.Close(fd) + return nil, err + } - return file, nil + return newTunFd(fd, t.shutdownFd), nil } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { @@ -688,8 +753,20 @@ func (t *tun) Close() error { close(t.routeChan) } - if t.ReadWriteCloser != nil { - _ = t.ReadWriteCloser.Close() + // Signal all readers blocked in poll to wake up and exit + if t.shutdownFd >= 0 { + var buf [8]byte + binary.NativeEndian.PutUint64(buf[:], 1) + unix.Write(t.shutdownFd, buf[:]) + } + + if t.tunFd != nil { + _ = t.tunFd.Close() + } + + if t.shutdownFd >= 0 { + unix.Close(t.shutdownFd) + t.shutdownFd = -1 } if t.ioctlFd > 0 {