Compare commits

..

7 Commits

Author SHA1 Message Date
Ryan Huber
a4b7f624da sure 2025-11-03 17:23:57 +00:00
Ryan Huber
1c069a8e42 reuse control on gso 2025-11-03 11:14:52 +00:00
Ryan Huber
0d8bd11818 reuse GRO slices 2025-11-03 11:06:07 +00:00
Ryan Huber
5128e2653e reuse packet buffer 2025-11-03 10:52:09 +00:00
Ryan Huber
c73b2dfbc7 fixed fallback for non io_uring packet send/recv 2025-11-03 10:45:30 +00:00
Ryan Huber
3dea761530 fix compile for 386 2025-11-03 10:12:02 +00:00
Ryan Huber
b394112ad9 gso and gro with uring on send/receive for udp 2025-11-03 09:59:45 +00:00
11 changed files with 3587 additions and 705 deletions

View File

@@ -7,30 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
### Added
- Experimental Linux UDP offload support: enable `listen.enable_gso` and
`listen.enable_gro` to activate UDP_SEGMENT batching and GRO receive
splitting. Includes automatic capability probing, per-packet fallbacks, and
runtime metrics/logs for visibility.
- Optional Linux TUN `virtio_net_hdr` support: set `tun.enable_vnet_hdr` to
have Nebula negotiate VNET headers and offload flags so future batches can
be delivered to the kernel with metadata instead of per-packet writes.
- Linux UDP send sharding can now be tuned with `listen.send_shards`; defaults
to `GOMAXPROCS` but can be increased to stripe heavy peers across more
goroutines.
### Changed ### Changed
- `default_local_cidr_any` now defaults to false, meaning that any firewall rule - `default_local_cidr_any` now defaults to false, meaning that any firewall rule
intended to target an `unsafe_routes` entry must explicitly declare it via the intended to target an `unsafe_routes` entry must explicitly declare it via the
`local_cidr` field. This is almost always the intended behavior. This flag is `local_cidr` field. This is almost always the intended behavior. This flag is
deprecated and will be removed in a future release. deprecated and will be removed in a future release.
- UDP receive path now enqueues into per-worker lock-free rings, restoring the
`listen.decrypt_workers`/`listen.decrypt_queue_depth` tuning knobs while
eliminating the mutex contention from the old shared channel.
- Increased replay protection window to 32k packets so high-throughput links
tolerate larger bursts of reordering without tripping the anti-replay logic.
## [1.9.4] - 2024-09-09 ## [1.9.4] - 2024-09-09

View File

@@ -1,8 +1,10 @@
package cert package cert
import ( import (
"encoding/hex"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"time"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
) )
@@ -189,3 +191,71 @@ func UnmarshalSigningPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error)
} }
return k.Bytes, r, curve, nil return k.Bytes, r, curve, nil
} }
// Backward compatibility functions for older API
func MarshalX25519PublicKey(b []byte) []byte {
return MarshalPublicKeyToPEM(Curve_CURVE25519, b)
}
func MarshalX25519PrivateKey(b []byte) []byte {
return MarshalPrivateKeyToPEM(Curve_CURVE25519, b)
}
func MarshalPublicKey(curve Curve, b []byte) []byte {
return MarshalPublicKeyToPEM(curve, b)
}
func MarshalPrivateKey(curve Curve, b []byte) []byte {
return MarshalPrivateKeyToPEM(curve, b)
}
// NebulaCertificate is a compatibility wrapper for the old API
type NebulaCertificate struct {
Details NebulaCertificateDetails
Signature []byte
cert Certificate
}
// NebulaCertificateDetails is a compatibility wrapper for certificate details
type NebulaCertificateDetails struct {
Name string
NotBefore time.Time
NotAfter time.Time
PublicKey []byte
IsCA bool
Issuer []byte
Curve Curve
}
// UnmarshalNebulaCertificateFromPEM provides backward compatibility with the old API
func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) {
c, rest, err := UnmarshalCertificateFromPEM(b)
if err != nil {
return nil, rest, err
}
// Convert to old format
nc := &NebulaCertificate{
Details: NebulaCertificateDetails{
Name: c.Name(),
NotBefore: c.NotBefore(),
NotAfter: c.NotAfter(),
PublicKey: c.PublicKey(),
IsCA: c.IsCA(),
Curve: c.Curve(),
},
Signature: c.Signature(),
cert: c,
}
// Handle issuer
if c.Issuer() != "" {
issuerBytes, err := hex.DecodeString(c.Issuer())
if err != nil {
return nil, rest, fmt.Errorf("failed to decode issuer fingerprint: %w", err)
}
nc.Details.Issuer = issuerBytes
}
return nc, rest, nil
}

