diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index d5c6c745..6d7d9fb8 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -28,10 +28,12 @@ import ( // tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking. // A shared eventfd allows Close to wake all readers blocked in poll. type tunFile struct { - fd int - lastOne bool - pollFds [2]unix.PollFd - closed bool + fd int + shutdownFd int + lastOne bool + readPoll [2]unix.PollFd + writePoll [2]unix.PollFd + closed bool } // newFriend makes a tunFile for a MultiQueueReader that copies the shutdown eventfd from the parent tun @@ -40,10 +42,15 @@ func (r *tunFile) newFriend(fd int) (*tunFile, error) { return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) } return &tunFile{ - fd: fd, - pollFds: [2]unix.PollFd{ + fd: fd, + shutdownFd: r.shutdownFd, + readPoll: [2]unix.PollFd{ {Fd: int32(fd), Events: unix.POLLIN}, - {Fd: r.pollFds[1].Fd, Events: unix.POLLIN}, + {Fd: int32(r.shutdownFd), Events: unix.POLLIN}, + }, + writePoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLOUT}, + {Fd: int32(r.shutdownFd), Events: unix.POLLIN}, }, }, nil } @@ -59,31 +66,62 @@ func newTunFd(fd int) (*tunFile, error) { } out := &tunFile{ - fd: fd, - lastOne: true, - pollFds: [2]unix.PollFd{ + fd: fd, + shutdownFd: shutdownFd, + lastOne: true, + readPoll: [2]unix.PollFd{ {Fd: int32(fd), Events: unix.POLLIN}, {Fd: int32(shutdownFd), Events: unix.POLLIN}, }, + writePoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLOUT}, + {Fd: int32(shutdownFd), Events: unix.POLLIN}, + }, } return out, nil } -func (r *tunFile) block() error { +func (r *tunFile) blockOnRead() error { const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR var err error for { - _, err = unix.Poll(r.pollFds[:], -1) + _, err = unix.Poll(r.readPoll[:], -1) if err != unix.EINTR { break } } //always reset these! - tunEvents := r.pollFds[0].Revents - shutdownEvents := r.pollFds[1].Revents - r.pollFds[0].Revents = 0 - r.pollFds[1].Revents = 0 + tunEvents := r.readPoll[0].Revents + shutdownEvents := r.readPoll[1].Revents + r.readPoll[0].Revents = 0 + r.readPoll[1].Revents = 0 + //do the err check before trusting the potentially bogus bits we just got + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } else if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil +} + +func (r *tunFile) blockOnWrite() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error + for { + _, err = unix.Poll(r.writePoll[:], -1) + if err != unix.EINTR { + break + } + } + //always reset these! + tunEvents := r.writePoll[0].Revents + shutdownEvents := r.writePoll[1].Revents + r.writePoll[0].Revents = 0 + r.writePoll[1].Revents = 0 //do the err check before trusting the potentially bogus bits we just got if err != nil { return err @@ -101,7 +139,7 @@ func (r *tunFile) Read(buf []byte) (int, error) { if n, err := unix.Read(r.fd, buf); err == nil { return n, nil } else if err == unix.EAGAIN { - if err = r.block(); err != nil { + if err = r.blockOnRead(); err != nil { return 0, err } continue @@ -112,13 +150,26 @@ func (r *tunFile) Read(buf []byte) (int, error) { } func (r *tunFile) Write(buf []byte) (int, error) { - return unix.Write(r.fd, buf) + for { + if n, err := unix.Write(r.fd, buf); err == nil { + return n, nil + } else if err == unix.EAGAIN { + if err = r.blockOnWrite(); err != nil { + return 0, err + } + continue + } else if err == unix.EINTR { + continue + } else { + return 0, err + } + } } func (r *tunFile) wakeForShutdown() error { var buf [8]byte binary.NativeEndian.PutUint64(buf[:], 1) - _, err := unix.Write(int(r.pollFds[1].Fd), buf[:]) + _, err := unix.Write(int(r.readPoll[1].Fd), buf[:]) return err } @@ -128,7 +179,7 @@ func (r *tunFile) Close() error { } r.closed = true if r.lastOne { - _ = unix.Close(int(r.pollFds[1].Fd)) + _ = unix.Close(r.shutdownFd) } return unix.Close(r.fd) }