new tun interface

This commit is contained in:
JackDoan
2026-04-17 10:25:05 -05:00
parent 398d67e2da
commit 2bdd284993
21 changed files with 875 additions and 463 deletions

View File

@@ -4,15 +4,13 @@
package e2e package e2e
import ( import (
"io" "log/slog"
"net/netip" "net/netip"
"os" "os"
"strings" "strings"
"testing" "testing"
"time" "time"
"log/slog"
"dario.cat/mergo" "dario.cat/mergo"
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
@@ -382,7 +380,7 @@ func getAddrs(ns []netip.Prefix) []netip.Addr {
func NewTestLogger() *slog.Logger { func NewTestLogger() *slog.Logger {
v := os.Getenv("TEST_LOGS") v := os.Getenv("TEST_LOGS")
if v == "" { if v == "" {
return slog.New(slog.NewTextHandler(io.Discard, nil)) return slog.New(slog.DiscardHandler)
} }
level := slog.LevelInfo level := slog.LevelInfo

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"log/slog" "log/slog"
"net/netip" "net/netip"
"sync" "sync"
@@ -13,6 +12,7 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
@@ -88,7 +88,7 @@ type Interface struct {
ctx context.Context ctx context.Context
writers []udp.Conn writers []udp.Conn
readers []io.ReadWriteCloser readers []tio.Queue
wg sync.WaitGroup wg sync.WaitGroup
// fatalErr holds the first unexpected reader error that caused shutdown. // fatalErr holds the first unexpected reader error that caused shutdown.
@@ -187,7 +187,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
routines: c.routines, routines: c.routines,
version: c.version, version: c.version,
writers: make([]udp.Conn, c.routines), writers: make([]udp.Conn, c.routines),
readers: make([]io.ReadWriteCloser, c.routines), readers: make([]tio.Queue, c.routines),
myVpnNetworks: cs.myVpnNetworks, myVpnNetworks: cs.myVpnNetworks,
myVpnNetworksTable: cs.myVpnNetworksTable, myVpnNetworksTable: cs.myVpnNetworksTable,
myVpnAddrs: cs.myVpnAddrs, myVpnAddrs: cs.myVpnAddrs,
@@ -245,16 +245,14 @@ func (f *Interface) activate() error {
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines)) metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
// Prepare n tun queues // Prepare n tun queues
var reader io.ReadWriteCloser = f.inside
for i := 0; i < f.routines; i++ { for i := 0; i < f.routines; i++ {
if i > 0 { if i > 0 {
reader, err = f.inside.NewMultiQueueReader() if err = f.inside.NewMultiQueueReader(); err != nil {
if err != nil {
return err return err
} }
} }
f.readers[i] = reader
} }
f.readers = f.inside.Readers()
f.wg.Add(1) // for us to wait on Close() to return f.wg.Add(1) // for us to wait on Close() to return
if err = f.inside.Activate(); err != nil { if err = f.inside.Activate(); err != nil {
@@ -328,8 +326,7 @@ func (f *Interface) listenOut(i int) {
f.l.Debug("underlay reader is done", "reader", i) f.l.Debug("underlay reader is done", "reader", i)
} }
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { func (f *Interface) listenIn(reader tio.Queue, i int) {
packet := make([]byte, mtu)
out := make([]byte, mtu) out := make([]byte, mtu)
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
@@ -337,7 +334,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout) conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
for { for {
n, err := reader.Read(packet) pkts, err := reader.Read()
if err != nil { if err != nil {
if !f.closed.Load() { if !f.closed.Load() {
f.l.Error("Error while reading outbound packet, closing", "error", err, "reader", i) f.l.Error("Error while reading outbound packet, closing", "error", err, "reader", i)
@@ -345,8 +342,10 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
} }
break break
} }
for _, pkt := range pkts {
f.consumeInsidePacket(pkt.Bytes, fwPacket, nb, out, i, conntrackCache.Get())
}
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get())
} }
f.l.Debug("overlay reader is done", "reader", i) f.l.Debug("overlay reader is done", "reader", i)

View File

@@ -4,15 +4,21 @@ import (
"io" "io"
"net/netip" "net/netip"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
// defaultBatchBufSize is the per-Queue scratch size for Read on backends
// that don't do TSO segmentation. 65535 covers any single IP packet.
const defaultBatchBufSize = 65535
type Device interface { type Device interface {
io.ReadWriteCloser io.Closer
Activate() error Activate() error
Networks() []netip.Prefix Networks() []netip.Prefix
Name() string Name() string
RoutesFor(netip.Addr) routing.Gateways RoutesFor(netip.Addr) routing.Gateways
SupportsMultiqueue() bool SupportsMultiqueue() bool
NewMultiQueueReader() (io.ReadWriteCloser, error) NewMultiQueueReader() error
Readers() []tio.Queue
} }

View File

@@ -4,9 +4,9 @@ package overlaytest
import ( import (
"errors" "errors"
"io"
"net/netip" "net/netip"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
@@ -31,8 +31,8 @@ func (NoopTun) Name() string {
return "noop" return "noop"
} }
func (NoopTun) Read([]byte) (int, error) { func (NoopTun) Read() ([]tio.Packet, error) {
return 0, nil return nil, nil
} }
func (NoopTun) Write([]byte) (int, error) { func (NoopTun) Write([]byte) (int, error) {
@@ -43,8 +43,12 @@ func (NoopTun) SupportsMultiqueue() bool {
return false return false
} }
func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (NoopTun) NewMultiQueueReader() error {
return nil, errors.New("unsupported") return errors.New("unsupported")
}
func (NoopTun) Readers() []tio.Queue {
return []tio.Queue{NoopTun{}}
} }
func (NoopTun) Close() error { func (NoopTun) Close() error {

View File

@@ -0,0 +1,80 @@
package tio
import (
"encoding/binary"
"errors"
"fmt"
"golang.org/x/sys/unix"
)
type pollQueueSet struct {
pq []*Poll
// pqi is exactly the same as pq, but stored as the interface type
pqi []Queue
shutdownFd int
}
func NewPollQueueSet() (QueueSet, error) {
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
if err != nil {
return nil, fmt.Errorf("failed to create eventfd: %w", err)
}
out := &pollQueueSet{
pq: []*Poll{},
pqi: []Queue{},
shutdownFd: shutdownFd,
}
return out, nil
}
func (c *pollQueueSet) Queues() []Queue {
return c.pqi
}
func (c *pollQueueSet) Add(fd int) error {
x, err := newPoll(fd, c.shutdownFd)
if err != nil {
return err
}
c.pq = append(c.pq, x)
c.pqi = append(c.pqi, x)
return nil
}
func (c *pollQueueSet) wakeForShutdown() error {
var buf [8]byte
binary.NativeEndian.PutUint64(buf[:], 1)
_, err := unix.Write(c.shutdownFd, buf[:])
return err
}
func (c *pollQueueSet) Close() error {
if c.shutdownFd < 0 {
return nil
}
errs := []error{}
if err := c.wakeForShutdown(); err != nil {
errs = append(errs, err)
}
for _, x := range c.pq {
if err := x.Close(); err != nil {
errs = append(errs, err)
}
}
// All Polls reference shutdownFd in their pollfd arrays, so close it
// only after every Poll.Close has returned.
if err := unix.Close(c.shutdownFd); err != nil {
errs = append(errs, err)
}
c.shutdownFd = -1
return errors.Join(errs...)
}

123
overlay/tio/tio.go Normal file
View File

@@ -0,0 +1,123 @@
package tio
import (
"io"
)
// QueueSet holds one or many Queue objects and helps close them in an orderly way.
type QueueSet interface {
io.Closer
Queues() []Queue
// Add takes a tun fd, adds it to the set, and prepares it for use as a Queue.
Add(fd int) error
}
// Capabilities advertises which kernel offload features a Queue
// successfully negotiated. Callers consult this to decide which coalescers
// to wire onto the write path — a Queue without TSO can't usefully accept a
// TCPCoalescer, and a Queue without USO can't accept a UDPCoalescer.
type Capabilities struct {
// TSO means the FD was opened with IFF_VNET_HDR and the kernel agreed
// to TUN_F_TSO4|TSO6 — i.e. WriteGSO with GSOProtoTCP is safe.
TSO bool
// USO means the kernel additionally agreed to TUN_F_USO4|USO6, so
// WriteGSO with GSOProtoUDP is safe. Linux ≥ 6.2.
USO bool
}
// Queue is a readable/writable Poll queue. One Queue is driven by a single
// read goroutine plus a single writer (see Write below).
type Queue interface {
io.Closer
// Read returns one or more packets. The returned Packet.Bytes slices
// are borrowed from the Queue's internal buffer and are only valid
// until the next Read or Close on this Queue - callers must encrypt
// or copy each slice before the next call. A Packet may carry a
// GSO/USO superpacket (see GSOInfo); when GSO.IsSuperpacket() is
// true the caller must segment Bytes before treating it as a single
// IP datagram. Not safe for concurrent Reads.
Read() ([]Packet, error)
// Write emits a single packet on the plaintext (outside→inside)
// delivery path. Not safe for concurrent Writes.
Write(p []byte) (int, error)
}
// Packet is the unit Queue.Read returns. Bytes points into the queue's
// internal buffer and is only valid until the next Read or Close on the
// queue that produced it. GSO is the zero value for an already-segmented
// IP datagram; when non-zero it describes a kernel-supplied TSO/USO
// superpacket the caller must segment before consuming.
type Packet struct {
Bytes []byte
GSO GSOInfo
}
// GSOInfo describes a kernel-supplied superpacket sitting in Packet.Bytes.
// The zero value means "not a superpacket" — Bytes is one regular IP
// datagram and no segmentation is required.
type GSOInfo struct {
// Size is the GSO segment size: max payload bytes per segment
// (== TCP MSS for TSO, == UDP payload chunk for USO). Zero means
// not a superpacket.
Size uint16
// HdrLen is the total L3+L4 header length within Bytes (already
// corrected via correctHdrLen, so safe to slice on).
HdrLen uint16
// CsumStart is the L4 header offset inside Bytes (== L3 header
// length).
CsumStart uint16
// Proto picks the L4 protocol (TCP or UDP) so the segmenter knows
// which checksum/header layout to apply.
Proto GSOProto
}
// IsSuperpacket reports whether g describes a multi-segment GSO/USO
// superpacket that needs segmentation before its bytes can be encrypted
// and sent on the wire.
func (g GSOInfo) IsSuperpacket() bool { return g.Size > 0 }
// Clone returns a Packet whose Bytes is a freshly allocated copy of p.Bytes,
// safe to retain past the next Read or Close on the originating Queue.
// GSO metadata is copied verbatim. Use this only when a caller genuinely
// needs to outlive the borrowed-slice contract — the hot path reads should
// continue to consume the borrow synchronously to avoid the allocation.
func (p Packet) Clone() Packet {
if p.Bytes == nil {
return p
}
cp := make([]byte, len(p.Bytes))
copy(cp, p.Bytes)
return Packet{Bytes: cp, GSO: p.GSO}
}
// CapsProvider is an optional interface implemented by Queues that
// successfully negotiated kernel offload features at open time. Callers
// pick a write-path coalescer based on the result. Queues that don't
// implement it are treated as having no offload capability — callers must
// fall back to plain per-packet writes.
type CapsProvider interface {
Capabilities() Capabilities
}
// QueueCapabilities returns q's negotiated offload capabilities, or the
// zero value when q does not advertise any.
func QueueCapabilities(q Queue) Capabilities {
if cp, ok := q.(CapsProvider); ok {
return cp.Capabilities()
}
return Capabilities{}
}
// GSOProto selects the L4 protocol for a GSO superpacket. Determines which
// VIRTIO_NET_HDR_GSO_* type the writer stamps and which checksum offset
// inside the transport header virtio NEEDS_CSUM expects.
type GSOProto uint8
const (
GSOProtoNone GSOProto = iota
GSOProtoTCP
GSOProtoUDP
)

View File

@@ -0,0 +1,167 @@
package tio
import (
"fmt"
"os"
"sync"
"sync/atomic"
"golang.org/x/sys/unix"
)
type Poll struct {
fd int
readPoll [2]unix.PollFd
writePoll [2]unix.PollFd
writeLock sync.Mutex
closed atomic.Bool
readBuf []byte
batchRet [1]Packet
}
func newPoll(fd int, shutdownFd int) (*Poll, error) {
if err := unix.SetNonblock(fd, true); err != nil {
_ = unix.Close(fd)
return nil, fmt.Errorf("failed to set Poll device as nonblocking: %w", err)
}
out := &Poll{
fd: fd,
readBuf: make([]byte, 65535),
readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
writePoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLOUT},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
writeLock: sync.Mutex{},
}
return out, nil
}
// blockOnRead waits until the Poll fd is readable or shutdown has been signaled.
// Returns os.ErrClosed if Close was called.
func (t *Poll) blockOnRead() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(t.readPoll[:], -1)
if err != unix.EINTR {
break
}
}
tunEvents := t.readPoll[0].Revents
shutdownEvents := t.readPoll[1].Revents
t.readPoll[0].Revents = 0
t.readPoll[1].Revents = 0
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
}
if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (t *Poll) blockOnWrite() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(t.writePoll[:], -1)
if err != unix.EINTR {
break
}
}
t.writeLock.Lock()
tunEvents := t.writePoll[0].Revents
shutdownEvents := t.writePoll[1].Revents
t.writePoll[0].Revents = 0
t.writePoll[1].Revents = 0
t.writeLock.Unlock()
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
}
if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (t *Poll) Read() ([]Packet, error) {
n, err := t.readOne(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = Packet{Bytes: t.readBuf[:n]}
return t.batchRet[:], nil
}
func (t *Poll) readOne(to []byte) (int, error) {
for {
n, errno := unix.Read(t.fd, to)
if errno == nil {
return n, nil
}
switch errno {
case unix.EAGAIN:
if err := t.blockOnRead(); err != nil {
return 0, err
}
case unix.EINTR:
// retry
case unix.EBADF:
return 0, os.ErrClosed
default:
return 0, errno
}
}
}
func (t *Poll) Write(from []byte) (int, error) {
for {
n, errno := unix.Write(t.fd, from)
if errno == nil {
return n, nil
}
switch errno {
case unix.EAGAIN:
if err := t.blockOnWrite(); err != nil {
return 0, err
}
case unix.EINTR:
// retry
case unix.EBADF:
return 0, os.ErrClosed
default:
return 0, errno
}
}
}
func (t *Poll) Close() error {
if t.closed.Swap(true) {
return nil
}
//shutdownFd is owned by the container, so we should not close it
var err error
if t.fd >= 0 {
err = unix.Close(t.fd)
t.fd = -1
}
return err
}
func (t *Poll) Capabilities() Capabilities {
return Capabilities{TSO: false, USO: false}
}

View File

@@ -0,0 +1,103 @@
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
package tio
import (
"errors"
"os"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)
// newReadPipe returns a read fd. The matching write fd is registered for cleanup.
// The caller takes ownership of the read fd (pass it into a QueueSet).
func newReadPipe(t *testing.T) int {
t.Helper()
var fds [2]int
if err := unix.Pipe2(fds[:], unix.O_CLOEXEC); err != nil {
t.Fatalf("pipe2: %v", err)
}
t.Cleanup(func() { _ = unix.Close(fds[1]) })
return fds[0]
}
func TestPoll_WakeForShutdown_WakesFriends(t *testing.T) {
pipe1 := newReadPipe(t)
pipe2 := newReadPipe(t)
parent, err := NewPollQueueSet()
require.NoError(t, err)
require.NoError(t, parent.Add(pipe1))
require.NoError(t, parent.Add(pipe2))
t.Cleanup(func() {
_ = unix.Close(pipe1)
_ = unix.Close(pipe2)
})
readers := parent.Queues()
errs := make([]error, len(readers))
var wg sync.WaitGroup
for i, r := range readers {
wg.Add(1)
go func(i int, r Queue) {
defer wg.Done()
_, errs[i] = r.Read()
}(i, r)
}
time.Sleep(50 * time.Millisecond)
if err := parent.Close(); err != nil {
t.Fatalf("Close: %v", err)
}
done := make(chan struct{})
go func() { wg.Wait(); close(done) }()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("readers did not wake")
}
for i, err := range errs {
if !errors.Is(err, os.ErrClosed) {
t.Errorf("reader %d: expected os.ErrClosed, got %v", i, err)
}
}
}
func TestPoll_Close_Idempotent(t *testing.T) {
tf, err := newPoll(newReadPipe(t), 1)
require.NoError(t, err)
if err := tf.Close(); err != nil {
t.Fatalf("first Close: %v", err)
}
if err := tf.Close(); err != nil {
t.Fatalf("second Close should be a no-op, got %v", err)
}
}
func TestPollQueueSet_Close_ClosesEventfd(t *testing.T) {
qs, err := NewPollQueueSet()
require.NoError(t, err)
require.NoError(t, qs.Add(newReadPipe(t)))
fd := qs.(*pollQueueSet).shutdownFd
require.NoError(t, qs.Close())
// Closing the eventfd again should fail with EBADF, proving Close
// actually released it.
if err := unix.Close(fd); err == nil {
t.Fatalf("eventfd %d still open after QueueSet.Close", fd)
}
// Second Close must be a no-op (and must not double-close the eventfd
// in case the kernel handed it out to another caller in the meantime).
if err := qs.Close(); err != nil {
t.Fatalf("second Close: %v", err)
}
}

View File

@@ -13,17 +13,38 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
) )
type tun struct { type tun struct {
io.ReadWriteCloser rwc io.ReadWriteCloser
fd int fd int
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *slog.Logger l *slog.Logger
readBuf []byte
batchRet [1]tio.Packet
}
func (t *tun) Read() ([]tio.Packet, error) {
n, err := t.rwc.Read(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
return t.batchRet[:], nil
}
func (t *tun) Write(p []byte) (int, error) {
return t.rwc.Write(p)
}
func (t *tun) Close() error {
return t.rwc.Close()
} }
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
@@ -32,10 +53,11 @@ func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
t := &tun{ t := &tun{
ReadWriteCloser: file, rwc: file,
fd: deviceFd, fd: deviceFd,
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
l: l, l: l,
readBuf: make([]byte, defaultBatchBufSize),
} }
err := t.reload(c, true) err := t.reload(c, true)
@@ -62,7 +84,7 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
return r return r
} }
func (t tun) Activate() error { func (t *tun) Activate() error {
return nil return nil
} }
@@ -99,6 +121,10 @@ func (t *tun) SupportsMultiqueue() bool {
return false return false
} }
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() error {
return nil, fmt.Errorf("TODO: multiqueue not implemented for android") return fmt.Errorf("TODO: multiqueue not implemented for android")
}
func (t *tun) Readers() []tio.Queue {
return []tio.Queue{t}
} }

