mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-16 04:47:38 +02:00
fix tests
This commit is contained in:
@@ -8,21 +8,21 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type gsoContainer struct {
|
||||
pq []*tunFile
|
||||
type offloadContainer struct {
|
||||
pq []*Offload
|
||||
// pqi is exactly the same as pq, but stored as the interface type
|
||||
pqi []Queue
|
||||
shutdownFd int
|
||||
}
|
||||
|
||||
func NewGSOContainer() (Container, error) {
|
||||
func NewOffloadContainer() (Container, error) {
|
||||
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create eventfd: %w", err)
|
||||
}
|
||||
|
||||
out := &gsoContainer{
|
||||
pq: []*tunFile{},
|
||||
out := &offloadContainer{
|
||||
pq: []*Offload{},
|
||||
pqi: []Queue{},
|
||||
shutdownFd: shutdownFd,
|
||||
}
|
||||
@@ -30,12 +30,12 @@ func NewGSOContainer() (Container, error) {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *gsoContainer) Queues() []Queue {
|
||||
func (c *offloadContainer) Queues() []Queue {
|
||||
return c.pqi
|
||||
}
|
||||
|
||||
func (c *gsoContainer) Add(fd int) error {
|
||||
x, err := newTunFd(fd, c.shutdownFd)
|
||||
func (c *offloadContainer) Add(fd int) error {
|
||||
x, err := newOffload(fd, c.shutdownFd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -45,14 +45,14 @@ func (c *gsoContainer) Add(fd int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *gsoContainer) wakeForShutdown() error {
|
||||
func (c *offloadContainer) wakeForShutdown() error {
|
||||
var buf [8]byte
|
||||
binary.NativeEndian.PutUint64(buf[:], 1)
|
||||
_, err := unix.Write(int(c.shutdownFd), buf[:])
|
||||
_, err := unix.Write(c.shutdownFd, buf[:])
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *gsoContainer) Close() error {
|
||||
func (c *offloadContainer) Close() error {
|
||||
errs := []error{}
|
||||
|
||||
// Signal all readers blocked in poll to wake up and exit
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
@@ -28,7 +27,7 @@ const tunSegBufCap = tunSegBufSize * 2
|
||||
const tunDrainCap = 64
|
||||
|
||||
// gsoInitialPayIovs is the starting capacity (in payload fragments) of
|
||||
// tunFile.gsoIovs. Sized to cover the default coalesce segment cap without
|
||||
// Offload.gsoIovs. Sized to cover the default coalesce segment cap without
|
||||
// any reallocations.
|
||||
const gsoInitialPayIovs = 66
|
||||
|
||||
@@ -42,9 +41,9 @@ const gsoInitialPayIovs = 66
|
||||
// safe.
|
||||
var validVnetHdr = [virtioNetHdrLen]byte{unix.VIRTIO_NET_HDR_F_DATA_VALID}
|
||||
|
||||
// tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking.
|
||||
// Offload wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking.
|
||||
// A shared eventfd allows Close to wake all readers blocked in poll.
|
||||
type tunFile struct { //todo rename GSO
|
||||
type Offload struct {
|
||||
fd int
|
||||
shutdownFd int
|
||||
readPoll [2]unix.PollFd
|
||||
@@ -71,12 +70,12 @@ type tunFile struct { //todo rename GSO
|
||||
gsoIovs []unix.Iovec
|
||||
}
|
||||
|
||||
func newTunFd(fd int, shutdownFd int) (*tunFile, error) {
|
||||
func newOffload(fd int, shutdownFd int) (*Offload, error) {
|
||||
if err := unix.SetNonblock(fd, true); err != nil {
|
||||
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
|
||||
}
|
||||
|
||||
out := &tunFile{
|
||||
out := &Offload{
|
||||
fd: fd,
|
||||
shutdownFd: shutdownFd,
|
||||
closed: atomic.Bool{},
|
||||
@@ -104,7 +103,7 @@ func newTunFd(fd int, shutdownFd int) (*tunFile, error) {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *tunFile) blockOnRead() error {
|
||||
func (r *Offload) blockOnRead() error {
|
||||
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
|
||||
var err error
|
||||
for {
|
||||
@@ -130,7 +129,7 @@ func (r *tunFile) blockOnRead() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *tunFile) blockOnWrite() error {
|
||||
func (r *Offload) blockOnWrite() error {
|
||||
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
|
||||
var err error
|
||||
for {
|
||||
@@ -156,7 +155,7 @@ func (r *tunFile) blockOnWrite() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *tunFile) readRaw(buf []byte) (int, error) {
|
||||
func (r *Offload) readRaw(buf []byte) (int, error) {
|
||||
for {
|
||||
if n, err := unix.Read(r.fd, buf); err == nil {
|
||||
return n, nil
|
||||
@@ -180,9 +179,9 @@ func (r *tunFile) readRaw(buf []byte) (int, error) {
|
||||
// readable we drain additional packets non-blocking until the kernel queue
|
||||
// is empty (EAGAIN), we've collected tunDrainCap packets, or we're out of
|
||||
// segBuf headroom. This amortizes the poll wake over bursts of small
|
||||
// packets (e.g. TCP ACKs). Slices point into the tunFile's internal buffers
|
||||
// packets (e.g. TCP ACKs). Slices point into the Offload's internal buffers
|
||||
// and are only valid until the next Read or Close on this Queue.
|
||||
func (r *tunFile) Read() ([][]byte, error) {
|
||||
func (r *Offload) Read() ([][]byte, error) {
|
||||
r.pending = r.pending[:0]
|
||||
r.segOff = 0
|
||||
|
||||
@@ -226,7 +225,7 @@ func (r *tunFile) Read() ([][]byte, error) {
|
||||
// decodeRead decodes the virtio header plus payload in r.readBuf[:n], appends
|
||||
// the segments to r.pending, and advances r.segOff by the total scratch used.
|
||||
// Caller must have already ensured r.vnetHdr is true.
|
||||
func (r *tunFile) decodeRead(n int) error {
|
||||
func (r *Offload) decodeRead(n int) error {
|
||||
if n < virtioNetHdrLen {
|
||||
return fmt.Errorf("short tun read: %d < %d", n, virtioNetHdrLen)
|
||||
}
|
||||
@@ -242,7 +241,7 @@ func (r *tunFile) decodeRead(n int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *tunFile) Write(buf []byte) (int, error) {
|
||||
func (r *Offload) Write(buf []byte) (int, error) {
|
||||
return r.writeWithScratch(buf, &r.writeIovs)
|
||||
}
|
||||
|
||||
@@ -250,36 +249,33 @@ func (r *tunFile) Write(buf []byte) (int, error) {
|
||||
// distinct from the one used by the coalescer's Write path. This avoids a
|
||||
// data race between the inside (listenIn) goroutine emitting reject or
|
||||
// self-forward packets and the outside (listenOut) goroutine flushing TCP
|
||||
// coalescer passthroughs on the same tunFile.
|
||||
func (r *tunFile) WriteReject(buf []byte) (int, error) {
|
||||
// coalescer passthroughs on the same Offload.
|
||||
func (r *Offload) WriteReject(buf []byte) (int, error) {
|
||||
return r.writeWithScratch(buf, &r.rejectIovs)
|
||||
}
|
||||
|
||||
func (r *tunFile) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error) {
|
||||
func (r *Offload) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error) {
|
||||
if len(buf) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
// Point the payload iovec at the caller's buffer. iovs[0] is pre-wired
|
||||
// to validVnetHdr during tunFile construction so we don't rebuild it here.
|
||||
// to validVnetHdr during Offload construction so we don't rebuild it here.
|
||||
iovs[1].Base = &buf[0]
|
||||
iovs[1].SetLen(len(buf))
|
||||
iovPtr := uintptr(unsafe.Pointer(&iovs[0]))
|
||||
// The TUN fd is non-blocking (set in newTunFd / newFriend), so writev
|
||||
// either completes promptly or returns EAGAIN — it cannot park the
|
||||
// goroutine inside the kernel. That lets us use syscall.RawSyscall and
|
||||
// skip the runtime.entersyscall / exitsyscall bookkeeping on every
|
||||
// packet; we only pay that cost when we fall through to blockOnWrite.
|
||||
iovPtr := unsafe.Pointer(&iovs[0])
|
||||
return r.rawWrite(iovPtr, 2)
|
||||
}
|
||||
|
||||
func (r *Offload) rawWrite(iovs unsafe.Pointer, iovcnt int) (int, error) {
|
||||
for {
|
||||
n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, 2)
|
||||
n, _, errno := syscall.Syscall(unix.SYS_WRITEV, uintptr(r.fd), uintptr(iovs), uintptr(iovcnt))
|
||||
if errno == 0 {
|
||||
runtime.KeepAlive(buf)
|
||||
if int(n) < virtioNetHdrLen {
|
||||
return 0, io.ErrShortWrite
|
||||
}
|
||||
return int(n) - virtioNetHdrLen, nil
|
||||
}
|
||||
if errno == unix.EAGAIN {
|
||||
runtime.KeepAlive(buf)
|
||||
if err := r.blockOnWrite(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -291,7 +287,6 @@ func (r *tunFile) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error)
|
||||
if errno == unix.EBADF {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
runtime.KeepAlive(buf)
|
||||
return 0, errno
|
||||
}
|
||||
}
|
||||
@@ -299,7 +294,7 @@ func (r *tunFile) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error)
|
||||
// GSOSupported reports whether this queue was opened with IFF_VNET_HDR and
|
||||
// can accept WriteGSO. When false, callers should fall back to per-segment
|
||||
// Write calls.
|
||||
func (r *tunFile) GSOSupported() bool { return true }
|
||||
func (r *Offload) GSOSupported() bool { return true }
|
||||
|
||||
// WriteGSO emits a TCP TSO superpacket in a single writev. hdr is the
|
||||
// IPv4/IPv6 + TCP header prefix (already finalized — total length, IP csum,
|
||||
@@ -308,7 +303,7 @@ func (r *tunFile) GSOSupported() bool { return true }
|
||||
// slice is read-only and must stay valid until return. gsoSize is the MSS;
|
||||
// every segment except possibly the last is exactly gsoSize bytes.
|
||||
// csumStart is the byte offset where the TCP header begins within hdr.
|
||||
func (r *tunFile) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error {
|
||||
func (r *Offload) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error {
|
||||
if len(hdr) == 0 || len(pays) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -356,45 +351,18 @@ func (r *tunFile) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool,
|
||||
r.gsoIovs[2+i].SetLen(len(p))
|
||||
}
|
||||
|
||||
iovPtr := uintptr(unsafe.Pointer(&r.gsoIovs[0]))
|
||||
iovCnt := uintptr(len(r.gsoIovs))
|
||||
for {
|
||||
n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, iovCnt)
|
||||
if errno == 0 {
|
||||
runtime.KeepAlive(hdr)
|
||||
runtime.KeepAlive(pays)
|
||||
if int(n) < virtioNetHdrLen {
|
||||
return io.ErrShortWrite
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if errno == unix.EAGAIN {
|
||||
runtime.KeepAlive(hdr)
|
||||
runtime.KeepAlive(pays)
|
||||
if err := r.blockOnWrite(); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if errno == unix.EINTR {
|
||||
continue
|
||||
}
|
||||
if errno == unix.EBADF {
|
||||
return os.ErrClosed
|
||||
}
|
||||
runtime.KeepAlive(hdr)
|
||||
runtime.KeepAlive(pays)
|
||||
return errno
|
||||
}
|
||||
iovPtr := unsafe.Pointer(&r.gsoIovs[0])
|
||||
iovCnt := len(r.gsoIovs)
|
||||
_, err := r.rawWrite(iovPtr, iovCnt)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *tunFile) Close() error {
|
||||
func (r *Offload) Close() error {
|
||||
if r.closed.Swap(true) {
|
||||
return nil
|
||||
}
|
||||
|
||||
//shutdownFd is owned by the container, so we should not close it
|
||||
|
||||
var err error
|
||||
if r.fd >= 0 {
|
||||
err = unix.Close(r.fd)
|
||||
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// newReadPipe returns a read fd. The matching write fd is registered for cleanup.
|
||||
// The caller takes ownership of the read fd (pass it to newTunFd / newFriend).
|
||||
// The caller takes ownership of the read fd (pass it to newOffload / newFriend).
|
||||
func newReadPipe(t *testing.T) int {
|
||||
t.Helper()
|
||||
var fds [2]int
|
||||
@@ -25,70 +26,35 @@ func newReadPipe(t *testing.T) int {
|
||||
return fds[0]
|
||||
}
|
||||
|
||||
func TestTunFile_WakeForShutdown_UnblocksRead(t *testing.T) {
|
||||
tf, err := newTunFd(newReadPipe(t))
|
||||
func TestOffload_WakeForShutdown_WakesFriends(t *testing.T) {
|
||||
pipe1 := newReadPipe(t)
|
||||
pipe2 := newReadPipe(t)
|
||||
parent, err := NewOffloadContainer()
|
||||
if err != nil {
|
||||
t.Fatalf("newTunFd: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = tf.Close() })
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := tf.Read(make([]byte, 64))
|
||||
done <- err
|
||||
}()
|
||||
|
||||
// Verify Read is actually blocked in poll.
|
||||
select {
|
||||
case err := <-done:
|
||||
t.Fatalf("Read returned before shutdown signal: %v", err)
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
if err := tf.wakeForShutdown(); err != nil {
|
||||
t.Fatalf("wakeForShutdown: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if !errors.Is(err, os.ErrClosed) {
|
||||
t.Fatalf("expected os.ErrClosed, got %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Read did not wake on shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTunFile_WakeForShutdown_WakesFriends(t *testing.T) {
|
||||
parent, err := newTunFd(newReadPipe(t))
|
||||
if err != nil {
|
||||
t.Fatalf("newTunFd: %v", err)
|
||||
}
|
||||
friend, err := parent.newFriend(newReadPipe(t))
|
||||
if err != nil {
|
||||
_ = parent.Close()
|
||||
t.Fatalf("newFriend: %v", err)
|
||||
t.Fatalf("newOffload: %v", err)
|
||||
}
|
||||
require.NoError(t, parent.Add(pipe1))
|
||||
require.NoError(t, parent.Add(pipe2))
|
||||
t.Cleanup(func() {
|
||||
_ = friend.Close()
|
||||
_ = parent.Close()
|
||||
_ = unix.Close(pipe1)
|
||||
_ = unix.Close(pipe2)
|
||||
})
|
||||
|
||||
readers := []*tunFile{parent, friend}
|
||||
readers := parent.Queues()
|
||||
errs := make([]error, len(readers))
|
||||
var wg sync.WaitGroup
|
||||
for i, r := range readers {
|
||||
wg.Add(1)
|
||||
go func(i int, r *tunFile) {
|
||||
go func(i int, r Queue) {
|
||||
defer wg.Done()
|
||||
_, errs[i] = r.Read(make([]byte, 64))
|
||||
_, errs[i] = r.Read()
|
||||
}(i, r)
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if err := parent.wakeForShutdown(); err != nil {
|
||||
t.Fatalf("wakeForShutdown: %v", err)
|
||||
if err := parent.Close(); err != nil {
|
||||
t.Fatalf("Close: %v", err)
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
@@ -107,9 +73,9 @@ func TestTunFile_WakeForShutdown_WakesFriends(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTunFile_Close_Idempotent(t *testing.T) {
|
||||
tf, err := newTunFd(newReadPipe(t))
|
||||
tf, err := newOffload(newReadPipe(t), 1)
|
||||
if err != nil {
|
||||
t.Fatalf("newTunFd: %v", err)
|
||||
t.Fatalf("newOffload: %v", err)
|
||||
}
|
||||
if err := tf.Close(); err != nil {
|
||||
t.Fatalf("first Close: %v", err)
|
||||
|
||||
@@ -309,7 +309,7 @@ func TestTunFileWriteVnetHdrNoAlloc(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(func() { _ = unix.Close(fd) })
|
||||
|
||||
tf := &tunFile{fd: fd}
|
||||
tf := &Offload{fd: fd}
|
||||
tf.writeIovs[0].Base = &validVnetHdr[0]
|
||||
tf.writeIovs[0].SetLen(virtioNetHdrLen)
|
||||
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
//go:build !e2e_testing
|
||||
// +build !e2e_testing
|
||||
|
||||
package tio
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/slackhq/nebula/overlay"
|
||||
)
|
||||
|
||||
var runAdvMSSTests = []struct {
|
||||
name string
|
||||
tun *overlay.tun
|
||||
r overlay.Route
|
||||
expected int
|
||||
}{
|
||||
// Standard case, default MTU is the device max MTU
|
||||
{"default", &overlay.tun{DefaultMTU: 1440, MaxMTU: 1440}, overlay.Route{}, 0},
|
||||
{"default-min", &overlay.tun{DefaultMTU: 1440, MaxMTU: 1440}, overlay.Route{MTU: 1440}, 0},
|
||||
{"default-low", &overlay.tun{DefaultMTU: 1440, MaxMTU: 1440}, overlay.Route{MTU: 1200}, 1160},
|
||||
|
||||
// Case where we have a route MTU set higher than the default
|
||||
{"route", &overlay.tun{DefaultMTU: 1440, MaxMTU: 8941}, overlay.Route{}, 1400},
|
||||
{"route-min", &overlay.tun{DefaultMTU: 1440, MaxMTU: 8941}, overlay.Route{MTU: 1440}, 1400},
|
||||
{"route-high", &overlay.tun{DefaultMTU: 1440, MaxMTU: 8941}, overlay.Route{MTU: 8941}, 0},
|
||||
}
|
||||
|
||||
func TestTunAdvMSS(t *testing.T) {
|
||||
for _, tt := range runAdvMSSTests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
o := tt.tun.advMSS(tt.r)
|
||||
if o != tt.expected {
|
||||
t.Errorf("got %d, want %d", o, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -176,7 +176,7 @@ func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vnetHdr bool, vpnNetwo
|
||||
var container tio.Container
|
||||
var err error
|
||||
if vnetHdr {
|
||||
container, err = tio.NewGSOContainer()
|
||||
container, err = tio.NewOffloadContainer()
|
||||
} else {
|
||||
container, err = tio.NewPollContainer()
|
||||
}
|
||||
|
||||
36
overlay/tun_linux_test.go
Normal file
36
overlay/tun_linux_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
//go:build !e2e_testing
|
||||
// +build !e2e_testing
|
||||
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
var runAdvMSSTests = []struct {
|
||||
name string
|
||||
tun *tun
|
||||
r Route
|
||||
expected int
|
||||
}{
|
||||
// Standard case, default MTU is the device max MTU
|
||||
{"default", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0},
|
||||
{"default-min", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0},
|
||||
{"default-low", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160},
|
||||
|
||||
// Case where we have a route MTU set higher than the default
|
||||
{"route", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400},
|
||||
{"route-min", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400},
|
||||
{"route-high", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0},
|
||||
}
|
||||
|
||||
func TestTunAdvMSS(t *testing.T) {
|
||||
for _, tt := range runAdvMSSTests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
o := tt.tun.advMSS(tt.r)
|
||||
if o != tt.expected {
|
||||
t.Errorf("got %d, want %d", o, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -24,6 +24,7 @@ func NewUserDevice(vpnNetworks []netip.Prefix) (Device, error) {
|
||||
outboundWriter: ow,
|
||||
inboundReader: ir,
|
||||
inboundWriter: iw,
|
||||
numReaders: 1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user