tun_linux.go: stdlib too slow, but can't use blocking IO and clean shutdown

This commit is contained in:
JackDoan
2026-04-15 17:45:50 -05:00
parent 72c04b90bd
commit 9ac45a06cf

View File

@@ -4,6 +4,7 @@
package overlay package overlay
import ( import (
"encoding/binary"
"fmt" "fmt"
"io" "io"
"net" "net"
@@ -24,9 +25,61 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// tunFd wraps a non-blocking TUN file descriptor with poll-based reads.
// A shared eventfd allows Close to wake all readers blocked in poll.
type tunFd struct {
fd int
pollFds [2]unix.PollFd
}
func newTunFd(fd, shutdownFd int) *tunFd {
return &tunFd{
fd: fd,
pollFds: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
}
}
func (r *tunFd) Read(buf []byte) (int, error) {
for {
n, err := unix.Read(r.fd, buf)
if err != nil {
if err == unix.EAGAIN {
for {
_, err = unix.Poll(r.pollFds[:], -1)
if err != unix.EINTR {
break
}
}
if err != nil {
return 0, err
}
if r.pollFds[1].Revents&unix.POLLIN != 0 {
return 0, os.ErrClosed
}
r.pollFds[0].Revents = 0
r.pollFds[1].Revents = 0
continue
}
return 0, err
}
return n, nil
}
}
func (r *tunFd) Write(buf []byte) (int, error) {
return unix.Write(r.fd, buf)
}
func (r *tunFd) Close() error {
return unix.Close(r.fd)
}
type tun struct { type tun struct {
io.ReadWriteCloser *tunFd
fd int shutdownFd int
Device string Device string
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
MaxMTU int MaxMTU int
@@ -72,9 +125,7 @@ type ifreqQLEN struct {
} }
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") t, err := newTunGeneric(c, l, deviceFd, vpnNetworks)
t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -122,8 +173,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
} }
name := strings.Trim(string(req.Name[:]), "\x00") name := strings.Trim(string(req.Name[:]), "\x00")
file := os.NewFile(uintptr(fd), "/dev/net/tun") t, err := newTunGeneric(c, l, fd, vpnNetworks)
t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -133,10 +183,22 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
return t, nil return t, nil
} }
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) { func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) {
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
if err != nil {
unix.Close(fd)
return nil, fmt.Errorf("failed to create eventfd: %w", err)
}
if err := unix.SetNonblock(fd, true); err != nil {
unix.Close(fd)
unix.Close(shutdownFd)
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
}
t := &tun{ t := &tun{
ReadWriteCloser: file, tunFd: newTunFd(fd, shutdownFd),
fd: int(file.Fd()), shutdownFd: shutdownFd,
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500), TXQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false), useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
@@ -145,8 +207,7 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n
l: l, l: l,
} }
err := t.reload(c, true) if err := t.reload(c, true); err != nil {
if err != nil {
return nil, err return nil, err
} }
@@ -248,12 +309,16 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
copy(req.Name[:], t.Device) copy(req.Name[:], t.Device)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
unix.Close(fd)
return nil, err return nil, err
} }
file := os.NewFile(uintptr(fd), "/dev/net/tun") if err = unix.SetNonblock(fd, true); err != nil {
unix.Close(fd)
return nil, err
}
return file, nil return newTunFd(fd, t.shutdownFd), nil
} }
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
@@ -688,8 +753,20 @@ func (t *tun) Close() error {
close(t.routeChan) close(t.routeChan)
} }
if t.ReadWriteCloser != nil { // Signal all readers blocked in poll to wake up and exit
_ = t.ReadWriteCloser.Close() if t.shutdownFd >= 0 {
var buf [8]byte
binary.NativeEndian.PutUint64(buf[:], 1)
unix.Write(t.shutdownFd, buf[:])
}
if t.tunFd != nil {
_ = t.tunFd.Close()
}
if t.shutdownFd >= 0 {
unix.Close(t.shutdownFd)
t.shutdownFd = -1
} }
if t.ioctlFd > 0 { if t.ioctlFd > 0 {