View File

@@ -16,6 +16,7 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route" netroute "golang.org/x/net/route"
@@ -23,7 +24,7 @@ import (
) )
type tun struct { type tun struct {
io.ReadWriteCloser rwc io.ReadWriteCloser
Device string Device string
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
DefaultMTU int DefaultMTU int
@@ -34,6 +35,9 @@ type tun struct {
// cache out buffer since we need to prepend 4 bytes for tun metadata // cache out buffer since we need to prepend 4 bytes for tun metadata
out []byte out []byte
readBuf []byte
batchRet [1]tio.Packet
} }
type ifReq struct { type ifReq struct {
@@ -124,11 +128,12 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t
} }
t := &tun{ t := &tun{
ReadWriteCloser: os.NewFile(uintptr(fd), ""), rwc: os.NewFile(uintptr(fd), ""),
Device: name, Device: name,
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
DefaultMTU: c.GetInt("tun.mtu", DefaultMTU), DefaultMTU: c.GetInt("tun.mtu", DefaultMTU),
l: l, l: l,
readBuf: make([]byte, defaultBatchBufSize),
} }
err = t.reload(c, true) err = t.reload(c, true)
@@ -158,8 +163,8 @@ func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, e
} }
func (t *tun) Close() error { func (t *tun) Close() error {
if t.ReadWriteCloser != nil { if t.rwc != nil {
return t.ReadWriteCloser.Close() return t.rwc.Close()
} }
return nil return nil
} }
@@ -502,15 +507,24 @@ func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
return nil return nil
} }
func (t *tun) Read(to []byte) (int, error) { func (t *tun) readOne(to []byte) (int, error) {
buf := make([]byte, len(to)+4) buf := make([]byte, len(to)+4)
n, err := t.ReadWriteCloser.Read(buf) n, err := t.rwc.Read(buf)
copy(to, buf[4:]) copy(to, buf[4:])
return n - 4, err return n - 4, err
} }
func (t *tun) Read() ([]tio.Packet, error) {
n, err := t.readOne(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
return t.batchRet[:], nil
}
// Write is only valid for single threaded use // Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) { func (t *tun) Write(from []byte) (int, error) {
buf := t.out buf := t.out
@@ -536,7 +550,7 @@ func (t *tun) Write(from []byte) (int, error) {
copy(buf[4:], from) copy(buf[4:], from)
n, err := t.ReadWriteCloser.Write(buf) n, err := t.rwc.Write(buf)
return n - 4, err return n - 4, err
} }
@@ -552,6 +566,10 @@ func (t *tun) SupportsMultiqueue() bool {
return false return false
} }
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() error {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") return fmt.Errorf("TODO: multiqueue not implemented for darwin")
}
func (t *tun) Readers() []tio.Queue {
return []tio.Queue{t}
} }

