This commit is contained in:
JackDoan
2026-04-16 12:26:35 -05:00
parent b644131fd7
commit 2a0fd0be1d
2 changed files with 129 additions and 67 deletions

View File

@@ -485,15 +485,7 @@ func (f *Interface) Close() error {
f.l.WithError(err).Error("Error while closing udp socket") f.l.WithError(err).Error("Error while closing udp socket")
} }
} }
for i, r := range f.readers {
if i == 0 {
continue // f.readers[0] is f.inside, which we want to save for last
}
if err := r.Close(); err != nil {
f.l.WithError(err).Error("Error while closing tun reader")
}
}
// Release the tun device // Release the tun device (closing the tun also closes all readers)
return f.inside.Close() return f.inside.Close()
} }

View File

@@ -25,61 +25,118 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// tunFd wraps a non-blocking TUN file descriptor with poll-based reads. // tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking.
// A shared eventfd allows Close to wake all readers blocked in poll. // A shared eventfd allows Close to wake all readers blocked in poll.
type tunFd struct { type tunFile struct {
fd int fd int
lastOne bool
pollFds [2]unix.PollFd pollFds [2]unix.PollFd
closed bool
} }
func newTunFd(fd, shutdownFd int) *tunFd { // newFriend makes a tunFile for a MultiQueueReader that copies the shutdown eventfd from the parent tun
return &tunFd{ func (r *tunFile) newFriend(fd int) (*tunFile, error) {
if err := unix.SetNonblock(fd, true); err != nil {
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
}
return &tunFile{
fd: fd, fd: fd,
pollFds: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN},
{Fd: r.pollFds[1].Fd, Events: unix.POLLIN},
},
}, nil
}
func newTunFd(fd int) (*tunFile, error) {
if err := unix.SetNonblock(fd, true); err != nil {
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
}
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
if err != nil {
return nil, fmt.Errorf("failed to create eventfd: %w", err)
}
out := &tunFile{
fd: fd,
lastOne: true,
pollFds: [2]unix.PollFd{ pollFds: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN}, {Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(shutdownFd), Events: unix.POLLIN}, {Fd: int32(shutdownFd), Events: unix.POLLIN},
}, },
} }
return out, nil
} }
func (r *tunFd) Read(buf []byte) (int, error) { func (r *tunFile) block() error {
for { const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
n, err := unix.Read(r.fd, buf) var err error
if err != nil {
if err == unix.EAGAIN {
for { for {
_, err = unix.Poll(r.pollFds[:], -1) _, err = unix.Poll(r.pollFds[:], -1)
if err != unix.EINTR { if err != unix.EINTR {
break break
} }
} }
if err != nil { //always reset these!
return 0, err tunEvents := r.pollFds[0].Revents
} shutdownEvents := r.pollFds[1].Revents
if r.pollFds[1].Revents&unix.POLLIN != 0 {
return 0, os.ErrClosed
}
r.pollFds[0].Revents = 0 r.pollFds[0].Revents = 0
r.pollFds[1].Revents = 0 r.pollFds[1].Revents = 0
continue //do the err check before trusting the potentially bogus bits we just got
if err != nil {
return err
} }
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
} else if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (r *tunFile) Read(buf []byte) (int, error) {
for {
if n, err := unix.Read(r.fd, buf); err == nil {
return n, nil
} else if err == unix.EAGAIN {
if err = r.block(); err != nil {
return 0, err
}
continue
} else {
return 0, err return 0, err
} }
return n, nil
} }
} }
func (r *tunFd) Write(buf []byte) (int, error) { func (r *tunFile) Write(buf []byte) (int, error) {
return unix.Write(r.fd, buf) return unix.Write(r.fd, buf)
} }
func (r *tunFd) Close() error { func (r *tunFile) wakeForShutdown() error {
var buf [8]byte
binary.NativeEndian.PutUint64(buf[:], 1)
_, err := unix.Write(int(r.pollFds[1].Fd), buf[:])
return err
}
func (r *tunFile) Close() error {
if r.closed { // avoid closing more than once. Technically a fd could get re-used, which would be a problem
return nil
}
r.closed = true
if r.lastOne {
_ = unix.Close(int(r.pollFds[1].Fd))
}
return unix.Close(r.fd) return unix.Close(r.fd)
} }
type tun struct { type tun struct {
*tunFd *tunFile
shutdownFd int readers []*tunFile
closeLock sync.Mutex
Device string Device string
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
MaxMTU int MaxMTU int
@@ -166,6 +223,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
nameStr := c.GetString("tun.dev", "") nameStr := c.GetString("tun.dev", "")
copy(req.Name[:], nameStr) copy(req.Name[:], nameStr)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
_ = unix.Close(fd)
return nil, &NameError{ return nil, &NameError{
Name: nameStr, Name: nameStr,
Underlying: err, Underlying: err,
@@ -183,22 +241,17 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
return t, nil return t, nil
} }
// newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error.
func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) { func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) {
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC) tfd, err := newTunFd(fd)
if err != nil { if err != nil {
_ = unix.Close(fd) _ = unix.Close(fd)
return nil, fmt.Errorf("failed to create eventfd: %w", err) return nil, err
} }
if err = unix.SetNonblock(fd, true); err != nil {
_ = unix.Close(fd)
_ = unix.Close(shutdownFd)
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
}
t := &tun{ t := &tun{
tunFd: newTunFd(fd, shutdownFd), tunFile: tfd,
shutdownFd: shutdownFd, readers: []*tunFile{tfd},
closeLock: sync.Mutex{},
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500), TXQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false), useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
@@ -207,7 +260,8 @@ func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vpnNetworks []netip.Pr
l: l, l: l,
} }
if err := t.reload(c, true); err != nil { if err = t.reload(c, true); err != nil {
_ = t.Close()
return nil, err return nil, err
} }
@@ -300,6 +354,9 @@ func (t *tun) SupportsMultiqueue() bool {
} }
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
t.closeLock.Lock()
defer t.closeLock.Unlock()
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -313,12 +370,15 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, err return nil, err
} }
if err = unix.SetNonblock(fd, true); err != nil { out, err := t.tunFile.newFriend(fd)
if err != nil {
_ = unix.Close(fd) _ = unix.Close(fd)
return nil, err return nil, err
} }
return newTunFd(fd, t.shutdownFd), nil t.readers = append(t.readers, out)
return out, nil
} }
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
@@ -749,30 +809,40 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
} }
func (t *tun) Close() error { func (t *tun) Close() error {
t.closeLock.Lock()
defer t.closeLock.Unlock()
if t.routeChan != nil { if t.routeChan != nil {
close(t.routeChan) close(t.routeChan)
t.routeChan = nil
} }
// Signal all readers blocked in poll to wake up and exit // Signal all readers blocked in poll to wake up and exit
if t.shutdownFd >= 0 { _ = t.tunFile.wakeForShutdown()
var buf [8]byte
binary.NativeEndian.PutUint64(buf[:], 1)
_, _ = unix.Write(t.shutdownFd, buf[:])
}
if t.tunFd != nil {
_ = t.tunFd.Close()
}
if t.shutdownFd >= 0 {
_ = unix.Close(t.shutdownFd)
t.shutdownFd = -1
}
if t.ioctlFd > 0 { if t.ioctlFd > 0 {
_ = os.NewFile(t.ioctlFd, "ioctlFd").Close() _ = unix.Close(int(t.ioctlFd))
t.ioctlFd = 0 t.ioctlFd = 0
} }
return nil for i := range t.readers {
if i == 0 {
continue //we want to close the zeroth reader last
}
err := t.readers[i].Close()
if err != nil {
t.l.WithField("reader", i).WithError(err).Error("Error closing tun reader")
} else {
t.l.WithField("reader", i).Info("Closed tun reader")
}
}
//this is t.readers[0] too
err := t.tunFile.Close()
if err != nil {
t.l.WithField("reader", 0).WithError(err).Error("Error closing tun reader")
} else {
t.l.WithField("reader", 0).Info("Closed tun reader")
}
return err
} }