batch tun reads

This commit is contained in:
Jay Wren
2026-02-03 17:12:44 -05:00
parent 15333f9fed
commit 30db76ed79
5 changed files with 447 additions and 7 deletions

View File

@@ -34,6 +34,7 @@ type tun struct {
TXQueueLen int
deviceIndex int
ioctlFd uintptr
nonBlocking bool // true if fd is in non-blocking mode
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
@@ -239,7 +240,7 @@ func (t *tun) SupportsMultiqueue() bool {
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
fd, err := unix.Open("/dev/net/tun", os.O_RDWR|unix.O_NONBLOCK, 0)
if err != nil {
return nil, err
}
@@ -251,9 +252,99 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, err
}
file := os.NewFile(uintptr(fd), "/dev/net/tun")
return &tunBatchReader{fd: fd, device: t.Device}, nil
}
return file, nil
// tunBatchReader implements BatchReader for efficient batch packet reading
type tunBatchReader struct {
fd int
device string
}
func (r *tunBatchReader) Read(b []byte) (int, error) {
// Use poll to wait for data, then read
for {
n, err := unix.Read(r.fd, b)
if err == nil {
return n, nil
}
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
// Wait for data
pfds := []unix.PollFd{{Fd: int32(r.fd), Events: unix.POLLIN}}
_, err = unix.Poll(pfds, -1)
if err != nil {
if err == unix.EINTR {
continue
}
return 0, err
}
continue
}
return n, err
}
}
func (r *tunBatchReader) Write(b []byte) (int, error) {
return unix.Write(r.fd, b)
}
func (r *tunBatchReader) Close() error {
return unix.Close(r.fd)
}
// ReadBatch reads up to len(packets) packets from the TUN device.
// It drains all available packets without blocking, using poll() only
// when no packets have been read yet.
func (r *tunBatchReader) ReadBatch(packets [][]byte, sizes []int) (int, error) {
count := 0
maxPackets := len(packets)
if len(sizes) < maxPackets {
maxPackets = len(sizes)
}
for count < maxPackets {
n, err := unix.Read(r.fd, packets[count])
if err == nil && n > 0 {
sizes[count] = n
count++
continue
}
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
// No more packets available
if count > 0 {
// We have some packets, return them
return count, nil
}
// No packets yet, wait for at least one
pfds := []unix.PollFd{{Fd: int32(r.fd), Events: unix.POLLIN}}
_, err = unix.Poll(pfds, -1)
if err != nil {
if err == unix.EINTR {
continue
}
return 0, err
}
continue
}
if err != nil {
if count > 0 {
// Return what we have
return count, nil
}
return 0, err
}
if n == 0 {
if count > 0 {
return count, nil
}
return 0, io.EOF
}
}
return count, nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
@@ -284,6 +375,111 @@ func (t *tun) Write(b []byte) (int, error) {
}
}
// EnableBatchReading sets the TUN fd to non-blocking mode to enable batch reading.
// This should be called before using ReadBatch.
func (t *tun) EnableBatchReading() error {
if t.nonBlocking {
return nil
}
err := unix.SetNonblock(t.fd, true)
if err != nil {
return err
}
t.nonBlocking = true
return nil
}
// Read overrides the default Read to handle non-blocking mode
func (t *tun) Read(b []byte) (int, error) {
if !t.nonBlocking {
// Use the embedded ReadWriteCloser for blocking reads
return t.ReadWriteCloser.Read(b)
}
// Non-blocking read with poll
for {
n, err := unix.Read(t.fd, b)
if err == nil {
return n, nil
}
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
// Wait for data
pfds := []unix.PollFd{{Fd: int32(t.fd), Events: unix.POLLIN}}
_, err = unix.Poll(pfds, -1)
if err != nil {
if err == unix.EINTR {
continue
}
return 0, err
}
continue
}
return n, err
}
}
// ReadBatch reads up to len(packets) packets from the TUN device.
// EnableBatchReading must be called first.
func (t *tun) ReadBatch(packets [][]byte, sizes []int) (int, error) {
if !t.nonBlocking {
// Fallback to single read if non-blocking not enabled
n, err := t.ReadWriteCloser.Read(packets[0])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
}
count := 0
maxPackets := len(packets)
if len(sizes) < maxPackets {
maxPackets = len(sizes)
}
for count < maxPackets {
n, err := unix.Read(t.fd, packets[count])
if err == nil && n > 0 {
sizes[count] = n
count++
continue
}
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
// No more packets available
if count > 0 {
return count, nil
}
// No packets yet, wait for at least one
pfds := []unix.PollFd{{Fd: int32(t.fd), Events: unix.POLLIN}}
_, err = unix.Poll(pfds, -1)
if err != nil {
if err == unix.EINTR {
continue
}
return 0, err
}
continue
}
if err != nil {
if count > 0 {
return count, nil
}
return 0, err
}
if n == 0 {
if count > 0 {
return count, nil
}
return 0, io.EOF
}
}
return count, nil
}
func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device {
o[i] = byte(c)