View File

@@ -10,6 +10,7 @@ import (
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
@@ -21,6 +22,42 @@ type disabledTun struct {
tx metrics.Counter tx metrics.Counter
rx metrics.Counter rx metrics.Counter
l *slog.Logger l *slog.Logger
numReaders int
}
// disabledQueue is one tio.Queue view onto a shared disabledTun. Each queue
// owns a private batchRet so concurrent Read calls from different reader
// goroutines do not race on the returned slice.
type disabledQueue struct {
parent *disabledTun
batchRet [1]tio.Packet
}
func (q *disabledQueue) Read() ([]tio.Packet, error) {
r, ok := <-q.parent.read
if !ok {
return nil, io.EOF
}
q.parent.tx.Inc(1)
if q.parent.l.Enabled(context.Background(), slog.LevelDebug) {
q.parent.l.Debug("Write payload", "raw", prettyPacket(r))
}
q.batchRet[0] = tio.Packet{Bytes: r}
return q.batchRet[:], nil
}
// Write on a queue forwards to the underlying disabledTun. All queues share
// one ICMP-handling/log path so this is a thin pass-through.
func (q *disabledQueue) Write(b []byte) (int, error) {
return q.parent.Write(b)
}
// Close on a queue is a no-op. The shared channel and metrics are owned by
// the disabledTun; Close on the device tears them down once for everybody.
func (q *disabledQueue) Close() error {
return nil
} }
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *slog.Logger) *disabledTun { func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *slog.Logger) *disabledTun {
@@ -28,6 +65,7 @@ func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled boo
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
read: make(chan []byte, queueLen), read: make(chan []byte, queueLen),
l: l, l: l,
numReaders: 1,
} }
if metricsEnabled { if metricsEnabled {
@@ -57,24 +95,6 @@ func (*disabledTun) Name() string {
return "disabled" return "disabled"
} }
func (t *disabledTun) Read(b []byte) (int, error) {
r, ok := <-t.read
if !ok {
return 0, io.EOF
}
if len(r) > len(b) {
return 0, fmt.Errorf("packet larger than mtu: %d > %d bytes", len(r), len(b))
}
t.tx.Inc(1)
if t.l.Enabled(context.Background(), slog.LevelDebug) {
t.l.Debug("Write payload", "raw", prettyPacket(r))
}
return copy(b, r), nil
}
func (t *disabledTun) handleICMPEchoRequest(b []byte) bool { func (t *disabledTun) handleICMPEchoRequest(b []byte) bool {
out := make([]byte, len(b)) out := make([]byte, len(b))
out = iputil.CreateICMPEchoResponse(b, out) out = iputil.CreateICMPEchoResponse(b, out)
@@ -110,8 +130,17 @@ func (t *disabledTun) SupportsMultiqueue() bool {
return true return true
} }
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *disabledTun) NewMultiQueueReader() error {
return t, nil t.numReaders++
return nil
}
func (t *disabledTun) Readers() []tio.Queue {
out := make([]tio.Queue, t.numReaders)
for i := range t.numReaders {
out[i] = &disabledQueue{parent: t}
}
return out
} }
func (t *disabledTun) Close() error { func (t *disabledTun) Close() error {

View File

@@ -1,120 +0,0 @@
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
package overlay
import (
"errors"
"os"
"sync"
"testing"
"time"
"golang.org/x/sys/unix"
)
// newReadPipe returns a read fd. The matching write fd is registered for cleanup.
// The caller takes ownership of the read fd (pass it to newTunFd / newFriend).
func newReadPipe(t *testing.T) int {
t.Helper()
var fds [2]int
if err := unix.Pipe2(fds[:], unix.O_CLOEXEC); err != nil {
t.Fatalf("pipe2: %v", err)
}
t.Cleanup(func() { _ = unix.Close(fds[1]) })
return fds[0]
}
func TestTunFile_WakeForShutdown_UnblocksRead(t *testing.T) {
tf, err := newTunFd(newReadPipe(t))
if err != nil {
t.Fatalf("newTunFd: %v", err)
}
t.Cleanup(func() { _ = tf.Close() })
done := make(chan error, 1)
go func() {
_, err := tf.Read(make([]byte, 64))
done <- err
}()
// Verify Read is actually blocked in poll.
select {
case err := <-done:
t.Fatalf("Read returned before shutdown signal: %v", err)
case <-time.After(50 * time.Millisecond):
}
if err := tf.wakeForShutdown(); err != nil {
t.Fatalf("wakeForShutdown: %v", err)
}
select {
case err := <-done:
if !errors.Is(err, os.ErrClosed) {
t.Fatalf("expected os.ErrClosed, got %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("Read did not wake on shutdown")
}
}
func TestTunFile_WakeForShutdown_WakesFriends(t *testing.T) {
parent, err := newTunFd(newReadPipe(t))
if err != nil {
t.Fatalf("newTunFd: %v", err)
}
friend, err := parent.newFriend(newReadPipe(t))
if err != nil {
_ = parent.Close()
t.Fatalf("newFriend: %v", err)
}
t.Cleanup(func() {
_ = friend.Close()
_ = parent.Close()
})
readers := []*tunFile{parent, friend}
errs := make([]error, len(readers))
var wg sync.WaitGroup
for i, r := range readers {
wg.Add(1)
go func(i int, r *tunFile) {
defer wg.Done()
_, errs[i] = r.Read(make([]byte, 64))
}(i, r)
}
time.Sleep(50 * time.Millisecond)
if err := parent.wakeForShutdown(); err != nil {
t.Fatalf("wakeForShutdown: %v", err)
}
done := make(chan struct{})
go func() { wg.Wait(); close(done) }()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("readers did not wake")
}
for i, err := range errs {
if !errors.Is(err, os.ErrClosed) {
t.Errorf("reader %d: expected os.ErrClosed, got %v", i, err)
}
}
}
func TestTunFile_Close_Idempotent(t *testing.T) {
tf, err := newTunFd(newReadPipe(t))
if err != nil {
t.Fatalf("newTunFd: %v", err)
}
if err := tf.Close(); err != nil {
t.Fatalf("first Close: %v", err)
}
if err := tf.Close(); err != nil {
t.Fatalf("second Close should be a no-op, got %v", err)
}
}

View File

@@ -7,7 +7,6 @@ import (
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
"io"
"io/fs" "io/fs"
"log/slog" "log/slog"
"net/netip" "net/netip"
@@ -20,7 +19,7 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route" netroute "golang.org/x/net/route"
@@ -103,6 +102,9 @@ type tun struct {
readPoll [2]unix.PollFd readPoll [2]unix.PollFd
writePoll [2]unix.PollFd writePoll [2]unix.PollFd
closed atomic.Bool closed atomic.Bool
readBuf []byte
batchRet [1]tio.Packet
} }
// blockOnRead waits until the tun fd is readable or shutdown has been signaled. // blockOnRead waits until the tun fd is readable or shutdown has been signaled.
@@ -157,7 +159,16 @@ func (t *tun) blockOnWrite() error {
return nil return nil
} }
func (t *tun) Read(to []byte) (int, error) { func (t *tun) Read() ([]tio.Packet, error) {
n, err := t.readOne(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
return t.batchRet[:], nil
}
func (t *tun) readOne(to []byte) (int, error) {
// first 4 bytes is protocol family, in network byte order // first 4 bytes is protocol family, in network byte order
var head [4]byte var head [4]byte
iovecs := [2]syscall.Iovec{ iovecs := [2]syscall.Iovec{
@@ -375,6 +386,7 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t
MTU: c.GetInt("tun.mtu", DefaultMTU), MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l, l: l,
fd: fd, fd: fd,
readBuf: make([]byte, defaultBatchBufSize),
shutdownR: shutdownR, shutdownR: shutdownR,
shutdownW: shutdownW, shutdownW: shutdownW,
readPoll: [2]unix.PollFd{ readPoll: [2]unix.PollFd{
@@ -565,8 +577,8 @@ func (t *tun) SupportsMultiqueue() bool {
return false return false
} }
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() error {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd") return fmt.Errorf("TODO: multiqueue not implemented for freebsd")
} }
func (t *tun) addRoutes(logErrors bool) error { func (t *tun) addRoutes(logErrors bool) error {
@@ -593,6 +605,10 @@ func (t *tun) addRoutes(logErrors bool) error {
return nil return nil
} }
func (t *tun) Readers() []tio.Queue {
return []tio.Queue{t}
}
func (t *tun) removeRoutes(routes []Route) error { func (t *tun) removeRoutes(routes []Route) error {
for _, r := range routes { for _, r := range routes {
if !r.Install { if !r.Install {

View File

@@ -16,16 +16,37 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
) )
type tun struct { type tun struct {
io.ReadWriteCloser rwc io.ReadWriteCloser
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *slog.Logger l *slog.Logger
readBuf []byte
batchRet [1]tio.Packet
}
func (t *tun) Read() ([]tio.Packet, error) {
n, err := t.rwc.Read(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
return t.batchRet[:], nil
}
func (t *tun) Write(p []byte) (int, error) {
return t.rwc.Write(p)
}
func (t *tun) Close() error {
return t.rwc.Close()
} }
func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) { func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
@@ -36,8 +57,9 @@ func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip
file := os.NewFile(uintptr(deviceFd), "/dev/tun") file := os.NewFile(uintptr(deviceFd), "/dev/tun")
t := &tun{ t := &tun{
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
ReadWriteCloser: &tunReadCloser{f: file}, rwc: &tunReadCloser{f: file},
l: l, l: l,
readBuf: make([]byte, defaultBatchBufSize),
} }
err := t.reload(c, true) err := t.reload(c, true)
@@ -155,6 +177,10 @@ func (t *tun) SupportsMultiqueue() bool {
return false return false
} }
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() error {
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios") return fmt.Errorf("TODO: multiqueue not implemented for ios")
}
func (t *tun) Readers() []tio.Queue {
return []tio.Queue{t}
} }

View File

@@ -4,9 +4,7 @@
package overlay package overlay
import ( import (
"encoding/binary"
"fmt" "fmt"
"io"
"log/slog" "log/slog"
"net" "net"
"net/netip" "net/netip"
@@ -19,180 +17,15 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking.
// A shared eventfd allows Close to wake all readers blocked in poll.
type tunFile struct {
fd int
shutdownFd int
lastOne bool
readPoll [2]unix.PollFd
writePoll [2]unix.PollFd
closed bool
}
// newFriend makes a tunFile for a MultiQueueReader that copies the shutdown eventfd from the parent tun
func (r *tunFile) newFriend(fd int) (*tunFile, error) {
if err := unix.SetNonblock(fd, true); err != nil {
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
}
return &tunFile{
fd: fd,
shutdownFd: r.shutdownFd,
readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
},
writePoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLOUT},
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
},
}, nil
}
func newTunFd(fd int) (*tunFile, error) {
if err := unix.SetNonblock(fd, true); err != nil {
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
}
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
if err != nil {
return nil, fmt.Errorf("failed to create eventfd: %w", err)
}
out := &tunFile{
fd: fd,
shutdownFd: shutdownFd,
lastOne: true,
readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
writePoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLOUT},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
}
return out, nil
}
func (r *tunFile) blockOnRead() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(r.readPoll[:], -1)
if err != unix.EINTR {
break
}
}
//always reset these!
tunEvents := r.readPoll[0].Revents
shutdownEvents := r.readPoll[1].Revents
r.readPoll[0].Revents = 0
r.readPoll[1].Revents = 0
//do the err check before trusting the potentially bogus bits we just got
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
} else if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (r *tunFile) blockOnWrite() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(r.writePoll[:], -1)
if err != unix.EINTR {
break
}
}
//always reset these!
tunEvents := r.writePoll[0].Revents
shutdownEvents := r.writePoll[1].Revents
r.writePoll[0].Revents = 0
r.writePoll[1].Revents = 0
//do the err check before trusting the potentially bogus bits we just got
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
} else if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (r *tunFile) Read(buf []byte) (int, error) {
for {
if n, err := unix.Read(r.fd, buf); err == nil {
return n, nil
} else if err == unix.EAGAIN {
if err = r.blockOnRead(); err != nil {
return 0, err
}
continue
} else if err == unix.EINTR {
continue
} else if err == unix.EBADF {
return 0, os.ErrClosed
} else {
return 0, err
}
}
}
func (r *tunFile) Write(buf []byte) (int, error) {
for {
if n, err := unix.Write(r.fd, buf); err == nil {
return n, nil
} else if err == unix.EAGAIN {
if err = r.blockOnWrite(); err != nil {
return 0, err
}
continue
} else if err == unix.EINTR {
continue
} else if err == unix.EBADF {
return 0, os.ErrClosed
} else {
return 0, err
}
}
}
func (r *tunFile) wakeForShutdown() error {
var buf [8]byte
binary.NativeEndian.PutUint64(buf[:], 1)
_, err := unix.Write(int(r.readPoll[1].Fd), buf[:])
return err
}
func (r *tunFile) Close() error {
if r.closed { // avoid closing more than once. Technically a fd could get re-used, which would be a problem
return nil
}
r.closed = true
if r.lastOne {
_ = unix.Close(r.shutdownFd)
}
return unix.Close(r.fd)
}
type tun struct { type tun struct {
*tunFile readers tio.QueueSet
readers []*tunFile
closeLock sync.Mutex closeLock sync.Mutex
Device string Device string
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
@@ -239,7 +72,9 @@ type ifreqQLEN struct {
} }
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
t, err := newTunGeneric(c, l, deviceFd, vpnNetworks) // We don't know what flags the caller opened this fd with and can't turn
// on IFF_VNET_HDR after TUNSETIFF, so skip offload on inherited fds.
t, err := newTunGeneric(c, l, deviceFd, false, false, vpnNetworks)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -249,46 +84,65 @@ func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip
return t, nil return t, nil
} }
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { // openTunDev opens /dev/net/tun, creating the device node first if it's
// missing (docker containers occasionally omit it).
func openTunDev() (int, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil { if err == nil {
// If /dev/net/tun doesn't exist, try to create it (will happen in docker) return fd, nil
if os.IsNotExist(err) {
err = os.MkdirAll("/dev/net", 0755)
if err != nil {
return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
} }
err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))) if !os.IsNotExist(err) {
if err != nil { return -1, err
return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err) }
if err = os.MkdirAll("/dev/net", 0755); err != nil {
return -1, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
}
if err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))); err != nil {
return -1, fmt.Errorf("failed to create /dev/net/tun: %w", err)
} }
fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0) fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil { if err != nil {
return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err) return -1, fmt.Errorf("created /dev/net/tun, but still failed: %w", err)
}
} else {
return nil, err
}
} }
return fd, nil
}
// tunSetIff runs TUNSETIFF with the given flags and returns the kernel-chosen
// device name on success.
func tunSetIff(fd int, name string, flags uint16) (string, error) {
var req ifReq var req ifReq
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI) req.Flags = flags
copy(req.Name[:], name)
if err := ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
return "", err
}
return strings.Trim(string(req.Name[:]), "\x00"), nil
}
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
baseFlags := uint16(unix.IFF_TUN | unix.IFF_NO_PI)
if multiqueue { if multiqueue {
req.Flags |= unix.IFF_MULTI_QUEUE baseFlags |= unix.IFF_MULTI_QUEUE
} }
nameStr := c.GetString("tun.dev", "") nameStr := c.GetString("tun.dev", "")
copy(req.Name[:], nameStr)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
_ = unix.Close(fd)
return nil, &NameError{
Name: nameStr,
Underlying: err,
}
}
name := strings.Trim(string(req.Name[:]), "\x00")
t, err := newTunGeneric(c, l, fd, vpnNetworks) // First try to enable IFF_VNET_HDR via TUNSETIFF and negotiate TUN_F_*
// offloads via TUNSETOFFLOAD so we can receive TSO/USO superpackets.
// We try TSO+USO first, fall back to TSO-only on kernels without USO
// (Linux < 6.2), and finally give up on virtio headers entirely and
// reopen as a plain TUN if neither offload mask is accepted.
fd, err := openTunDev()
if err != nil {
return nil, err
}
name, err := tunSetIff(fd, nameStr, baseFlags)
if err != nil {
_ = unix.Close(fd)
return nil, &NameError{Name: nameStr, Underlying: err}
}
t, err := newTunGeneric(c, l, fd, false, false, vpnNetworks)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -299,15 +153,21 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue
} }
// newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error. // newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error.
func newTunGeneric(c *config.C, l *slog.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) { func newTunGeneric(c *config.C, l *slog.Logger, fd int, vnetHdr, usoEnabled bool, vpnNetworks []netip.Prefix) (*tun, error) {
tfd, err := newTunFd(fd) qs, err := tio.NewPollQueueSet()
if err != nil { if err != nil {
_ = unix.Close(fd) _ = unix.Close(fd)
return nil, err return nil, err
} }
err = qs.Add(fd)
if err != nil {
_ = unix.Close(fd)
return nil, err
}
t := &tun{ t := &tun{
tunFile: tfd, readers: qs,
readers: []*tunFile{tfd},
closeLock: sync.Mutex{}, closeLock: sync.Mutex{},
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500), TXQueueLen: c.GetInt("tun.tx_queue", 500),
@@ -410,32 +270,29 @@ func (t *tun) SupportsMultiqueue() bool {
return true return true
} }
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() error {
t.closeLock.Lock() t.closeLock.Lock()
defer t.closeLock.Unlock() defer t.closeLock.Unlock()
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil { if err != nil {
return nil, err return err
} }
var req ifReq flags := uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
copy(req.Name[:], t.Device) if _, err = tunSetIff(fd, t.Device, flags); err != nil {
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
_ = unix.Close(fd) _ = unix.Close(fd)
return nil, err return err
} }
out, err := t.tunFile.newFriend(fd) err = t.readers.Add(fd)
if err != nil { if err != nil {
_ = unix.Close(fd) _ = unix.Close(fd)
return nil, err return err
} }
t.readers = append(t.readers, out) return nil
return out, nil
} }
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
@@ -603,6 +460,15 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
Table: unix.RT_TABLE_MAIN, Table: unix.RT_TABLE_MAIN,
Type: unix.RTN_UNICAST, Type: unix.RTN_UNICAST,
} }
// Match the metric the kernel uses for its auto-installed connected
// route, so RouteReplace overwrites it in place instead of adding a
// second route at a worse metric. IPv6 connected routes are installed
// at metric 256 (IP6_RT_PRIO_KERN); IPv4 uses 0. Without this, the
// kernel route wins lookups and our MTU / AdvMSS / Features never
// apply on v6.
if cidr.Addr().Is6() {
nr.Priority = 256
}
err := netlink.RouteReplace(&nr) err := netlink.RouteReplace(&nr)
if err != nil { if err != nil {
t.l.Warn("Failed to set default route MTU, retrying", "error", err, "cidr", cidr) t.l.Warn("Failed to set default route MTU, retrying", "error", err, "cidr", cidr)
@@ -869,6 +735,10 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
t.routeTree.Store(newTree) t.routeTree.Store(newTree)
} }
func (t *tun) Readers() []tio.Queue {
return t.readers.Queues()
}
func (t *tun) Close() error { func (t *tun) Close() error {
t.closeLock.Lock() t.closeLock.Lock()
defer t.closeLock.Unlock() defer t.closeLock.Unlock()
@@ -878,32 +748,10 @@ func (t *tun) Close() error {
t.routeChan = nil t.routeChan = nil
} }
// Signal all readers blocked in poll to wake up and exit
_ = t.tunFile.wakeForShutdown()
if t.ioctlFd > 0 { if t.ioctlFd > 0 {
_ = unix.Close(int(t.ioctlFd)) _ = unix.Close(int(t.ioctlFd))
t.ioctlFd = 0 t.ioctlFd = 0
} }
for i := range t.readers { return t.readers.Close()
if i == 0 {
continue //we want to close the zeroth reader last
}
err := t.readers[i].Close()
if err != nil {
t.l.Error("error closing tun reader", "reader", i, "error", err)
} else {
t.l.Info("closed tun reader", "reader", i)
}
}
//this is t.readers[0] too
err := t.tunFile.Close()
if err != nil {
t.l.Error("error closing tun reader", "reader", 0, "error", err)
} else {
t.l.Info("closed tun reader", "reader", 0)
}
return err
} }

