fancy blocking writes

This commit is contained in:
JackDoan
2026-04-16 13:43:13 -05:00
parent 6b2e6d9f55
commit ba8da0e86c

View File

@@ -29,8 +29,10 @@ import (
// A shared eventfd allows Close to wake all readers blocked in poll. // A shared eventfd allows Close to wake all readers blocked in poll.
type tunFile struct { type tunFile struct {
fd int fd int
shutdownFd int
lastOne bool lastOne bool
pollFds [2]unix.PollFd readPoll [2]unix.PollFd
writePoll [2]unix.PollFd
closed bool closed bool
} }
@@ -41,9 +43,14 @@ func (r *tunFile) newFriend(fd int) (*tunFile, error) {
} }
return &tunFile{ return &tunFile{
fd: fd, fd: fd,
pollFds: [2]unix.PollFd{ shutdownFd: r.shutdownFd,
readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN}, {Fd: int32(fd), Events: unix.POLLIN},
{Fd: r.pollFds[1].Fd, Events: unix.POLLIN}, {Fd: int32(r.shutdownFd), Events: unix.POLLIN},
},
writePoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLOUT},
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
}, },
}, nil }, nil
} }
@@ -60,30 +67,61 @@ func newTunFd(fd int) (*tunFile, error) {
out := &tunFile{ out := &tunFile{
fd: fd, fd: fd,
shutdownFd: shutdownFd,
lastOne: true, lastOne: true,
pollFds: [2]unix.PollFd{ readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN}, {Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(shutdownFd), Events: unix.POLLIN}, {Fd: int32(shutdownFd), Events: unix.POLLIN},
}, },
writePoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLOUT},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
} }
return out, nil return out, nil
} }
func (r *tunFile) block() error { func (r *tunFile) blockOnRead() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error var err error
for { for {
_, err = unix.Poll(r.pollFds[:], -1) _, err = unix.Poll(r.readPoll[:], -1)
if err != unix.EINTR { if err != unix.EINTR {
break break
} }
} }
//always reset these! //always reset these!
tunEvents := r.pollFds[0].Revents tunEvents := r.readPoll[0].Revents
shutdownEvents := r.pollFds[1].Revents shutdownEvents := r.readPoll[1].Revents
r.pollFds[0].Revents = 0 r.readPoll[0].Revents = 0
r.pollFds[1].Revents = 0 r.readPoll[1].Revents = 0
//do the err check before trusting the potentially bogus bits we just got
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
} else if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (r *tunFile) blockOnWrite() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(r.writePoll[:], -1)
if err != unix.EINTR {
break
}
}
//always reset these!
tunEvents := r.writePoll[0].Revents
shutdownEvents := r.writePoll[1].Revents
r.writePoll[0].Revents = 0
r.writePoll[1].Revents = 0
//do the err check before trusting the potentially bogus bits we just got //do the err check before trusting the potentially bogus bits we just got
if err != nil { if err != nil {
return err return err
@@ -101,7 +139,7 @@ func (r *tunFile) Read(buf []byte) (int, error) {
if n, err := unix.Read(r.fd, buf); err == nil { if n, err := unix.Read(r.fd, buf); err == nil {
return n, nil return n, nil
} else if err == unix.EAGAIN { } else if err == unix.EAGAIN {
if err = r.block(); err != nil { if err = r.blockOnRead(); err != nil {
return 0, err return 0, err
} }
continue continue
@@ -112,13 +150,26 @@ func (r *tunFile) Read(buf []byte) (int, error) {
} }
func (r *tunFile) Write(buf []byte) (int, error) { func (r *tunFile) Write(buf []byte) (int, error) {
return unix.Write(r.fd, buf) for {
if n, err := unix.Write(r.fd, buf); err == nil {
return n, nil
} else if err == unix.EAGAIN {
if err = r.blockOnWrite(); err != nil {
return 0, err
}
continue
} else if err == unix.EINTR {
continue
} else {
return 0, err
}
}
} }
func (r *tunFile) wakeForShutdown() error { func (r *tunFile) wakeForShutdown() error {
var buf [8]byte var buf [8]byte
binary.NativeEndian.PutUint64(buf[:], 1) binary.NativeEndian.PutUint64(buf[:], 1)
_, err := unix.Write(int(r.pollFds[1].Fd), buf[:]) _, err := unix.Write(int(r.readPoll[1].Fd), buf[:])
return err return err
} }
@@ -128,7 +179,7 @@ func (r *tunFile) Close() error {
} }
r.closed = true r.closed = true
if r.lastOne { if r.lastOne {
_ = unix.Close(int(r.pollFds[1].Fd)) _ = unix.Close(r.shutdownFd)
} }
return unix.Close(r.fd) return unix.Close(r.fd)
} }