View File

@@ -13,10 +13,7 @@ import (
"github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/noiseutil"
) )
// ReplayWindow controls the size of the sliding window used to detect replays. const ReplayWindow = 1024
// High-bandwidth links with GRO/GSO can reorder more than a thousand packets in
// flight, so keep this comfortably above the largest expected burst.
const ReplayWindow = 32768
type ConnectionState struct { type ConnectionState struct {
eKey *NebulaCipherState eKey *NebulaCipherState

View File

@@ -5,11 +5,9 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"math/bits"
"net/netip" "net/netip"
"os" "os"
"runtime" "runtime"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -23,12 +21,7 @@ import (
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
) )
const ( const mtu = 9001
mtu = 9001
tunReadBufferSize = mtu * 8
defaultDecryptWorkerFactor = 2
defaultInboundQueueDepth = 1024
)
type InterfaceConfig struct { type InterfaceConfig struct {
HostMap *HostMap HostMap *HostMap
@@ -55,8 +48,6 @@ type InterfaceConfig struct {
ConntrackCacheTimeout time.Duration ConntrackCacheTimeout time.Duration
l *logrus.Logger l *logrus.Logger
DecryptWorkers int
DecryptQueueDepth int
} }
type Interface struct { type Interface struct {
@@ -101,167 +92,7 @@ type Interface struct {
messageMetrics *MessageMetrics messageMetrics *MessageMetrics
cachedPacketMetrics *cachedPacketMetrics cachedPacketMetrics *cachedPacketMetrics
l *logrus.Logger l *logrus.Logger
ctx context.Context
udpListenWG sync.WaitGroup
inboundPool sync.Pool
decryptWG sync.WaitGroup
decryptQueues []*inboundRing
decryptWorkers int
decryptStates []decryptWorkerState
decryptCounter atomic.Uint32
}
type inboundPacket struct {
addr netip.AddrPort
payload []byte
release func()
queue int
}
type decryptWorkerState struct {
queue *inboundRing
notify chan struct{}
}
type decryptContext struct {
ctTicker *firewall.ConntrackCacheTicker
plain []byte
head header.H
fwPacket firewall.Packet
light *LightHouseHandler
nebula []byte
}
type inboundCell struct {
seq atomic.Uint64
pkt *inboundPacket
}
type inboundRing struct {
mask uint64
cells []inboundCell
enqueuePos atomic.Uint64
dequeuePos atomic.Uint64
}
func newInboundRing(capacity int) *inboundRing {
if capacity < 2 {
capacity = 2
}
size := nextPowerOfTwo(uint32(capacity))
if size < 2 {
size = 2
}
ring := &inboundRing{
mask: uint64(size - 1),
cells: make([]inboundCell, size),
}
for i := range ring.cells {
ring.cells[i].seq.Store(uint64(i))
}
return ring
}
func nextPowerOfTwo(v uint32) uint32 {
if v == 0 {
return 1
}
return 1 << (32 - bits.LeadingZeros32(v-1))
}
func (r *inboundRing) Enqueue(pkt *inboundPacket) bool {
var cell *inboundCell
pos := r.enqueuePos.Load()
for {
cell = &r.cells[pos&r.mask]
seq := cell.seq.Load()
diff := int64(seq) - int64(pos)
if diff == 0 {
if r.enqueuePos.CompareAndSwap(pos, pos+1) {
break
}
} else if diff < 0 {
return false
} else {
pos = r.enqueuePos.Load()
}
}
cell.pkt = pkt
cell.seq.Store(pos + 1)
return true
}
func (r *inboundRing) Dequeue() (*inboundPacket, bool) {
var cell *inboundCell
pos := r.dequeuePos.Load()
for {
cell = &r.cells[pos&r.mask]
seq := cell.seq.Load()
diff := int64(seq) - int64(pos+1)
if diff == 0 {
if r.dequeuePos.CompareAndSwap(pos, pos+1) {
break
}
} else if diff < 0 {
return nil, false
} else {
pos = r.dequeuePos.Load()
}
}
pkt := cell.pkt
cell.pkt = nil
cell.seq.Store(pos + r.mask + 1)
return pkt, true
}
func (f *Interface) getInboundPacket() *inboundPacket {
if pkt, ok := f.inboundPool.Get().(*inboundPacket); ok && pkt != nil {
return pkt
}
return &inboundPacket{}
}
func (f *Interface) putInboundPacket(pkt *inboundPacket) {
if pkt == nil {
return
}
pkt.addr = netip.AddrPort{}
pkt.payload = nil
pkt.release = nil
pkt.queue = 0
f.inboundPool.Put(pkt)
}
func newDecryptContext(f *Interface) *decryptContext {
return &decryptContext{
ctTicker: firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout),
plain: make([]byte, udp.MTU),
head: header.H{},
fwPacket: firewall.Packet{},
light: f.lightHouse.NewRequestHandler(),
nebula: make([]byte, 12, 12),
}
}
func (f *Interface) processInboundPacket(pkt *inboundPacket, ctx *decryptContext) {
if pkt == nil {
return
}
defer func() {
if pkt.release != nil {
pkt.release()
}
f.putInboundPacket(pkt)
}()
ctx.head = header.H{}
ctx.fwPacket = firewall.Packet{}
var cache firewall.ConntrackCache
if ctx.ctTicker != nil {
cache = ctx.ctTicker.Get(f.l)
}
f.readOutsidePackets(pkt.addr, nil, ctx.plain[:0], pkt.payload, &ctx.head, &ctx.fwPacket, ctx.light, ctx.nebula, pkt.queue, cache)
} }
type EncWriter interface { type EncWriter interface {
@@ -331,35 +162,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
} }
cs := c.pki.getCertState() cs := c.pki.getCertState()
decryptWorkers := c.DecryptWorkers
if decryptWorkers < 0 {
decryptWorkers = 0
}
if decryptWorkers == 0 {
decryptWorkers = c.routines * defaultDecryptWorkerFactor
if decryptWorkers < c.routines {
decryptWorkers = c.routines
}
}
if decryptWorkers < 0 {
decryptWorkers = 0
}
if runtime.GOOS != "linux" {
decryptWorkers = 0
}
queueDepth := c.DecryptQueueDepth
if queueDepth <= 0 {
queueDepth = defaultInboundQueueDepth
}
minDepth := c.routines * 64
if minDepth <= 0 {
minDepth = 64
}
if queueDepth < minDepth {
queueDepth = minDepth
}
ifce := &Interface{ ifce := &Interface{
pki: c.pki, pki: c.pki,
hostMap: c.HostMap, hostMap: c.HostMap,
@@ -392,10 +194,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil), dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
}, },
l: c.l, l: c.l,
ctx: ctx,
inboundPool: sync.Pool{New: func() any { return &inboundPacket{} }},
decryptWorkers: decryptWorkers,
} }
ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
@@ -404,19 +203,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
ifce.connectionManager.intf = ifce ifce.connectionManager.intf = ifce
if decryptWorkers > 0 {
ifce.decryptQueues = make([]*inboundRing, decryptWorkers)
ifce.decryptStates = make([]decryptWorkerState, decryptWorkers)
for i := 0; i < decryptWorkers; i++ {
queue := newInboundRing(queueDepth)
ifce.decryptQueues[i] = queue
ifce.decryptStates[i] = decryptWorkerState{
queue: queue,
notify: make(chan struct{}, 1),
}
}
}
return ifce, nil return ifce, nil
} }
@@ -456,68 +242,8 @@ func (f *Interface) activate() {
} }
} }
func (f *Interface) startDecryptWorkers() {
if f.decryptWorkers <= 0 || len(f.decryptQueues) == 0 {
return
}
f.decryptWG.Add(f.decryptWorkers)
for i := 0; i < f.decryptWorkers; i++ {
go f.decryptWorker(i)
}
}
func (f *Interface) decryptWorker(id int) {
defer f.decryptWG.Done()
if id < 0 || id >= len(f.decryptStates) {
return
}
state := f.decryptStates[id]
if state.queue == nil {
return
}
ctx := newDecryptContext(f)
for {
for {
pkt, ok := state.queue.Dequeue()
if !ok {
break
}
f.processInboundPacket(pkt, ctx)
}
if f.closed.Load() || f.ctx.Err() != nil {
for {
pkt, ok := state.queue.Dequeue()
if !ok {
return
}
f.processInboundPacket(pkt, ctx)
}
}
select {
case <-f.ctx.Done():
case <-state.notify:
}
}
}
func (f *Interface) notifyDecryptWorker(idx int) {
if idx < 0 || idx >= len(f.decryptStates) {
return
}
state := f.decryptStates[idx]
if state.notify == nil {
return
}
select {
case state.notify <- struct{}{}:
default:
}
}
func (f *Interface) run() { func (f *Interface) run() {
f.startDecryptWorkers()
// Launch n queues to read packets from udp // Launch n queues to read packets from udp
f.udpListenWG.Add(f.routines)
for i := 0; i < f.routines; i++ { for i := 0; i < f.routines; i++ {
go f.listenOut(i) go f.listenOut(i)
} }
@@ -530,7 +256,6 @@ func (f *Interface) run() {
func (f *Interface) listenOut(i int) { func (f *Interface) listenOut(i int) {
runtime.LockOSThread() runtime.LockOSThread()
defer f.udpListenWG.Done()
var li udp.Conn var li udp.Conn
if i > 0 { if i > 0 {
@@ -539,78 +264,26 @@ func (f *Interface) listenOut(i int) {
li = f.outside li = f.outside
} }
useWorkers := f.decryptWorkers > 0 && len(f.decryptQueues) > 0 ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
var ( lhh := f.lightHouse.NewRequestHandler()
inlineTicker *firewall.ConntrackCacheTicker plaintext := make([]byte, udp.MTU)
inlineHandler *LightHouseHandler h := &header.H{}
inlinePlain []byte fwPacket := &firewall.Packet{}
inlineHeader header.H nb := make([]byte, 12, 12)
inlinePacket firewall.Packet
inlineNB []byte
inlineCtx *decryptContext
)
if useWorkers {
inlineCtx = newDecryptContext(f)
} else {
inlineTicker = firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
inlineHandler = f.lightHouse.NewRequestHandler()
inlinePlain = make([]byte, udp.MTU)
inlineNB = make([]byte, 12, 12)
}
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte, release func()) { li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte, release func()) {
if !useWorkers { if release != nil {
if release != nil { defer release()
defer release()
}
select {
case <-f.ctx.Done():
return
default:
}
inlineHeader = header.H{}
inlinePacket = firewall.Packet{}
var cache firewall.ConntrackCache
if inlineTicker != nil {
cache = inlineTicker.Get(f.l)
}
f.readOutsidePackets(fromUdpAddr, nil, inlinePlain[:0], payload, &inlineHeader, &inlinePacket, inlineHandler, inlineNB, i, cache)
return
} }
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
if f.ctx.Err() != nil {
if release != nil {
release()
}
return
}
pkt := f.getInboundPacket()
pkt.addr = fromUdpAddr
pkt.payload = payload
pkt.release = release
pkt.queue = i
queueCount := len(f.decryptQueues)
if queueCount == 0 {
f.processInboundPacket(pkt, inlineCtx)
return
}
w := int(f.decryptCounter.Add(1)-1) % queueCount
if w < 0 || w >= queueCount || !f.decryptQueues[w].Enqueue(pkt) {
f.processInboundPacket(pkt, inlineCtx)
return
}
f.notifyDecryptWorker(w)
}) })
} }
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
runtime.LockOSThread() runtime.LockOSThread()
packet := make([]byte, tunReadBufferSize) packet := make([]byte, mtu)
out := make([]byte, tunReadBufferSize) out := make([]byte, mtu)
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
@@ -788,19 +461,6 @@ func (f *Interface) Close() error {
} }
} }
f.udpListenWG.Wait()
if f.decryptWorkers > 0 {
for _, state := range f.decryptStates {
if state.notify != nil {
select {
case state.notify <- struct{}{}:
default:
}
}
}
f.decryptWG.Wait()
}
// Release the tun device // Release the tun device
return f.inside.Close() return f.inside.Close()
} }