View File

@@ -3,7 +3,9 @@
package overlay package overlay
import "testing" import (
"testing"
)
var runAdvMSSTests = []struct { var runAdvMSSTests = []struct {
name string name string

View File

@@ -6,7 +6,6 @@ package overlay
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"log/slog" "log/slog"
"net/netip" "net/netip"
"os" "os"
@@ -17,6 +16,7 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route" netroute "golang.org/x/net/route"
@@ -66,6 +66,22 @@ type tun struct {
l *slog.Logger l *slog.Logger
f *os.File f *os.File
fd int fd int
readBuf []byte
batchRet [1]tio.Packet
}
func (t *tun) Read() ([]tio.Packet, error) {
n, err := t.readOne(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
return t.batchRet[:], nil
}
func (t *tun) Readers() []tio.Queue {
return []tio.Queue{t}
} }
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
@@ -102,6 +118,7 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU), MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l, l: l,
readBuf: make([]byte, defaultBatchBufSize),
} }
err = t.reload(c, true) err = t.reload(c, true)
@@ -141,7 +158,7 @@ func (t *tun) Close() error {
return nil return nil
} }
func (t *tun) Read(to []byte) (int, error) { func (t *tun) readOne(to []byte) (int, error) {
rc, err := t.f.SyscallConn() rc, err := t.f.SyscallConn()
if err != nil { if err != nil {
return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err) return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err)
@@ -394,8 +411,8 @@ func (t *tun) SupportsMultiqueue() bool {
return false return false
} }
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() error {
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd") return fmt.Errorf("TODO: multiqueue not implemented for netbsd")
} }
func (t *tun) addRoutes(logErrors bool) error { func (t *tun) addRoutes(logErrors bool) error {

View File

@@ -6,7 +6,6 @@ package overlay
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"log/slog" "log/slog"
"net/netip" "net/netip"
"os" "os"
@@ -17,6 +16,7 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route" netroute "golang.org/x/net/route"
@@ -59,6 +59,18 @@ type tun struct {
fd int fd int
// cache out buffer since we need to prepend 4 bytes for tun metadata // cache out buffer since we need to prepend 4 bytes for tun metadata
out []byte out []byte
readBuf []byte
batchRet [1]tio.Packet
}
func (t *tun) Read() ([]tio.Packet, error) {
n, err := t.readOne(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
return t.batchRet[:], nil
} }
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
@@ -95,6 +107,7 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU), MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l, l: l,
readBuf: make([]byte, defaultBatchBufSize),
} }
err = t.reload(c, true) err = t.reload(c, true)
@@ -124,7 +137,7 @@ func (t *tun) Close() error {
return nil return nil
} }
func (t *tun) Read(to []byte) (int, error) { func (t *tun) readOne(to []byte) (int, error) {
buf := make([]byte, len(to)+4) buf := make([]byte, len(to)+4)
n, err := t.f.Read(buf) n, err := t.f.Read(buf)
@@ -314,8 +327,8 @@ func (t *tun) SupportsMultiqueue() bool {
return false return false
} }
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() error {
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd") return fmt.Errorf("TODO: multiqueue not implemented for openbsd")
} }
func (t *tun) addRoutes(logErrors bool) error { func (t *tun) addRoutes(logErrors bool) error {
@@ -366,6 +379,10 @@ func (t *tun) deviceBytes() (o [16]byte) {
return return
} }
func (t *tun) Readers() []tio.Queue {
return []tio.Queue{t}
}
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error { func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil { if err != nil {

View File

@@ -14,6 +14,7 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
) )
@@ -28,6 +29,8 @@ type TestTun struct {
closed atomic.Bool closed atomic.Bool
rxPackets chan []byte // Packets to receive into nebula rxPackets chan []byte // Packets to receive into nebula
TxPackets chan []byte // Packets transmitted outside by nebula TxPackets chan []byte // Packets transmitted outside by nebula
batchRet [1]tio.Packet
} }
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
@@ -48,6 +51,9 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*T
l: l, l: l,
rxPackets: make(chan []byte, 10), rxPackets: make(chan []byte, 10),
TxPackets: make(chan []byte, 10), TxPackets: make(chan []byte, 10),
batchRet: [1]tio.Packet{
tio.Packet{Bytes: make([]byte, udp.MTU)},
},
}, nil }, nil
} }
@@ -162,7 +168,17 @@ func (t *TestTun) Close() error {
return nil return nil
} }
func (t *TestTun) Read(b []byte) (int, error) { func (t *TestTun) Read() ([]tio.Packet, error) {
t.batchRet[0].Bytes = t.batchRet[0].Bytes[:udp.MTU]
n, err := t.read(t.batchRet[0].Bytes)
if err != nil {
return nil, err
}
t.batchRet[0].Bytes = t.batchRet[0].Bytes[:n]
return t.batchRet[:], nil
}
func (t *TestTun) read(b []byte) (int, error) {
p, ok := <-t.rxPackets p, ok := <-t.rxPackets
if !ok { if !ok {
return 0, os.ErrClosed return 0, os.ErrClosed
@@ -177,10 +193,14 @@ func (t *TestTun) Read(b []byte) (int, error) {
return n, nil return n, nil
} }
func (t *TestTun) Readers() []tio.Queue {
return []tio.Queue{t}
}
func (t *TestTun) SupportsMultiqueue() bool { func (t *TestTun) SupportsMultiqueue() bool {
return false return false
} }
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *TestTun) NewMultiQueueReader() error {
return nil, fmt.Errorf("TODO: multiqueue not implemented") return fmt.Errorf("TODO: multiqueue not implemented")
} }

View File

@@ -6,7 +6,6 @@ package overlay
import ( import (
"crypto" "crypto"
"fmt" "fmt"
"io"
"log/slog" "log/slog"
"net/netip" "net/netip"
"os" "os"
@@ -18,6 +17,7 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/wintun" "github.com/slackhq/nebula/wintun"
@@ -45,6 +45,18 @@ type winTun struct {
l *slog.Logger l *slog.Logger
tun *wintun.NativeTun tun *wintun.NativeTun
readBuf []byte
batchRet [1]tio.Packet
}
func (t *winTun) Read() ([]tio.Packet, error) {
n, err := t.tun.Read(t.readBuf, 0)
if err != nil {
return nil, err
}
t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
return t.batchRet[:], nil
} }
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (Device, error) { func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (Device, error) {
@@ -69,6 +81,7 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*w
} }
t := &winTun{ t := &winTun{
readBuf: make([]byte, defaultBatchBufSize),
Device: deviceName, Device: deviceName,
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU), MTU: c.GetInt("tun.mtu", DefaultMTU),
@@ -255,10 +268,6 @@ func (t *winTun) Name() string {
return t.Device return t.Device
} }
func (t *winTun) Read(b []byte) (int, error) {
return t.tun.Read(b, 0)
}
func (t *winTun) Write(b []byte) (int, error) { func (t *winTun) Write(b []byte) (int, error) {
return t.tun.Write(b, 0) return t.tun.Write(b, 0)
} }
@@ -267,8 +276,12 @@ func (t *winTun) SupportsMultiqueue() bool {
return false return false
} }
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *winTun) NewMultiQueueReader() error {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") return fmt.Errorf("TODO: multiqueue not implemented for windows")
}
func (t *winTun) Readers() []tio.Queue {
return []tio.Queue{t}
} }
func (t *winTun) Close() error { func (t *winTun) Close() error {

View File

@@ -1,11 +1,13 @@
package overlay package overlay
import ( import (
"errors"
"io" "io"
"log/slog" "log/slog"
"net/netip" "net/netip"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
@@ -28,12 +30,28 @@ func NewUserDevice(vpnNetworks []netip.Prefix) (Device, error) {
type UserDevice struct { type UserDevice struct {
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
numReaders int
outboundReader *io.PipeReader outboundReader *io.PipeReader
outboundWriter *io.PipeWriter outboundWriter *io.PipeWriter
inboundReader *io.PipeReader inboundReader *io.PipeReader
inboundWriter *io.PipeWriter inboundWriter *io.PipeWriter
readBuf []byte
batchRet [1]tio.Packet
}
func (d *UserDevice) Read() ([]tio.Packet, error) {
if d.readBuf == nil {
d.readBuf = make([]byte, defaultBatchBufSize)
}
n, err := d.outboundReader.Read(d.readBuf)
if err != nil {
return nil, err
}
d.batchRet[0] = tio.Packet{Bytes: d.readBuf[:n]}
return d.batchRet[:], nil
} }
func (d *UserDevice) Activate() error { func (d *UserDevice) Activate() error {
@@ -47,23 +65,25 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
} }
func (d *UserDevice) SupportsMultiqueue() bool { func (d *UserDevice) SupportsMultiqueue() bool {
return true return false
} }
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (d *UserDevice) NewMultiQueueReader() error {
return d, nil return errors.New("not implemented")
}
func (d *UserDevice) Readers() []tio.Queue {
return []tio.Queue{d}
} }
func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) { func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) {
return d.inboundReader, d.outboundWriter return d.inboundReader, d.outboundWriter
} }
func (d *UserDevice) Read(p []byte) (n int, err error) {
return d.outboundReader.Read(p)
}
func (d *UserDevice) Write(p []byte) (n int, err error) { func (d *UserDevice) Write(p []byte) (n int, err error) {
return d.inboundWriter.Write(p) return d.inboundWriter.Write(p)
} }
func (d *UserDevice) Close() error { func (d *UserDevice) Close() error {
d.inboundWriter.Close() d.inboundWriter.Close()
d.outboundWriter.Close() d.outboundWriter.Close()