better and batched tun interface

This commit is contained in:
JackDoan
2026-04-17 10:25:05 -05:00
parent 398d67e2da
commit 4b4331ba42
34 changed files with 1189 additions and 483 deletions

View File

@@ -0,0 +1,69 @@
package tio
import (
"encoding/binary"
"errors"
"fmt"
"golang.org/x/sys/unix"
)
type pollContainer struct {
pq []*Poll
// pqi is exactly the same as pq, but stored as the interface type
pqi []Queue
shutdownFd int
}
func NewPollContainer() (Container, error) {
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
if err != nil {
return nil, fmt.Errorf("failed to create eventfd: %w", err)
}
out := &pollContainer{
pq: []*Poll{},
pqi: []Queue{},
shutdownFd: shutdownFd,
}
return out, nil
}
func (c *pollContainer) Queues() []Queue {
return c.pqi
}
func (c *pollContainer) Add(fd int) error {
x, err := newPoll(fd, c.shutdownFd)
if err != nil {
return err
}
c.pq = append(c.pq, x)
c.pqi = append(c.pqi, x)
return nil
}
func (c *pollContainer) wakeForShutdown() error {
var buf [8]byte
binary.NativeEndian.PutUint64(buf[:], 1)
_, err := unix.Write(int(c.shutdownFd), buf[:])
return err
}
func (c *pollContainer) Close() error {
errs := []error{}
if err := c.wakeForShutdown(); err != nil {
errs = append(errs, err)
}
for _, x := range c.pq {
if err := x.Close(); err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}

56
overlay/tio/tio.go Normal file
View File

@@ -0,0 +1,56 @@
package tio
import "io"
// defaultBatchBufSize is the per-Queue scratch size for Read on backends
// that don't do TSO segmentation. 65535 covers any single IP packet.
const defaultBatchBufSize = 65535
// Container holds one or many Queue objects and helps close them in an orderly way
type Container interface {
io.Closer
Queues() []Queue
// Add takes a tun fd, adds it to the container, and prepares it for use as a Queue
Add(fd int) error
}
// Queue is a readable/writable Poll queue. One Queue is driven by a single
// read goroutine plus concurrent writers (see Write / WriteReject below).
type Queue interface {
io.Closer
// Read returns one or more packets. The returned slices are borrowed
// from the Queue's internal buffer and are only valid until the next
// Read or Close on this Queue - callers must encrypt or copy each
// slice before the next call. Not safe for concurrent Reads.
Read() ([][]byte, error)
// Write emits a single packet on the plaintext (outside→inside)
// delivery path. Not safe for concurrent Writes.
Write(p []byte) (int, error)
}
// GSOWriter is implemented by Queues that can emit a TCP TSO superpacket
// assembled from a header prefix plus one or more borrowed payload
// fragments, in a single vectored write (writev with a leading
// virtio_net_hdr). This lets the coalescer avoid copying payload bytes
// between the caller's decrypt buffer and the TUN. Backends without GSO
// support return false from GSOSupported and coalescing is skipped.
//
// hdr contains the IPv4/IPv6 + TCP header prefix (mutable - callers will
// have filled in total length and pseudo-header partial). pays are
// non-overlapping payload fragments whose concatenation is the full
// superpacket payload; they are read-only from the writer's perspective
// and must remain valid until the call returns. gsoSize is the MSS:
// every segment except possibly the last is exactly that many bytes.
// csumStart is the byte offset where the TCP header begins within hdr.
//
// # TODO fold into Queue
//
// hdr's TCP checksum field must already hold the pseudo-header partial
// sum (single-fold, not inverted), per virtio NEEDS_CSUM semantics.
type GSOWriter interface {
WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error
GSOSupported() bool
}

View File

@@ -0,0 +1,164 @@
package tio
import (
"fmt"
"os"
"sync/atomic"
"golang.org/x/sys/unix"
)
// Maximum size we accept for a single read from a TUN with IFF_VNET_HDR. A
// TSO superpacket can be up to 64KiB of payload plus a single L2/L3/L4 header
// prefix plus the virtio header.
const tunReadBufSize = 65535
type Poll struct {
fd int
readPoll [2]unix.PollFd
writePoll [2]unix.PollFd
closed atomic.Bool
readBuf []byte
batchRet [1][]byte
}
func newPoll(fd int, shutdownFd int) (*Poll, error) {
if err := unix.SetNonblock(fd, true); err != nil {
_ = unix.Close(fd)
return nil, fmt.Errorf("failed to set Poll device as nonblocking: %w", err)
}
out := &Poll{
fd: fd,
readBuf: make([]byte, tunReadBufSize),
readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
writePoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLOUT},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
}
return out, nil
}
// blockOnRead waits until the Poll fd is readable or shutdown has been signaled.
// Returns os.ErrClosed if Close was called.
func (t *Poll) blockOnRead() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(t.readPoll[:], -1)
if err != unix.EINTR {
break
}
}
tunEvents := t.readPoll[0].Revents
shutdownEvents := t.readPoll[1].Revents
t.readPoll[0].Revents = 0
t.readPoll[1].Revents = 0
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
}
if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (t *Poll) blockOnWrite() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(t.writePoll[:], -1)
if err != unix.EINTR {
break
}
}
tunEvents := t.writePoll[0].Revents
shutdownEvents := t.writePoll[1].Revents
t.writePoll[0].Revents = 0
t.writePoll[1].Revents = 0
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
}
if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (t *Poll) Read() ([][]byte, error) {
n, err := t.readOne(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
return t.batchRet[:], nil
}
func (t *Poll) readOne(to []byte) (int, error) {
for {
n, errno := unix.Read(t.fd, to)
if errno == nil {
return n, nil
}
switch errno {
case unix.EAGAIN:
if err := t.blockOnRead(); err != nil {
return 0, err
}
case unix.EINTR:
// retry
case unix.EBADF:
return 0, os.ErrClosed
default:
return 0, errno
}
}
}
// Write is only valid for single threaded use
func (t *Poll) Write(from []byte) (int, error) {
for {
n, errno := unix.Write(t.fd, from)
if errno == nil {
return n, nil
}
switch errno {
case unix.EAGAIN:
if err := t.blockOnWrite(); err != nil {
return 0, err
}
case unix.EINTR:
// retry
case unix.EBADF:
return 0, os.ErrClosed
default:
return 0, errno
}
}
}
func (t *Poll) Close() error {
if t.closed.Swap(true) {
return nil
}
//shutdownFd is owned by the container, so we should not close it
var err error
if t.fd >= 0 {
err = unix.Close(t.fd)
t.fd = -1
}
return err
}

View File

@@ -0,0 +1,82 @@
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
package tio
import (
"errors"
"os"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)
// newReadPipe returns a read fd. The matching write fd is registered for cleanup.
// The caller takes ownership of the read fd (pass it to newOffload / newFriend).
func newReadPipe(t *testing.T) int {
t.Helper()
var fds [2]int
if err := unix.Pipe2(fds[:], unix.O_CLOEXEC); err != nil {
t.Fatalf("pipe2: %v", err)
}
t.Cleanup(func() { _ = unix.Close(fds[1]) })
return fds[0]
}
func TestPoll_WakeForShutdown_WakesFriends(t *testing.T) {
pipe1 := newReadPipe(t)
pipe2 := newReadPipe(t)
parent, err := NewPollContainer()
require.NoError(t, err)
require.NoError(t, parent.Add(pipe1))
require.NoError(t, parent.Add(pipe2))
t.Cleanup(func() {
_ = unix.Close(pipe1)
_ = unix.Close(pipe2)
})
readers := parent.Queues()
errs := make([]error, len(readers))
var wg sync.WaitGroup
for i, r := range readers {
wg.Add(1)
go func(i int, r Queue) {
defer wg.Done()
_, errs[i] = r.Read()
}(i, r)
}
time.Sleep(50 * time.Millisecond)
if err := parent.Close(); err != nil {
t.Fatalf("Close: %v", err)
}
done := make(chan struct{})
go func() { wg.Wait(); close(done) }()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("readers did not wake")
}
for i, err := range errs {
if !errors.Is(err, os.ErrClosed) {
t.Errorf("reader %d: expected os.ErrClosed, got %v", i, err)
}
}
}
func TestPoll_Close_Idempotent(t *testing.T) {
tf, err := newPoll(newReadPipe(t), 1)
require.NoError(t, err)
if err := tf.Close(); err != nil {
t.Fatalf("first Close: %v", err)
}
if err := tf.Close(); err != nil {
t.Fatalf("second Close should be a no-op, got %v", err)
}
}