fix tests

This commit is contained in:
JackDoan
2026-04-23 11:35:51 -05:00
parent 382b15ac52
commit f76ac2e216
8 changed files with 97 additions and 164 deletions

View File

@@ -8,21 +8,21 @@ import (
"golang.org/x/sys/unix"
)
type gsoContainer struct {
pq []*tunFile
type offloadContainer struct {
pq []*Offload
// pqi is exactly the same as pq, but stored as the interface type
pqi []Queue
shutdownFd int
}
func NewGSOContainer() (Container, error) {
func NewOffloadContainer() (Container, error) {
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
if err != nil {
return nil, fmt.Errorf("failed to create eventfd: %w", err)
}
out := &gsoContainer{
pq: []*tunFile{},
out := &offloadContainer{
pq: []*Offload{},
pqi: []Queue{},
shutdownFd: shutdownFd,
}
@@ -30,12 +30,12 @@ func NewGSOContainer() (Container, error) {
return out, nil
}
func (c *gsoContainer) Queues() []Queue {
func (c *offloadContainer) Queues() []Queue {
return c.pqi
}
func (c *gsoContainer) Add(fd int) error {
x, err := newTunFd(fd, c.shutdownFd)
func (c *offloadContainer) Add(fd int) error {
x, err := newOffload(fd, c.shutdownFd)
if err != nil {
return err
}
@@ -45,14 +45,14 @@ func (c *gsoContainer) Add(fd int) error {
return nil
}
func (c *gsoContainer) wakeForShutdown() error {
func (c *offloadContainer) wakeForShutdown() error {
var buf [8]byte
binary.NativeEndian.PutUint64(buf[:], 1)
_, err := unix.Write(int(c.shutdownFd), buf[:])
_, err := unix.Write(c.shutdownFd, buf[:])
return err
}
func (c *gsoContainer) Close() error {
func (c *offloadContainer) Close() error {
errs := []error{}
// Signal all readers blocked in poll to wake up and exit

View File

@@ -4,7 +4,6 @@ import (
"fmt"
"io"
"os"
"runtime"
"sync/atomic"
"syscall"
"unsafe"
@@ -28,7 +27,7 @@ const tunSegBufCap = tunSegBufSize * 2
const tunDrainCap = 64
// gsoInitialPayIovs is the starting capacity (in payload fragments) of
// tunFile.gsoIovs. Sized to cover the default coalesce segment cap without
// Offload.gsoIovs. Sized to cover the default coalesce segment cap without
// any reallocations.
const gsoInitialPayIovs = 66
@@ -42,9 +41,9 @@ const gsoInitialPayIovs = 66
// safe.
var validVnetHdr = [virtioNetHdrLen]byte{unix.VIRTIO_NET_HDR_F_DATA_VALID}
// tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking.
// Offload wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking.
// A shared eventfd allows Close to wake all readers blocked in poll.
type tunFile struct { //todo rename GSO
type Offload struct {
fd int
shutdownFd int
readPoll [2]unix.PollFd
@@ -71,12 +70,12 @@ type tunFile struct { //todo rename GSO
gsoIovs []unix.Iovec
}
func newTunFd(fd int, shutdownFd int) (*tunFile, error) {
func newOffload(fd int, shutdownFd int) (*Offload, error) {
if err := unix.SetNonblock(fd, true); err != nil {
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
}
out := &tunFile{
out := &Offload{
fd: fd,
shutdownFd: shutdownFd,
closed: atomic.Bool{},
@@ -104,7 +103,7 @@ func newTunFd(fd int, shutdownFd int) (*tunFile, error) {
return out, nil
}
func (r *tunFile) blockOnRead() error {
func (r *Offload) blockOnRead() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
@@ -130,7 +129,7 @@ func (r *tunFile) blockOnRead() error {
return nil
}
func (r *tunFile) blockOnWrite() error {
func (r *Offload) blockOnWrite() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
@@ -156,7 +155,7 @@ func (r *tunFile) blockOnWrite() error {
return nil
}
func (r *tunFile) readRaw(buf []byte) (int, error) {
func (r *Offload) readRaw(buf []byte) (int, error) {
for {
if n, err := unix.Read(r.fd, buf); err == nil {
return n, nil
@@ -180,9 +179,9 @@ func (r *tunFile) readRaw(buf []byte) (int, error) {
// readable we drain additional packets non-blocking until the kernel queue
// is empty (EAGAIN), we've collected tunDrainCap packets, or we're out of
// segBuf headroom. This amortizes the poll wake over bursts of small
// packets (e.g. TCP ACKs). Slices point into the tunFile's internal buffers
// packets (e.g. TCP ACKs). Slices point into the Offload's internal buffers
// and are only valid until the next Read or Close on this Queue.
func (r *tunFile) Read() ([][]byte, error) {
func (r *Offload) Read() ([][]byte, error) {
r.pending = r.pending[:0]
r.segOff = 0
@@ -226,7 +225,7 @@ func (r *tunFile) Read() ([][]byte, error) {
// decodeRead decodes the virtio header plus payload in r.readBuf[:n], appends
// the segments to r.pending, and advances r.segOff by the total scratch used.
// Caller must have already ensured r.vnetHdr is true.
func (r *tunFile) decodeRead(n int) error {
func (r *Offload) decodeRead(n int) error {
if n < virtioNetHdrLen {
return fmt.Errorf("short tun read: %d < %d", n, virtioNetHdrLen)
}
@@ -242,7 +241,7 @@ func (r *tunFile) decodeRead(n int) error {
return nil
}
func (r *tunFile) Write(buf []byte) (int, error) {
func (r *Offload) Write(buf []byte) (int, error) {
return r.writeWithScratch(buf, &r.writeIovs)
}
@@ -250,36 +249,33 @@ func (r *tunFile) Write(buf []byte) (int, error) {
// distinct from the one used by the coalescer's Write path. This avoids a
// data race between the inside (listenIn) goroutine emitting reject or
// self-forward packets and the outside (listenOut) goroutine flushing TCP
// coalescer passthroughs on the same tunFile.
func (r *tunFile) WriteReject(buf []byte) (int, error) {
// coalescer passthroughs on the same Offload.
func (r *Offload) WriteReject(buf []byte) (int, error) {
return r.writeWithScratch(buf, &r.rejectIovs)
}
func (r *tunFile) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error) {
func (r *Offload) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error) {
if len(buf) == 0 {
return 0, nil
}
// Point the payload iovec at the caller's buffer. iovs[0] is pre-wired
// to validVnetHdr during tunFile construction so we don't rebuild it here.
// to validVnetHdr during Offload construction so we don't rebuild it here.
iovs[1].Base = &buf[0]
iovs[1].SetLen(len(buf))
iovPtr := uintptr(unsafe.Pointer(&iovs[0]))
// The TUN fd is non-blocking (set in newTunFd / newFriend), so writev
// either completes promptly or returns EAGAIN — it cannot park the
// goroutine inside the kernel. That lets us use syscall.RawSyscall and
// skip the runtime.entersyscall / exitsyscall bookkeeping on every
// packet; we only pay that cost when we fall through to blockOnWrite.
iovPtr := unsafe.Pointer(&iovs[0])
return r.rawWrite(iovPtr, 2)
}
func (r *Offload) rawWrite(iovs unsafe.Pointer, iovcnt int) (int, error) {
for {
n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, 2)
n, _, errno := syscall.Syscall(unix.SYS_WRITEV, uintptr(r.fd), uintptr(iovs), uintptr(iovcnt))
if errno == 0 {
runtime.KeepAlive(buf)
if int(n) < virtioNetHdrLen {
return 0, io.ErrShortWrite
}
return int(n) - virtioNetHdrLen, nil
}
if errno == unix.EAGAIN {
runtime.KeepAlive(buf)
if err := r.blockOnWrite(); err != nil {
return 0, err
}
@@ -291,7 +287,6 @@ func (r *tunFile) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error)
if errno == unix.EBADF {
return 0, os.ErrClosed
}
runtime.KeepAlive(buf)
return 0, errno
}
}
@@ -299,7 +294,7 @@ func (r *tunFile) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error)
// GSOSupported reports whether this queue was opened with IFF_VNET_HDR and
// can accept WriteGSO. When false, callers should fall back to per-segment
// Write calls.
func (r *tunFile) GSOSupported() bool { return true }
func (r *Offload) GSOSupported() bool { return true }
// WriteGSO emits a TCP TSO superpacket in a single writev. hdr is the
// IPv4/IPv6 + TCP header prefix (already finalized — total length, IP csum,
@@ -308,7 +303,7 @@ func (r *tunFile) GSOSupported() bool { return true }
// slice is read-only and must stay valid until return. gsoSize is the MSS;
// every segment except possibly the last is exactly gsoSize bytes.
// csumStart is the byte offset where the TCP header begins within hdr.
func (r *tunFile) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error {
func (r *Offload) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error {
if len(hdr) == 0 || len(pays) == 0 {
return nil
}
@@ -356,45 +351,18 @@ func (r *tunFile) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool,
r.gsoIovs[2+i].SetLen(len(p))
}
iovPtr := uintptr(unsafe.Pointer(&r.gsoIovs[0]))
iovCnt := uintptr(len(r.gsoIovs))
for {
n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, iovCnt)
if errno == 0 {
runtime.KeepAlive(hdr)
runtime.KeepAlive(pays)
if int(n) < virtioNetHdrLen {
return io.ErrShortWrite
}
return nil
}
if errno == unix.EAGAIN {
runtime.KeepAlive(hdr)
runtime.KeepAlive(pays)
if err := r.blockOnWrite(); err != nil {
return err
}
continue
}
if errno == unix.EINTR {
continue
}
if errno == unix.EBADF {
return os.ErrClosed
}
runtime.KeepAlive(hdr)
runtime.KeepAlive(pays)
return errno
}
iovPtr := unsafe.Pointer(&r.gsoIovs[0])
iovCnt := len(r.gsoIovs)
_, err := r.rawWrite(iovPtr, iovCnt)
return err
}
func (r *tunFile) Close() error {
func (r *Offload) Close() error {
if r.closed.Swap(true) {
return nil
}
//shutdownFd is owned by the container, so we should not close it
var err error
if r.fd >= 0 {
err = unix.Close(r.fd)

View File

@@ -10,11 +10,12 @@ import (
"testing"
"time"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)
// newReadPipe returns a read fd. The matching write fd is registered for cleanup.
// The caller takes ownership of the read fd (pass it to newTunFd / newFriend).
// The caller takes ownership of the read fd (pass it to newOffload / newFriend).
func newReadPipe(t *testing.T) int {
t.Helper()
var fds [2]int
@@ -25,70 +26,35 @@ func newReadPipe(t *testing.T) int {
return fds[0]
}
func TestTunFile_WakeForShutdown_UnblocksRead(t *testing.T) {
tf, err := newTunFd(newReadPipe(t))
func TestOffload_WakeForShutdown_WakesFriends(t *testing.T) {
pipe1 := newReadPipe(t)
pipe2 := newReadPipe(t)
parent, err := NewOffloadContainer()
if err != nil {
t.Fatalf("newTunFd: %v", err)
}
t.Cleanup(func() { _ = tf.Close() })
done := make(chan error, 1)
go func() {
_, err := tf.Read(make([]byte, 64))
done <- err
}()
// Verify Read is actually blocked in poll.
select {
case err := <-done:
t.Fatalf("Read returned before shutdown signal: %v", err)
case <-time.After(50 * time.Millisecond):
}
if err := tf.wakeForShutdown(); err != nil {
t.Fatalf("wakeForShutdown: %v", err)
}
select {
case err := <-done:
if !errors.Is(err, os.ErrClosed) {
t.Fatalf("expected os.ErrClosed, got %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("Read did not wake on shutdown")
}
}
func TestTunFile_WakeForShutdown_WakesFriends(t *testing.T) {
parent, err := newTunFd(newReadPipe(t))
if err != nil {
t.Fatalf("newTunFd: %v", err)
}
friend, err := parent.newFriend(newReadPipe(t))
if err != nil {
_ = parent.Close()
t.Fatalf("newFriend: %v", err)
t.Fatalf("newOffload: %v", err)
}
require.NoError(t, parent.Add(pipe1))
require.NoError(t, parent.Add(pipe2))
t.Cleanup(func() {
_ = friend.Close()
_ = parent.Close()
_ = unix.Close(pipe1)
_ = unix.Close(pipe2)
})
readers := []*tunFile{parent, friend}
readers := parent.Queues()
errs := make([]error, len(readers))
var wg sync.WaitGroup
for i, r := range readers {
wg.Add(1)
go func(i int, r *tunFile) {
go func(i int, r Queue) {
defer wg.Done()
_, errs[i] = r.Read(make([]byte, 64))
_, errs[i] = r.Read()
}(i, r)
}
time.Sleep(50 * time.Millisecond)
if err := parent.wakeForShutdown(); err != nil {
t.Fatalf("wakeForShutdown: %v", err)
if err := parent.Close(); err != nil {
t.Fatalf("Close: %v", err)
}
done := make(chan struct{})
@@ -107,9 +73,9 @@ func TestTunFile_WakeForShutdown_WakesFriends(t *testing.T) {
}
func TestTunFile_Close_Idempotent(t *testing.T) {
tf, err := newTunFd(newReadPipe(t))
tf, err := newOffload(newReadPipe(t), 1)
if err != nil {
t.Fatalf("newTunFd: %v", err)
t.Fatalf("newOffload: %v", err)
}
if err := tf.Close(); err != nil {
t.Fatalf("first Close: %v", err)

View File

@@ -309,7 +309,7 @@ func TestTunFileWriteVnetHdrNoAlloc(t *testing.T) {
}
t.Cleanup(func() { _ = unix.Close(fd) })
tf := &tunFile{fd: fd}
tf := &Offload{fd: fd}
tf.writeIovs[0].Base = &validVnetHdr[0]
tf.writeIovs[0].SetLen(virtioNetHdrLen)

View File

@@ -1,38 +0,0 @@
//go:build !e2e_testing
// +build !e2e_testing
package tio
import (
"testing"
"github.com/slackhq/nebula/overlay"
)
var runAdvMSSTests = []struct {
name string
tun *overlay.tun
r overlay.Route
expected int
}{
// Standard case, default MTU is the device max MTU
{"default", &overlay.tun{DefaultMTU: 1440, MaxMTU: 1440}, overlay.Route{}, 0},
{"default-min", &overlay.tun{DefaultMTU: 1440, MaxMTU: 1440}, overlay.Route{MTU: 1440}, 0},
{"default-low", &overlay.tun{DefaultMTU: 1440, MaxMTU: 1440}, overlay.Route{MTU: 1200}, 1160},
// Case where we have a route MTU set higher than the default
{"route", &overlay.tun{DefaultMTU: 1440, MaxMTU: 8941}, overlay.Route{}, 1400},
{"route-min", &overlay.tun{DefaultMTU: 1440, MaxMTU: 8941}, overlay.Route{MTU: 1440}, 1400},
{"route-high", &overlay.tun{DefaultMTU: 1440, MaxMTU: 8941}, overlay.Route{MTU: 8941}, 0},
}
func TestTunAdvMSS(t *testing.T) {
for _, tt := range runAdvMSSTests {
t.Run(tt.name, func(t *testing.T) {
o := tt.tun.advMSS(tt.r)
if o != tt.expected {
t.Errorf("got %d, want %d", o, tt.expected)
}
})
}
}

View File

@@ -176,7 +176,7 @@ func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vnetHdr bool, vpnNetwo
var container tio.Container
var err error
if vnetHdr {
container, err = tio.NewGSOContainer()
container, err = tio.NewOffloadContainer()
} else {
container, err = tio.NewPollContainer()
}

36
overlay/tun_linux_test.go Normal file
View File

@@ -0,0 +1,36 @@
//go:build !e2e_testing
// +build !e2e_testing
package overlay
import (
"testing"
)
var runAdvMSSTests = []struct {
name string
tun *tun
r Route
expected int
}{
// Standard case, default MTU is the device max MTU
{"default", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0},
{"default-min", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0},
{"default-low", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160},
// Case where we have a route MTU set higher than the default
{"route", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400},
{"route-min", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400},
{"route-high", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0},
}
func TestTunAdvMSS(t *testing.T) {
for _, tt := range runAdvMSSTests {
t.Run(tt.name, func(t *testing.T) {
o := tt.tun.advMSS(tt.r)
if o != tt.expected {
t.Errorf("got %d, want %d", o, tt.expected)
}
})
}
}

View File

@@ -24,6 +24,7 @@ func NewUserDevice(vpnNetworks []netip.Prefix) (Device, error) {
outboundWriter: ow,
inboundReader: ir,
inboundWriter: iw,
numReaders: 1,
}, nil
}