fix tests

This commit is contained in:
JackDoan
2026-04-23 11:35:51 -05:00
parent 382b15ac52
commit f76ac2e216
8 changed files with 97 additions and 164 deletions

View File

@@ -8,21 +8,21 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
type gsoContainer struct { type offloadContainer struct {
pq []*tunFile pq []*Offload
// pqi is exactly the same as pq, but stored as the interface type // pqi is exactly the same as pq, but stored as the interface type
pqi []Queue pqi []Queue
shutdownFd int shutdownFd int
} }
func NewGSOContainer() (Container, error) { func NewOffloadContainer() (Container, error) {
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC) shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create eventfd: %w", err) return nil, fmt.Errorf("failed to create eventfd: %w", err)
} }
out := &gsoContainer{ out := &offloadContainer{
pq: []*tunFile{}, pq: []*Offload{},
pqi: []Queue{}, pqi: []Queue{},
shutdownFd: shutdownFd, shutdownFd: shutdownFd,
} }
@@ -30,12 +30,12 @@ func NewGSOContainer() (Container, error) {
return out, nil return out, nil
} }
func (c *gsoContainer) Queues() []Queue { func (c *offloadContainer) Queues() []Queue {
return c.pqi return c.pqi
} }
func (c *gsoContainer) Add(fd int) error { func (c *offloadContainer) Add(fd int) error {
x, err := newTunFd(fd, c.shutdownFd) x, err := newOffload(fd, c.shutdownFd)
if err != nil { if err != nil {
return err return err
} }
@@ -45,14 +45,14 @@ func (c *gsoContainer) Add(fd int) error {
return nil return nil
} }
func (c *gsoContainer) wakeForShutdown() error { func (c *offloadContainer) wakeForShutdown() error {
var buf [8]byte var buf [8]byte
binary.NativeEndian.PutUint64(buf[:], 1) binary.NativeEndian.PutUint64(buf[:], 1)
_, err := unix.Write(int(c.shutdownFd), buf[:]) _, err := unix.Write(c.shutdownFd, buf[:])
return err return err
} }
func (c *gsoContainer) Close() error { func (c *offloadContainer) Close() error {
errs := []error{} errs := []error{}
// Signal all readers blocked in poll to wake up and exit // Signal all readers blocked in poll to wake up and exit

View File

@@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"runtime"
"sync/atomic" "sync/atomic"
"syscall" "syscall"
"unsafe" "unsafe"
@@ -28,7 +27,7 @@ const tunSegBufCap = tunSegBufSize * 2
const tunDrainCap = 64 const tunDrainCap = 64
// gsoInitialPayIovs is the starting capacity (in payload fragments) of // gsoInitialPayIovs is the starting capacity (in payload fragments) of
// tunFile.gsoIovs. Sized to cover the default coalesce segment cap without // Offload.gsoIovs. Sized to cover the default coalesce segment cap without
// any reallocations. // any reallocations.
const gsoInitialPayIovs = 66 const gsoInitialPayIovs = 66
@@ -42,9 +41,9 @@ const gsoInitialPayIovs = 66
// safe. // safe.
var validVnetHdr = [virtioNetHdrLen]byte{unix.VIRTIO_NET_HDR_F_DATA_VALID} var validVnetHdr = [virtioNetHdrLen]byte{unix.VIRTIO_NET_HDR_F_DATA_VALID}
// tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking. // Offload wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking.
// A shared eventfd allows Close to wake all readers blocked in poll. // A shared eventfd allows Close to wake all readers blocked in poll.
type tunFile struct { //todo rename GSO type Offload struct {
fd int fd int
shutdownFd int shutdownFd int
readPoll [2]unix.PollFd readPoll [2]unix.PollFd
@@ -71,12 +70,12 @@ type tunFile struct { //todo rename GSO
gsoIovs []unix.Iovec gsoIovs []unix.Iovec
} }
func newTunFd(fd int, shutdownFd int) (*tunFile, error) { func newOffload(fd int, shutdownFd int) (*Offload, error) {
if err := unix.SetNonblock(fd, true); err != nil { if err := unix.SetNonblock(fd, true); err != nil {
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
} }
out := &tunFile{ out := &Offload{
fd: fd, fd: fd,
shutdownFd: shutdownFd, shutdownFd: shutdownFd,
closed: atomic.Bool{}, closed: atomic.Bool{},
@@ -104,7 +103,7 @@ func newTunFd(fd int, shutdownFd int) (*tunFile, error) {
return out, nil return out, nil
} }
func (r *tunFile) blockOnRead() error { func (r *Offload) blockOnRead() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error var err error
for { for {
@@ -130,7 +129,7 @@ func (r *tunFile) blockOnRead() error {
return nil return nil
} }
func (r *tunFile) blockOnWrite() error { func (r *Offload) blockOnWrite() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error var err error
for { for {
@@ -156,7 +155,7 @@ func (r *tunFile) blockOnWrite() error {
return nil return nil
} }
func (r *tunFile) readRaw(buf []byte) (int, error) { func (r *Offload) readRaw(buf []byte) (int, error) {
for { for {
if n, err := unix.Read(r.fd, buf); err == nil { if n, err := unix.Read(r.fd, buf); err == nil {
return n, nil return n, nil
@@ -180,9 +179,9 @@ func (r *tunFile) readRaw(buf []byte) (int, error) {
// readable we drain additional packets non-blocking until the kernel queue // readable we drain additional packets non-blocking until the kernel queue
// is empty (EAGAIN), we've collected tunDrainCap packets, or we're out of // is empty (EAGAIN), we've collected tunDrainCap packets, or we're out of
// segBuf headroom. This amortizes the poll wake over bursts of small // segBuf headroom. This amortizes the poll wake over bursts of small
// packets (e.g. TCP ACKs). Slices point into the tunFile's internal buffers // packets (e.g. TCP ACKs). Slices point into the Offload's internal buffers
// and are only valid until the next Read or Close on this Queue. // and are only valid until the next Read or Close on this Queue.
func (r *tunFile) Read() ([][]byte, error) { func (r *Offload) Read() ([][]byte, error) {
r.pending = r.pending[:0] r.pending = r.pending[:0]
r.segOff = 0 r.segOff = 0
@@ -226,7 +225,7 @@ func (r *tunFile) Read() ([][]byte, error) {
// decodeRead decodes the virtio header plus payload in r.readBuf[:n], appends // decodeRead decodes the virtio header plus payload in r.readBuf[:n], appends
// the segments to r.pending, and advances r.segOff by the total scratch used. // the segments to r.pending, and advances r.segOff by the total scratch used.
// Caller must have already ensured r.vnetHdr is true. // Caller must have already ensured r.vnetHdr is true.
func (r *tunFile) decodeRead(n int) error { func (r *Offload) decodeRead(n int) error {
if n < virtioNetHdrLen { if n < virtioNetHdrLen {
return fmt.Errorf("short tun read: %d < %d", n, virtioNetHdrLen) return fmt.Errorf("short tun read: %d < %d", n, virtioNetHdrLen)
} }
@@ -242,7 +241,7 @@ func (r *tunFile) decodeRead(n int) error {
return nil return nil
} }
func (r *tunFile) Write(buf []byte) (int, error) { func (r *Offload) Write(buf []byte) (int, error) {
return r.writeWithScratch(buf, &r.writeIovs) return r.writeWithScratch(buf, &r.writeIovs)
} }
@@ -250,36 +249,33 @@ func (r *tunFile) Write(buf []byte) (int, error) {
// distinct from the one used by the coalescer's Write path. This avoids a // distinct from the one used by the coalescer's Write path. This avoids a
// data race between the inside (listenIn) goroutine emitting reject or // data race between the inside (listenIn) goroutine emitting reject or
// self-forward packets and the outside (listenOut) goroutine flushing TCP // self-forward packets and the outside (listenOut) goroutine flushing TCP
// coalescer passthroughs on the same tunFile. // coalescer passthroughs on the same Offload.
func (r *tunFile) WriteReject(buf []byte) (int, error) { func (r *Offload) WriteReject(buf []byte) (int, error) {
return r.writeWithScratch(buf, &r.rejectIovs) return r.writeWithScratch(buf, &r.rejectIovs)
} }
func (r *tunFile) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error) { func (r *Offload) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error) {
if len(buf) == 0 { if len(buf) == 0 {
return 0, nil return 0, nil
} }
// Point the payload iovec at the caller's buffer. iovs[0] is pre-wired // Point the payload iovec at the caller's buffer. iovs[0] is pre-wired
// to validVnetHdr during tunFile construction so we don't rebuild it here. // to validVnetHdr during Offload construction so we don't rebuild it here.
iovs[1].Base = &buf[0] iovs[1].Base = &buf[0]
iovs[1].SetLen(len(buf)) iovs[1].SetLen(len(buf))
iovPtr := uintptr(unsafe.Pointer(&iovs[0])) iovPtr := unsafe.Pointer(&iovs[0])
// The TUN fd is non-blocking (set in newTunFd / newFriend), so writev return r.rawWrite(iovPtr, 2)
// either completes promptly or returns EAGAIN — it cannot park the }
// goroutine inside the kernel. That lets us use syscall.RawSyscall and
// skip the runtime.entersyscall / exitsyscall bookkeeping on every func (r *Offload) rawWrite(iovs unsafe.Pointer, iovcnt int) (int, error) {
// packet; we only pay that cost when we fall through to blockOnWrite.
for { for {
n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, 2) n, _, errno := syscall.Syscall(unix.SYS_WRITEV, uintptr(r.fd), uintptr(iovs), uintptr(iovcnt))
if errno == 0 { if errno == 0 {
runtime.KeepAlive(buf)
if int(n) < virtioNetHdrLen { if int(n) < virtioNetHdrLen {
return 0, io.ErrShortWrite return 0, io.ErrShortWrite
} }
return int(n) - virtioNetHdrLen, nil return int(n) - virtioNetHdrLen, nil
} }
if errno == unix.EAGAIN { if errno == unix.EAGAIN {
runtime.KeepAlive(buf)
if err := r.blockOnWrite(); err != nil { if err := r.blockOnWrite(); err != nil {
return 0, err return 0, err
} }
@@ -291,7 +287,6 @@ func (r *tunFile) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error)
if errno == unix.EBADF { if errno == unix.EBADF {
return 0, os.ErrClosed return 0, os.ErrClosed
} }
runtime.KeepAlive(buf)
return 0, errno return 0, errno
} }
} }
@@ -299,7 +294,7 @@ func (r *tunFile) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error)
// GSOSupported reports whether this queue was opened with IFF_VNET_HDR and // GSOSupported reports whether this queue was opened with IFF_VNET_HDR and
// can accept WriteGSO. When false, callers should fall back to per-segment // can accept WriteGSO. When false, callers should fall back to per-segment
// Write calls. // Write calls.
func (r *tunFile) GSOSupported() bool { return true } func (r *Offload) GSOSupported() bool { return true }
// WriteGSO emits a TCP TSO superpacket in a single writev. hdr is the // WriteGSO emits a TCP TSO superpacket in a single writev. hdr is the
// IPv4/IPv6 + TCP header prefix (already finalized — total length, IP csum, // IPv4/IPv6 + TCP header prefix (already finalized — total length, IP csum,
@@ -308,7 +303,7 @@ func (r *tunFile) GSOSupported() bool { return true }
// slice is read-only and must stay valid until return. gsoSize is the MSS; // slice is read-only and must stay valid until return. gsoSize is the MSS;
// every segment except possibly the last is exactly gsoSize bytes. // every segment except possibly the last is exactly gsoSize bytes.
// csumStart is the byte offset where the TCP header begins within hdr. // csumStart is the byte offset where the TCP header begins within hdr.
func (r *tunFile) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error { func (r *Offload) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error {
if len(hdr) == 0 || len(pays) == 0 { if len(hdr) == 0 || len(pays) == 0 {
return nil return nil
} }
@@ -356,45 +351,18 @@ func (r *tunFile) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool,
r.gsoIovs[2+i].SetLen(len(p)) r.gsoIovs[2+i].SetLen(len(p))
} }
iovPtr := uintptr(unsafe.Pointer(&r.gsoIovs[0])) iovPtr := unsafe.Pointer(&r.gsoIovs[0])
iovCnt := uintptr(len(r.gsoIovs)) iovCnt := len(r.gsoIovs)
for { _, err := r.rawWrite(iovPtr, iovCnt)
n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, iovCnt)
if errno == 0 {
runtime.KeepAlive(hdr)
runtime.KeepAlive(pays)
if int(n) < virtioNetHdrLen {
return io.ErrShortWrite
}
return nil
}
if errno == unix.EAGAIN {
runtime.KeepAlive(hdr)
runtime.KeepAlive(pays)
if err := r.blockOnWrite(); err != nil {
return err return err
}
continue
}
if errno == unix.EINTR {
continue
}
if errno == unix.EBADF {
return os.ErrClosed
}
runtime.KeepAlive(hdr)
runtime.KeepAlive(pays)
return errno
}
} }
func (r *tunFile) Close() error { func (r *Offload) Close() error {
if r.closed.Swap(true) { if r.closed.Swap(true) {
return nil return nil
} }
//shutdownFd is owned by the container, so we should not close it //shutdownFd is owned by the container, so we should not close it
var err error var err error
if r.fd >= 0 { if r.fd >= 0 {
err = unix.Close(r.fd) err = unix.Close(r.fd)

View File

@@ -10,11 +10,12 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// newReadPipe returns a read fd. The matching write fd is registered for cleanup. // newReadPipe returns a read fd. The matching write fd is registered for cleanup.
// The caller takes ownership of the read fd (pass it to newTunFd / newFriend). // The caller takes ownership of the read fd (pass it to newOffload / newFriend).
func newReadPipe(t *testing.T) int { func newReadPipe(t *testing.T) int {
t.Helper() t.Helper()
var fds [2]int var fds [2]int
@@ -25,70 +26,35 @@ func newReadPipe(t *testing.T) int {
return fds[0] return fds[0]
} }
func TestTunFile_WakeForShutdown_UnblocksRead(t *testing.T) { func TestOffload_WakeForShutdown_WakesFriends(t *testing.T) {
tf, err := newTunFd(newReadPipe(t)) pipe1 := newReadPipe(t)
pipe2 := newReadPipe(t)
parent, err := NewOffloadContainer()
if err != nil { if err != nil {
t.Fatalf("newTunFd: %v", err) t.Fatalf("newOffload: %v", err)
}
t.Cleanup(func() { _ = tf.Close() })
done := make(chan error, 1)
go func() {
_, err := tf.Read(make([]byte, 64))
done <- err
}()
// Verify Read is actually blocked in poll.
select {
case err := <-done:
t.Fatalf("Read returned before shutdown signal: %v", err)
case <-time.After(50 * time.Millisecond):
}
if err := tf.wakeForShutdown(); err != nil {
t.Fatalf("wakeForShutdown: %v", err)
}
select {
case err := <-done:
if !errors.Is(err, os.ErrClosed) {
t.Fatalf("expected os.ErrClosed, got %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("Read did not wake on shutdown")
}
}
func TestTunFile_WakeForShutdown_WakesFriends(t *testing.T) {
parent, err := newTunFd(newReadPipe(t))
if err != nil {
t.Fatalf("newTunFd: %v", err)
}
friend, err := parent.newFriend(newReadPipe(t))
if err != nil {
_ = parent.Close()
t.Fatalf("newFriend: %v", err)
} }
require.NoError(t, parent.Add(pipe1))
require.NoError(t, parent.Add(pipe2))
t.Cleanup(func() { t.Cleanup(func() {
_ = friend.Close() _ = unix.Close(pipe1)
_ = parent.Close() _ = unix.Close(pipe2)
}) })
readers := []*tunFile{parent, friend} readers := parent.Queues()
errs := make([]error, len(readers)) errs := make([]error, len(readers))
var wg sync.WaitGroup var wg sync.WaitGroup
for i, r := range readers { for i, r := range readers {
wg.Add(1) wg.Add(1)
go func(i int, r *tunFile) { go func(i int, r Queue) {
defer wg.Done() defer wg.Done()
_, errs[i] = r.Read(make([]byte, 64)) _, errs[i] = r.Read()
}(i, r) }(i, r)
} }
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
if err := parent.wakeForShutdown(); err != nil { if err := parent.Close(); err != nil {
t.Fatalf("wakeForShutdown: %v", err) t.Fatalf("Close: %v", err)
} }
done := make(chan struct{}) done := make(chan struct{})
@@ -107,9 +73,9 @@ func TestTunFile_WakeForShutdown_WakesFriends(t *testing.T) {
} }
func TestTunFile_Close_Idempotent(t *testing.T) { func TestTunFile_Close_Idempotent(t *testing.T) {
tf, err := newTunFd(newReadPipe(t)) tf, err := newOffload(newReadPipe(t), 1)
if err != nil { if err != nil {
t.Fatalf("newTunFd: %v", err) t.Fatalf("newOffload: %v", err)
} }
if err := tf.Close(); err != nil { if err := tf.Close(); err != nil {
t.Fatalf("first Close: %v", err) t.Fatalf("first Close: %v", err)

View File

@@ -309,7 +309,7 @@ func TestTunFileWriteVnetHdrNoAlloc(t *testing.T) {
} }
t.Cleanup(func() { _ = unix.Close(fd) }) t.Cleanup(func() { _ = unix.Close(fd) })
tf := &tunFile{fd: fd} tf := &Offload{fd: fd}
tf.writeIovs[0].Base = &validVnetHdr[0] tf.writeIovs[0].Base = &validVnetHdr[0]
tf.writeIovs[0].SetLen(virtioNetHdrLen) tf.writeIovs[0].SetLen(virtioNetHdrLen)

View File

@@ -1,38 +0,0 @@
//go:build !e2e_testing
// +build !e2e_testing
package tio
import (
"testing"
"github.com/slackhq/nebula/overlay"
)
var runAdvMSSTests = []struct {
name string
tun *overlay.tun
r overlay.Route
expected int
}{
// Standard case, default MTU is the device max MTU
{"default", &overlay.tun{DefaultMTU: 1440, MaxMTU: 1440}, overlay.Route{}, 0},
{"default-min", &overlay.tun{DefaultMTU: 1440, MaxMTU: 1440}, overlay.Route{MTU: 1440}, 0},
{"default-low", &overlay.tun{DefaultMTU: 1440, MaxMTU: 1440}, overlay.Route{MTU: 1200}, 1160},
// Case where we have a route MTU set higher than the default
{"route", &overlay.tun{DefaultMTU: 1440, MaxMTU: 8941}, overlay.Route{}, 1400},
{"route-min", &overlay.tun{DefaultMTU: 1440, MaxMTU: 8941}, overlay.Route{MTU: 1440}, 1400},
{"route-high", &overlay.tun{DefaultMTU: 1440, MaxMTU: 8941}, overlay.Route{MTU: 8941}, 0},
}
func TestTunAdvMSS(t *testing.T) {
for _, tt := range runAdvMSSTests {
t.Run(tt.name, func(t *testing.T) {
o := tt.tun.advMSS(tt.r)
if o != tt.expected {
t.Errorf("got %d, want %d", o, tt.expected)
}
})
}
}

View File

@@ -176,7 +176,7 @@ func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vnetHdr bool, vpnNetwo
var container tio.Container var container tio.Container
var err error var err error
if vnetHdr { if vnetHdr {
container, err = tio.NewGSOContainer() container, err = tio.NewOffloadContainer()
} else { } else {
container, err = tio.NewPollContainer() container, err = tio.NewPollContainer()
} }

36
overlay/tun_linux_test.go Normal file
View File

@@ -0,0 +1,36 @@
//go:build !e2e_testing
// +build !e2e_testing
package overlay
import (
"testing"
)
var runAdvMSSTests = []struct {
name string
tun *tun
r Route
expected int
}{
// Standard case, default MTU is the device max MTU
{"default", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0},
{"default-min", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0},
{"default-low", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160},
// Case where we have a route MTU set higher than the default
{"route", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400},
{"route-min", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400},
{"route-high", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0},
}
func TestTunAdvMSS(t *testing.T) {
for _, tt := range runAdvMSSTests {
t.Run(tt.name, func(t *testing.T) {
o := tt.tun.advMSS(tt.r)
if o != tt.expected {
t.Errorf("got %d, want %d", o, tt.expected)
}
})
}
}

View File

@@ -24,6 +24,7 @@ func NewUserDevice(vpnNetworks []netip.Prefix) (Device, error) {
outboundWriter: ow, outboundWriter: ow,
inboundReader: ir, inboundReader: ir,
inboundWriter: iw, inboundWriter: iw,
numReaders: 1,
}, nil }, nil
} }