fancy blocking writes

This commit is contained in:
JackDoan
2026-04-16 13:43:13 -05:00
parent 6b2e6d9f55
commit ba8da0e86c

View File

@@ -29,8 +29,10 @@ import (
// A shared eventfd allows Close to wake all readers blocked in poll.
type tunFile struct {
fd int
shutdownFd int
lastOne bool
pollFds [2]unix.PollFd
readPoll [2]unix.PollFd
writePoll [2]unix.PollFd
closed bool
}
@@ -41,9 +43,14 @@ func (r *tunFile) newFriend(fd int) (*tunFile, error) {
}
return &tunFile{
fd: fd,
pollFds: [2]unix.PollFd{
shutdownFd: r.shutdownFd,
readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN},
{Fd: r.pollFds[1].Fd, Events: unix.POLLIN},
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
},
writePoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLOUT},
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
},
}, nil
}
@@ -60,30 +67,61 @@ func newTunFd(fd int) (*tunFile, error) {
out := &tunFile{
fd: fd,
shutdownFd: shutdownFd,
lastOne: true,
pollFds: [2]unix.PollFd{
readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
writePoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLOUT},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
}
return out, nil
}
func (r *tunFile) block() error {
func (r *tunFile) blockOnRead() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(r.pollFds[:], -1)
_, err = unix.Poll(r.readPoll[:], -1)
if err != unix.EINTR {
break
}
}
//always reset these!
tunEvents := r.pollFds[0].Revents
shutdownEvents := r.pollFds[1].Revents
r.pollFds[0].Revents = 0
r.pollFds[1].Revents = 0
tunEvents := r.readPoll[0].Revents
shutdownEvents := r.readPoll[1].Revents
r.readPoll[0].Revents = 0
r.readPoll[1].Revents = 0
//do the err check before trusting the potentially bogus bits we just got
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
} else if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (r *tunFile) blockOnWrite() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(r.writePoll[:], -1)
if err != unix.EINTR {
break
}
}
//always reset these!
tunEvents := r.writePoll[0].Revents
shutdownEvents := r.writePoll[1].Revents
r.writePoll[0].Revents = 0
r.writePoll[1].Revents = 0
//do the err check before trusting the potentially bogus bits we just got
if err != nil {
return err
@@ -101,7 +139,7 @@ func (r *tunFile) Read(buf []byte) (int, error) {
if n, err := unix.Read(r.fd, buf); err == nil {
return n, nil
} else if err == unix.EAGAIN {
if err = r.block(); err != nil {
if err = r.blockOnRead(); err != nil {
return 0, err
}
continue
@@ -112,13 +150,26 @@ func (r *tunFile) Read(buf []byte) (int, error) {
}
func (r *tunFile) Write(buf []byte) (int, error) {
return unix.Write(r.fd, buf)
for {
if n, err := unix.Write(r.fd, buf); err == nil {
return n, nil
} else if err == unix.EAGAIN {
if err = r.blockOnWrite(); err != nil {
return 0, err
}
continue
} else if err == unix.EINTR {
continue
} else {
return 0, err
}
}
}
func (r *tunFile) wakeForShutdown() error {
var buf [8]byte
binary.NativeEndian.PutUint64(buf[:], 1)
_, err := unix.Write(int(r.pollFds[1].Fd), buf[:])
_, err := unix.Write(int(r.readPoll[1].Fd), buf[:])
return err
}
@@ -128,7 +179,7 @@ func (r *tunFile) Close() error {
}
r.closed = true
if r.lastOne {
_ = unix.Close(int(r.pollFds[1].Fd))
_ = unix.Close(r.shutdownFd)
}
return unix.Close(r.fd)
}