diff --git a/overlay/tio/container_gso_linux.go b/overlay/tio/container_gso_linux.go index f5260d68..aff23a37 100644 --- a/overlay/tio/container_gso_linux.go +++ b/overlay/tio/container_gso_linux.go @@ -8,21 +8,21 @@ import ( "golang.org/x/sys/unix" ) -type gsoContainer struct { - pq []*tunFile +type offloadContainer struct { + pq []*Offload // pqi is exactly the same as pq, but stored as the interface type pqi []Queue shutdownFd int } -func NewGSOContainer() (Container, error) { +func NewOffloadContainer() (Container, error) { shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC) if err != nil { return nil, fmt.Errorf("failed to create eventfd: %w", err) } - out := &gsoContainer{ - pq: []*tunFile{}, + out := &offloadContainer{ + pq: []*Offload{}, pqi: []Queue{}, shutdownFd: shutdownFd, } @@ -30,12 +30,12 @@ func NewGSOContainer() (Container, error) { return out, nil } -func (c *gsoContainer) Queues() []Queue { +func (c *offloadContainer) Queues() []Queue { return c.pqi } -func (c *gsoContainer) Add(fd int) error { - x, err := newTunFd(fd, c.shutdownFd) +func (c *offloadContainer) Add(fd int) error { + x, err := newOffload(fd, c.shutdownFd) if err != nil { return err } @@ -45,14 +45,14 @@ func (c *gsoContainer) Add(fd int) error { return nil } -func (c *gsoContainer) wakeForShutdown() error { +func (c *offloadContainer) wakeForShutdown() error { var buf [8]byte binary.NativeEndian.PutUint64(buf[:], 1) - _, err := unix.Write(int(c.shutdownFd), buf[:]) + _, err := unix.Write(c.shutdownFd, buf[:]) return err } -func (c *gsoContainer) Close() error { +func (c *offloadContainer) Close() error { errs := []error{} // Signal all readers blocked in poll to wake up and exit diff --git a/overlay/tio/tio_gso_linux.go b/overlay/tio/tio_gso_linux.go index c5eeabb0..d665e585 100644 --- a/overlay/tio/tio_gso_linux.go +++ b/overlay/tio/tio_gso_linux.go @@ -4,7 +4,6 @@ import ( "fmt" "io" "os" - "runtime" "sync/atomic" "syscall" "unsafe" @@ -28,7 +27,7 @@ const tunSegBufCap = tunSegBufSize * 2 const tunDrainCap = 64 // gsoInitialPayIovs is the starting capacity (in payload fragments) of -// tunFile.gsoIovs. Sized to cover the default coalesce segment cap without +// Offload.gsoIovs. Sized to cover the default coalesce segment cap without // any reallocations. const gsoInitialPayIovs = 66 @@ -42,9 +41,9 @@ const gsoInitialPayIovs = 66 // safe. var validVnetHdr = [virtioNetHdrLen]byte{unix.VIRTIO_NET_HDR_F_DATA_VALID} -// tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking. +// Offload wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking. // A shared eventfd allows Close to wake all readers blocked in poll. -type tunFile struct { //todo rename GSO +type Offload struct { fd int shutdownFd int readPoll [2]unix.PollFd @@ -71,12 +70,12 @@ type tunFile struct { //todo rename GSO gsoIovs []unix.Iovec } -func newTunFd(fd int, shutdownFd int) (*tunFile, error) { +func newOffload(fd int, shutdownFd int) (*Offload, error) { if err := unix.SetNonblock(fd, true); err != nil { return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) } - out := &tunFile{ + out := &Offload{ fd: fd, shutdownFd: shutdownFd, closed: atomic.Bool{}, @@ -104,7 +103,7 @@ func newTunFd(fd int, shutdownFd int) (*tunFile, error) { return out, nil } -func (r *tunFile) blockOnRead() error { +func (r *Offload) blockOnRead() error { const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR var err error for { @@ -130,7 +129,7 @@ func (r *tunFile) blockOnRead() error { return nil } -func (r *tunFile) blockOnWrite() error { +func (r *Offload) blockOnWrite() error { const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR var err error for { @@ -156,7 +155,7 @@ func (r *tunFile) blockOnWrite() error { return nil } -func (r *tunFile) readRaw(buf []byte) (int, error) { +func (r *Offload) readRaw(buf []byte) (int, error) { for { if n, err := unix.Read(r.fd, buf); err == nil { return n, nil @@ -180,9 +179,9 @@ func (r *tunFile) readRaw(buf []byte) (int, error) { // readable we drain additional packets non-blocking until the kernel queue // is empty (EAGAIN), we've collected tunDrainCap packets, or we're out of // segBuf headroom. This amortizes the poll wake over bursts of small -// packets (e.g. TCP ACKs). Slices point into the tunFile's internal buffers +// packets (e.g. TCP ACKs). Slices point into the Offload's internal buffers // and are only valid until the next Read or Close on this Queue. -func (r *tunFile) Read() ([][]byte, error) { +func (r *Offload) Read() ([][]byte, error) { r.pending = r.pending[:0] r.segOff = 0 @@ -226,7 +225,7 @@ func (r *tunFile) Read() ([][]byte, error) { // decodeRead decodes the virtio header plus payload in r.readBuf[:n], appends // the segments to r.pending, and advances r.segOff by the total scratch used. // Caller must have already ensured r.vnetHdr is true. -func (r *tunFile) decodeRead(n int) error { +func (r *Offload) decodeRead(n int) error { if n < virtioNetHdrLen { return fmt.Errorf("short tun read: %d < %d", n, virtioNetHdrLen) } @@ -242,7 +241,7 @@ func (r *tunFile) decodeRead(n int) error { return nil } -func (r *tunFile) Write(buf []byte) (int, error) { +func (r *Offload) Write(buf []byte) (int, error) { return r.writeWithScratch(buf, &r.writeIovs) } @@ -250,36 +249,33 @@ func (r *tunFile) Write(buf []byte) (int, error) { // distinct from the one used by the coalescer's Write path. This avoids a // data race between the inside (listenIn) goroutine emitting reject or // self-forward packets and the outside (listenOut) goroutine flushing TCP -// coalescer passthroughs on the same tunFile. -func (r *tunFile) WriteReject(buf []byte) (int, error) { +// coalescer passthroughs on the same Offload. +func (r *Offload) WriteReject(buf []byte) (int, error) { return r.writeWithScratch(buf, &r.rejectIovs) } -func (r *tunFile) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error) { +func (r *Offload) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error) { if len(buf) == 0 { return 0, nil } // Point the payload iovec at the caller's buffer. iovs[0] is pre-wired - // to validVnetHdr during tunFile construction so we don't rebuild it here. + // to validVnetHdr during Offload construction so we don't rebuild it here. iovs[1].Base = &buf[0] iovs[1].SetLen(len(buf)) - iovPtr := uintptr(unsafe.Pointer(&iovs[0])) - // The TUN fd is non-blocking (set in newTunFd / newFriend), so writev - // either completes promptly or returns EAGAIN — it cannot park the - // goroutine inside the kernel. That lets us use syscall.RawSyscall and - // skip the runtime.entersyscall / exitsyscall bookkeeping on every - // packet; we only pay that cost when we fall through to blockOnWrite. + iovPtr := unsafe.Pointer(&iovs[0]) + return r.rawWrite(iovPtr, 2) +} + +func (r *Offload) rawWrite(iovs unsafe.Pointer, iovcnt int) (int, error) { for { - n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, 2) + n, _, errno := syscall.Syscall(unix.SYS_WRITEV, uintptr(r.fd), uintptr(iovs), uintptr(iovcnt)) if errno == 0 { - runtime.KeepAlive(buf) if int(n) < virtioNetHdrLen { return 0, io.ErrShortWrite } return int(n) - virtioNetHdrLen, nil } if errno == unix.EAGAIN { - runtime.KeepAlive(buf) if err := r.blockOnWrite(); err != nil { return 0, err } @@ -291,7 +287,6 @@ func (r *tunFile) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error) if errno == unix.EBADF { return 0, os.ErrClosed } - runtime.KeepAlive(buf) return 0, errno } } @@ -299,7 +294,7 @@ func (r *tunFile) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error) // GSOSupported reports whether this queue was opened with IFF_VNET_HDR and // can accept WriteGSO. When false, callers should fall back to per-segment // Write calls. -func (r *tunFile) GSOSupported() bool { return true } +func (r *Offload) GSOSupported() bool { return true } // WriteGSO emits a TCP TSO superpacket in a single writev. hdr is the // IPv4/IPv6 + TCP header prefix (already finalized — total length, IP csum, @@ -308,7 +303,7 @@ func (r *tunFile) GSOSupported() bool { return true } // slice is read-only and must stay valid until return. gsoSize is the MSS; // every segment except possibly the last is exactly gsoSize bytes. // csumStart is the byte offset where the TCP header begins within hdr. -func (r *tunFile) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error { +func (r *Offload) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error { if len(hdr) == 0 || len(pays) == 0 { return nil } @@ -356,45 +351,18 @@ func (r *tunFile) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, r.gsoIovs[2+i].SetLen(len(p)) } - iovPtr := uintptr(unsafe.Pointer(&r.gsoIovs[0])) - iovCnt := uintptr(len(r.gsoIovs)) - for { - n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, iovCnt) - if errno == 0 { - runtime.KeepAlive(hdr) - runtime.KeepAlive(pays) - if int(n) < virtioNetHdrLen { - return io.ErrShortWrite - } - return nil - } - if errno == unix.EAGAIN { - runtime.KeepAlive(hdr) - runtime.KeepAlive(pays) - if err := r.blockOnWrite(); err != nil { - return err - } - continue - } - if errno == unix.EINTR { - continue - } - if errno == unix.EBADF { - return os.ErrClosed - } - runtime.KeepAlive(hdr) - runtime.KeepAlive(pays) - return errno - } + iovPtr := unsafe.Pointer(&r.gsoIovs[0]) + iovCnt := len(r.gsoIovs) + _, err := r.rawWrite(iovPtr, iovCnt) + return err } -func (r *tunFile) Close() error { +func (r *Offload) Close() error { if r.closed.Swap(true) { return nil } //shutdownFd is owned by the container, so we should not close it - var err error if r.fd >= 0 { err = unix.Close(r.fd) diff --git a/overlay/tio/tun_file_linux_test.go b/overlay/tio/tun_file_linux_test.go index d162c7b3..6a87f487 100644 --- a/overlay/tio/tun_file_linux_test.go +++ b/overlay/tio/tun_file_linux_test.go @@ -10,11 +10,12 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" "golang.org/x/sys/unix" ) // newReadPipe returns a read fd. The matching write fd is registered for cleanup. -// The caller takes ownership of the read fd (pass it to newTunFd / newFriend). +// The caller takes ownership of the read fd (pass it to newOffload / newFriend). func newReadPipe(t *testing.T) int { t.Helper() var fds [2]int @@ -25,70 +26,35 @@ func newReadPipe(t *testing.T) int { return fds[0] } -func TestTunFile_WakeForShutdown_UnblocksRead(t *testing.T) { - tf, err := newTunFd(newReadPipe(t)) +func TestOffload_WakeForShutdown_WakesFriends(t *testing.T) { + pipe1 := newReadPipe(t) + pipe2 := newReadPipe(t) + parent, err := NewOffloadContainer() if err != nil { - t.Fatalf("newTunFd: %v", err) - } - t.Cleanup(func() { _ = tf.Close() }) - - done := make(chan error, 1) - go func() { - _, err := tf.Read(make([]byte, 64)) - done <- err - }() - - // Verify Read is actually blocked in poll. - select { - case err := <-done: - t.Fatalf("Read returned before shutdown signal: %v", err) - case <-time.After(50 * time.Millisecond): - } - - if err := tf.wakeForShutdown(); err != nil { - t.Fatalf("wakeForShutdown: %v", err) - } - - select { - case err := <-done: - if !errors.Is(err, os.ErrClosed) { - t.Fatalf("expected os.ErrClosed, got %v", err) - } - case <-time.After(2 * time.Second): - t.Fatal("Read did not wake on shutdown") - } -} - -func TestTunFile_WakeForShutdown_WakesFriends(t *testing.T) { - parent, err := newTunFd(newReadPipe(t)) - if err != nil { - t.Fatalf("newTunFd: %v", err) - } - friend, err := parent.newFriend(newReadPipe(t)) - if err != nil { - _ = parent.Close() - t.Fatalf("newFriend: %v", err) + t.Fatalf("newOffload: %v", err) } + require.NoError(t, parent.Add(pipe1)) + require.NoError(t, parent.Add(pipe2)) t.Cleanup(func() { - _ = friend.Close() - _ = parent.Close() + _ = unix.Close(pipe1) + _ = unix.Close(pipe2) }) - readers := []*tunFile{parent, friend} + readers := parent.Queues() errs := make([]error, len(readers)) var wg sync.WaitGroup for i, r := range readers { wg.Add(1) - go func(i int, r *tunFile) { + go func(i int, r Queue) { defer wg.Done() - _, errs[i] = r.Read(make([]byte, 64)) + _, errs[i] = r.Read() }(i, r) } time.Sleep(50 * time.Millisecond) - if err := parent.wakeForShutdown(); err != nil { - t.Fatalf("wakeForShutdown: %v", err) + if err := parent.Close(); err != nil { + t.Fatalf("Close: %v", err) } done := make(chan struct{}) @@ -107,9 +73,9 @@ func TestTunFile_WakeForShutdown_WakesFriends(t *testing.T) { } func TestTunFile_Close_Idempotent(t *testing.T) { - tf, err := newTunFd(newReadPipe(t)) + tf, err := newOffload(newReadPipe(t), 1) if err != nil { - t.Fatalf("newTunFd: %v", err) + t.Fatalf("newOffload: %v", err) } if err := tf.Close(); err != nil { t.Fatalf("first Close: %v", err) diff --git a/overlay/tio/tun_linux_offload_test.go b/overlay/tio/tun_linux_offload_test.go index ff080b7c..20ca9cd9 100644 --- a/overlay/tio/tun_linux_offload_test.go +++ b/overlay/tio/tun_linux_offload_test.go @@ -309,7 +309,7 @@ func TestTunFileWriteVnetHdrNoAlloc(t *testing.T) { } t.Cleanup(func() { _ = unix.Close(fd) }) - tf := &tunFile{fd: fd} + tf := &Offload{fd: fd} tf.writeIovs[0].Base = &validVnetHdr[0] tf.writeIovs[0].SetLen(virtioNetHdrLen) diff --git a/overlay/tio/tun_linux_test.go b/overlay/tio/tun_linux_test.go deleted file mode 100644 index 607ed828..00000000 --- a/overlay/tio/tun_linux_test.go +++ /dev/null @@ -1,38 +0,0 @@ -//go:build !e2e_testing -// +build !e2e_testing - -package tio - -import ( - "testing" - - "github.com/slackhq/nebula/overlay" -) - -var runAdvMSSTests = []struct { - name string - tun *overlay.tun - r overlay.Route - expected int -}{ - // Standard case, default MTU is the device max MTU - {"default", &overlay.tun{DefaultMTU: 1440, MaxMTU: 1440}, overlay.Route{}, 0}, - {"default-min", &overlay.tun{DefaultMTU: 1440, MaxMTU: 1440}, overlay.Route{MTU: 1440}, 0}, - {"default-low", &overlay.tun{DefaultMTU: 1440, MaxMTU: 1440}, overlay.Route{MTU: 1200}, 1160}, - - // Case where we have a route MTU set higher than the default - {"route", &overlay.tun{DefaultMTU: 1440, MaxMTU: 8941}, overlay.Route{}, 1400}, - {"route-min", &overlay.tun{DefaultMTU: 1440, MaxMTU: 8941}, overlay.Route{MTU: 1440}, 1400}, - {"route-high", &overlay.tun{DefaultMTU: 1440, MaxMTU: 8941}, overlay.Route{MTU: 8941}, 0}, -} - -func TestTunAdvMSS(t *testing.T) { - for _, tt := range runAdvMSSTests { - t.Run(tt.name, func(t *testing.T) { - o := tt.tun.advMSS(tt.r) - if o != tt.expected { - t.Errorf("got %d, want %d", o, tt.expected) - } - }) - } -} diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index a81a958d..44e5e9f6 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -176,7 +176,7 @@ func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vnetHdr bool, vpnNetwo var container tio.Container var err error if vnetHdr { - container, err = tio.NewGSOContainer() + container, err = tio.NewOffloadContainer() } else { container, err = tio.NewPollContainer() } diff --git a/overlay/tun_linux_test.go b/overlay/tun_linux_test.go new file mode 100644 index 00000000..1003a165 --- /dev/null +++ b/overlay/tun_linux_test.go @@ -0,0 +1,36 @@ +//go:build !e2e_testing +// +build !e2e_testing + +package overlay + +import ( + "testing" +) + +var runAdvMSSTests = []struct { + name string + tun *tun + r Route + expected int +}{ + // Standard case, default MTU is the device max MTU + {"default", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0}, + {"default-min", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0}, + {"default-low", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160}, + + // Case where we have a route MTU set higher than the default + {"route", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400}, + {"route-min", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400}, + {"route-high", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0}, +} + +func TestTunAdvMSS(t *testing.T) { + for _, tt := range runAdvMSSTests { + t.Run(tt.name, func(t *testing.T) { + o := tt.tun.advMSS(tt.r) + if o != tt.expected { + t.Errorf("got %d, want %d", o, tt.expected) + } + }) + } +} diff --git a/overlay/user.go b/overlay/user.go index d8b53cf0..f15aafd8 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -24,6 +24,7 @@ func NewUserDevice(vpnNetworks []netip.Prefix) (Device, error) { outboundWriter: ow, inboundReader: ir, inboundWriter: iw, + numReaders: 1, }, nil }