View File

@@ -120,8 +120,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache") l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache")
} }
udp.SetDisableUDPCsum(c.GetBool("listen.disable_udp_checksum", false))
var tun overlay.Device var tun overlay.Device
if !configTest { if !configTest {
c.CatchHUP(ctx) c.CatchHUP(ctx)
@@ -223,9 +221,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
} }
} }
decryptWorkers := c.GetInt("listen.decrypt_workers", 0)
decryptQueueDepth := c.GetInt("listen.decrypt_queue_depth", 0)
ifConfig := &InterfaceConfig{ ifConfig := &InterfaceConfig{
HostMap: hostMap, HostMap: hostMap,
Inside: tun, Inside: tun,
@@ -248,8 +243,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
punchy: punchy, punchy: punchy,
ConntrackCacheTimeout: conntrackCacheTimeout, ConntrackCacheTimeout: conntrackCacheTimeout,
l: l, l: l,
DecryptWorkers: decryptWorkers,
DecryptQueueDepth: decryptQueueDepth,
} }
var ifce *Interface var ifce *Interface

View File

@@ -470,13 +470,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l). hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
WithError(err).
WithField("tag", "decrypt-debug").
WithField("remoteIndexLocal", hostinfo.localIndexId).
WithField("messageCounter", messageCounter).
WithField("packet_len", len(packet)).
Error("Failed to decrypt packet")
return false return false
} }

