diff --git a/interface.go b/interface.go index 61f8c9b7..d642cbbb 100644 --- a/interface.go +++ b/interface.go @@ -485,15 +485,7 @@ func (f *Interface) Close() error { f.l.WithError(err).Error("Error while closing udp socket") } } - for i, r := range f.readers { - if i == 0 { - continue // f.readers[0] is f.inside, which we want to save for last - } - if err := r.Close(); err != nil { - f.l.WithError(err).Error("Error while closing tun reader") - } - } - // Release the tun device + // Release the tun device (closing the tun also closes all readers) return f.inside.Close() } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index c6806441..8bd1609b 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -25,61 +25,118 @@ import ( "golang.org/x/sys/unix" ) -// tunFd wraps a non-blocking TUN file descriptor with poll-based reads. +// tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking. // A shared eventfd allows Close to wake all readers blocked in poll. -type tunFd struct { +type tunFile struct { fd int + lastOne bool pollFds [2]unix.PollFd + closed bool } -func newTunFd(fd, shutdownFd int) *tunFd { - return &tunFd{ +// newFriend makes a tunFile for a MultiQueueReader that copies the shutdown eventfd from the parent tun +func (r *tunFile) newFriend(fd int) (*tunFile, error) { + if err := unix.SetNonblock(fd, true); err != nil { + return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) + } + return &tunFile{ fd: fd, + pollFds: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLIN}, + {Fd: r.pollFds[1].Fd, Events: unix.POLLIN}, + }, + }, nil +} + +func newTunFd(fd int) (*tunFile, error) { + if err := unix.SetNonblock(fd, true); err != nil { + return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) + } + + shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC) + if err != nil { + return nil, fmt.Errorf("failed to create eventfd: %w", err) + } + + out := &tunFile{ + fd: fd, + lastOne: true, pollFds: [2]unix.PollFd{ {Fd: int32(fd), Events: unix.POLLIN}, {Fd: int32(shutdownFd), Events: unix.POLLIN}, }, } + + return out, nil } -func (r *tunFd) Read(buf []byte) (int, error) { +func (r *tunFile) block() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error for { - n, err := unix.Read(r.fd, buf) - if err != nil { - if err == unix.EAGAIN { - for { - _, err = unix.Poll(r.pollFds[:], -1) - if err != unix.EINTR { - break - } - } - if err != nil { - return 0, err - } - if r.pollFds[1].Revents&unix.POLLIN != 0 { - return 0, os.ErrClosed - } - r.pollFds[0].Revents = 0 - r.pollFds[1].Revents = 0 - continue + _, err = unix.Poll(r.pollFds[:], -1) + if err != unix.EINTR { + break + } + } + //always reset these! + tunEvents := r.pollFds[0].Revents + shutdownEvents := r.pollFds[1].Revents + r.pollFds[0].Revents = 0 + r.pollFds[1].Revents = 0 + //do the err check before trusting the potentially bogus bits we just got + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } else if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil +} + +func (r *tunFile) Read(buf []byte) (int, error) { + for { + if n, err := unix.Read(r.fd, buf); err == nil { + return n, nil + } else if err == unix.EAGAIN { + if err = r.block(); err != nil { + return 0, err } + continue + } else { return 0, err } - return n, nil } } -func (r *tunFd) Write(buf []byte) (int, error) { +func (r *tunFile) Write(buf []byte) (int, error) { return unix.Write(r.fd, buf) } -func (r *tunFd) Close() error { +func (r *tunFile) wakeForShutdown() error { + var buf [8]byte + binary.NativeEndian.PutUint64(buf[:], 1) + _, err := unix.Write(int(r.pollFds[1].Fd), buf[:]) + return err +} + +func (r *tunFile) Close() error { + if r.closed { // avoid closing more than once. Technically a fd could get re-used, which would be a problem + return nil + } + r.closed = true + if r.lastOne { + _ = unix.Close(int(r.pollFds[1].Fd)) + } return unix.Close(r.fd) } type tun struct { - *tunFd - shutdownFd int + *tunFile + readers []*tunFile + closeLock sync.Mutex Device string vpnNetworks []netip.Prefix MaxMTU int @@ -166,6 +223,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu nameStr := c.GetString("tun.dev", "") copy(req.Name[:], nameStr) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + _ = unix.Close(fd) return nil, &NameError{ Name: nameStr, Underlying: err, @@ -183,22 +241,17 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu return t, nil } +// newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error. func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) { - shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC) + tfd, err := newTunFd(fd) if err != nil { _ = unix.Close(fd) - return nil, fmt.Errorf("failed to create eventfd: %w", err) + return nil, err } - - if err = unix.SetNonblock(fd, true); err != nil { - _ = unix.Close(fd) - _ = unix.Close(shutdownFd) - return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) - } - t := &tun{ - tunFd: newTunFd(fd, shutdownFd), - shutdownFd: shutdownFd, + tunFile: tfd, + readers: []*tunFile{tfd}, + closeLock: sync.Mutex{}, vpnNetworks: vpnNetworks, TXQueueLen: c.GetInt("tun.tx_queue", 500), useSystemRoutes: c.GetBool("tun.use_system_route_table", false), @@ -207,7 +260,8 @@ func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vpnNetworks []netip.Pr l: l, } - if err := t.reload(c, true); err != nil { + if err = t.reload(c, true); err != nil { + _ = t.Close() return nil, err } @@ -300,6 +354,9 @@ func (t *tun) SupportsMultiqueue() bool { } func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { + t.closeLock.Lock() + defer t.closeLock.Unlock() + fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { return nil, err @@ -313,12 +370,15 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, err } - if err = unix.SetNonblock(fd, true); err != nil { + out, err := t.tunFile.newFriend(fd) + if err != nil { _ = unix.Close(fd) return nil, err } - return newTunFd(fd, t.shutdownFd), nil + t.readers = append(t.readers, out) + + return out, nil } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { @@ -749,30 +809,40 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { } func (t *tun) Close() error { + t.closeLock.Lock() + defer t.closeLock.Unlock() + if t.routeChan != nil { close(t.routeChan) + t.routeChan = nil } // Signal all readers blocked in poll to wake up and exit - if t.shutdownFd >= 0 { - var buf [8]byte - binary.NativeEndian.PutUint64(buf[:], 1) - _, _ = unix.Write(t.shutdownFd, buf[:]) - } - - if t.tunFd != nil { - _ = t.tunFd.Close() - } - - if t.shutdownFd >= 0 { - _ = unix.Close(t.shutdownFd) - t.shutdownFd = -1 - } + _ = t.tunFile.wakeForShutdown() if t.ioctlFd > 0 { - _ = os.NewFile(t.ioctlFd, "ioctlFd").Close() + _ = unix.Close(int(t.ioctlFd)) t.ioctlFd = 0 } - return nil + for i := range t.readers { + if i == 0 { + continue //we want to close the zeroth reader last + } + err := t.readers[i].Close() + if err != nil { + t.l.WithField("reader", i).WithError(err).Error("Error closing tun reader") + } else { + t.l.WithField("reader", i).Info("Closed tun reader") + } + } + + //this is t.readers[0] too + err := t.tunFile.Close() + if err != nil { + t.l.WithField("reader", 0).WithError(err).Error("Error closing tun reader") + } else { + t.l.WithField("reader", 0).Info("Closed tun reader") + } + return err }