View File

@@ -25,17 +25,14 @@ import (
type tun struct { type tun struct {
io.ReadWriteCloser io.ReadWriteCloser
fd int fd int
Device string Device string
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
MaxMTU int MaxMTU int
DefaultMTU int DefaultMTU int
TXQueueLen int TXQueueLen int
deviceIndex int deviceIndex int
ioctlFd uintptr ioctlFd uintptr
enableVnetHdr bool
vnetHdrLen int
queues []*tunQueue
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
@@ -68,90 +65,10 @@ type ifreqQLEN struct {
pad [8]byte pad [8]byte
} }
const (
virtioNetHdrLen = 12
tunDefaultMaxPacket = 65536
)
type tunQueue struct {
file *os.File
fd int
enableVnetHdr bool
vnetHdrLen int
maxPacket int
writeScratch []byte
readScratch []byte
l *logrus.Logger
}
func newTunQueue(file *os.File, enableVnetHdr bool, vnetHdrLen, maxPacket int, l *logrus.Logger) *tunQueue {
if maxPacket <= 0 {
maxPacket = tunDefaultMaxPacket
}
q := &tunQueue{
file: file,
fd: int(file.Fd()),
enableVnetHdr: enableVnetHdr,
vnetHdrLen: vnetHdrLen,
maxPacket: maxPacket,
l: l,
}
if enableVnetHdr {
q.growReadScratch(maxPacket)
}
return q
}
func (q *tunQueue) growReadScratch(packetSize int) {
needed := q.vnetHdrLen + packetSize
if needed < q.vnetHdrLen+DefaultMTU {
needed = q.vnetHdrLen + DefaultMTU
}
if q.readScratch == nil || cap(q.readScratch) < needed {
q.readScratch = make([]byte, needed)
} else {
q.readScratch = q.readScratch[:needed]
}
}
func (q *tunQueue) setMaxPacket(packet int) {
if packet <= 0 {
packet = DefaultMTU
}
q.maxPacket = packet
if q.enableVnetHdr {
q.growReadScratch(packet)
}
}
func configureVnetHdr(fd int, hdrLen int, l *logrus.Logger) error {
features, err := unix.IoctlGetInt(fd, unix.TUNGETFEATURES)
if err == nil && features&unix.IFF_VNET_HDR == 0 {
return fmt.Errorf("kernel does not support IFF_VNET_HDR")
}
if err := unix.IoctlSetInt(fd, unix.TUNSETVNETHDRSZ, hdrLen); err != nil {
return err
}
offload := unix.TUN_F_CSUM | unix.TUN_F_UFO
if err := unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, offload); err != nil {
if l != nil {
l.WithError(err).Warn("Failed to enable TUN offload features")
}
}
return nil
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
enableVnetHdr := c.GetBool("tun.enable_vnet_hdr", false)
if enableVnetHdr {
if err := configureVnetHdr(deviceFd, virtioNetHdrLen, l); err != nil {
l.WithError(err).Warn("Failed to configure VNET header support on provided tun fd; disabling")
enableVnetHdr = false
}
}
t, err := newTunGeneric(c, l, file, vpnNetworks, enableVnetHdr) t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -189,25 +106,14 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
if multiqueue { if multiqueue {
req.Flags |= unix.IFF_MULTI_QUEUE req.Flags |= unix.IFF_MULTI_QUEUE
} }
enableVnetHdr := c.GetBool("tun.enable_vnet_hdr", false)
if enableVnetHdr {
req.Flags |= unix.IFF_VNET_HDR
}
copy(req.Name[:], c.GetString("tun.dev", "")) copy(req.Name[:], c.GetString("tun.dev", ""))
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
return nil, err return nil, err
} }
name := strings.Trim(string(req.Name[:]), "\x00") name := strings.Trim(string(req.Name[:]), "\x00")
if enableVnetHdr {
if err := configureVnetHdr(fd, virtioNetHdrLen, l); err != nil {
l.WithError(err).Warn("Failed to configure VNET header support on tun device; disabling")
enableVnetHdr = false
}
}
file := os.NewFile(uintptr(fd), "/dev/net/tun") file := os.NewFile(uintptr(fd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, vpnNetworks, enableVnetHdr) t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -217,30 +123,21 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
return t, nil return t, nil
} }
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix, enableVnetHdr bool) (*tun, error) { func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
queue := newTunQueue(file, enableVnetHdr, virtioNetHdrLen, tunDefaultMaxPacket, l)
t := &tun{ t := &tun{
ReadWriteCloser: queue, ReadWriteCloser: file,
fd: int(file.Fd()), fd: int(file.Fd()),
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500), TXQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false), useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0), useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
l: l, l: l,
enableVnetHdr: enableVnetHdr,
vnetHdrLen: virtioNetHdrLen,
queues: []*tunQueue{queue},
} }
err := t.reload(c, true) err := t.reload(c, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if enableVnetHdr {
for _, q := range t.queues {
q.setMaxPacket(t.MaxMTU)
}
}
c.RegisterReloadCallback(func(c *config.C) { c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false) err := t.reload(c, false)
@@ -283,11 +180,6 @@ func (t *tun) reload(c *config.C, initial bool) error {
t.MaxMTU = newMaxMTU t.MaxMTU = newMaxMTU
t.DefaultMTU = newDefaultMTU t.DefaultMTU = newDefaultMTU
if t.enableVnetHdr {
for _, q := range t.queues {
q.setMaxPacket(t.MaxMTU)
}
}
// Teach nebula how to handle the routes before establishing them in the system table // Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes) oldRoutes := t.Routes.Swap(&routes)
@@ -332,87 +224,14 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
var req ifReq var req ifReq
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
if t.enableVnetHdr {
req.Flags |= unix.IFF_VNET_HDR
}
copy(req.Name[:], t.Device) copy(req.Name[:], t.Device)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
return nil, err return nil, err
} }
file := os.NewFile(uintptr(fd), "/dev/net/tun") file := os.NewFile(uintptr(fd), "/dev/net/tun")
queue := newTunQueue(file, t.enableVnetHdr, t.vnetHdrLen, t.MaxMTU, t.l)
if t.enableVnetHdr {
if err := configureVnetHdr(fd, t.vnetHdrLen, t.l); err != nil {
queue.enableVnetHdr = false
}
}
t.queues = append(t.queues, queue)
return queue, nil return file, nil
}
func (q *tunQueue) Read(p []byte) (int, error) {
if !q.enableVnetHdr {
return q.file.Read(p)
}
if len(p)+q.vnetHdrLen > cap(q.readScratch) {
q.growReadScratch(len(p))
}
buf := q.readScratch[:cap(q.readScratch)]
n, err := q.file.Read(buf)
if n <= 0 {
return n, err
}
if n < q.vnetHdrLen {
if err == nil {
err = io.ErrUnexpectedEOF
}
return 0, err
}
payload := buf[q.vnetHdrLen:n]
if len(payload) > len(p) {
copy(p, payload[:len(p)])
if err == nil {
err = io.ErrShortBuffer
}
return len(p), err
}
copy(p, payload)
return len(payload), err
}
func (q *tunQueue) Write(b []byte) (int, error) {
if !q.enableVnetHdr {
return unix.Write(q.fd, b)
}
total := q.vnetHdrLen + len(b)
if cap(q.writeScratch) < total {
q.writeScratch = make([]byte, total)
} else {
q.writeScratch = q.writeScratch[:total]
}
for i := 0; i < q.vnetHdrLen; i++ {
q.writeScratch[i] = 0
}
copy(q.writeScratch[q.vnetHdrLen:], b)
n, err := unix.Write(q.fd, q.writeScratch)
if n >= q.vnetHdrLen {
n -= q.vnetHdrLen
} else {
n = 0
}
return n, err
}
func (q *tunQueue) Close() error {
return q.file.Close()
} }
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {

1740
udp/io_uring_linux.go Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -7,6 +7,9 @@
package udp package udp
import ( import (
"errors"
"fmt"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@@ -79,3 +82,52 @@ func setIovecBase(msg *rawMessage, buf []byte) {
iov.Base = &buf[0] iov.Base = &buf[0]
iov.Len = uint32(len(buf)) iov.Len = uint32(len(buf))
} }
func rawMessageToUnixMsghdr(msg *rawMessage) (unix.Msghdr, unix.Iovec, error) {
var hdr unix.Msghdr
var iov unix.Iovec
if msg == nil {
return hdr, iov, errors.New("nil rawMessage")
}
if msg.Hdr.Iov == nil || msg.Hdr.Iov.Base == nil {
return hdr, iov, errors.New("rawMessage missing payload buffer")
}
payloadLen := int(msg.Hdr.Iov.Len)
if payloadLen < 0 {
return hdr, iov, fmt.Errorf("invalid payload length: %d", payloadLen)
}
iov.Base = msg.Hdr.Iov.Base
iov.Len = uint32(payloadLen)
hdr.Iov = &iov
hdr.Iovlen = 1
hdr.Name = msg.Hdr.Name
// CRITICAL: Always set to full buffer size for receive, not what kernel wrote last time
if hdr.Name != nil {
hdr.Namelen = uint32(unix.SizeofSockaddrInet6)
} else {
hdr.Namelen = 0
}
hdr.Control = msg.Hdr.Control
// CRITICAL: Use the allocated size, not what was previously returned
if hdr.Control != nil {
// Control buffer size is stored in Controllen from PrepareRawMessages
hdr.Controllen = msg.Hdr.Controllen
} else {
hdr.Controllen = 0
}
hdr.Flags = 0 // Reset flags for new receive
return hdr, iov, nil
}
func updateRawMessageFromUnixMsghdr(msg *rawMessage, hdr *unix.Msghdr, n int) {
if msg == nil || hdr == nil {
return
}
msg.Hdr.Namelen = hdr.Namelen
msg.Hdr.Controllen = hdr.Controllen
msg.Hdr.Flags = hdr.Flags
if n < 0 {
n = 0
}
msg.Len = uint32(n)
}

View File

@@ -7,6 +7,9 @@
package udp package udp
import ( import (
"errors"
"fmt"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@@ -80,3 +83,52 @@ func setIovecBase(msg *rawMessage, buf []byte) {
iov.Base = &buf[0] iov.Base = &buf[0]
iov.Len = uint64(len(buf)) iov.Len = uint64(len(buf))
} }
func rawMessageToUnixMsghdr(msg *rawMessage) (unix.Msghdr, unix.Iovec, error) {
var hdr unix.Msghdr
var iov unix.Iovec
if msg == nil {
return hdr, iov, errors.New("nil rawMessage")
}
if msg.Hdr.Iov == nil || msg.Hdr.Iov.Base == nil {
return hdr, iov, errors.New("rawMessage missing payload buffer")
}
payloadLen := int(msg.Hdr.Iov.Len)
if payloadLen < 0 {
return hdr, iov, fmt.Errorf("invalid payload length: %d", payloadLen)
}
iov.Base = msg.Hdr.Iov.Base
iov.Len = uint64(payloadLen)
hdr.Iov = &iov
hdr.Iovlen = 1
hdr.Name = msg.Hdr.Name
// CRITICAL: Always set to full buffer size for receive, not what kernel wrote last time
if hdr.Name != nil {
hdr.Namelen = uint32(unix.SizeofSockaddrInet6)
} else {
hdr.Namelen = 0
}
hdr.Control = msg.Hdr.Control
// CRITICAL: Use the allocated size, not what was previously returned
if hdr.Control != nil {
// Control buffer size is stored in Controllen from PrepareRawMessages
hdr.Controllen = msg.Hdr.Controllen
} else {
hdr.Controllen = 0
}
hdr.Flags = 0 // Reset flags for new receive
return hdr, iov, nil
}
func updateRawMessageFromUnixMsghdr(msg *rawMessage, hdr *unix.Msghdr, n int) {
if msg == nil || hdr == nil {
return
}
msg.Hdr.Namelen = hdr.Namelen
msg.Hdr.Controllen = hdr.Controllen
msg.Hdr.Flags = hdr.Flags
if n < 0 {
n = 0
}
msg.Len = uint32(n)
}