mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-23 17:04:25 +01:00
Compare commits
4 Commits
io-uring-g
...
stinkier
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
29157f413c | ||
|
|
68746bd907 | ||
|
|
51b383f680 | ||
|
|
71c849e63e |
18
CHANGELOG.md
18
CHANGELOG.md
@@ -7,12 +7,30 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- Experimental Linux UDP offload support: enable `listen.enable_gso` and
|
||||||
|
`listen.enable_gro` to activate UDP_SEGMENT batching and GRO receive
|
||||||
|
splitting. Includes automatic capability probing, per-packet fallbacks, and
|
||||||
|
runtime metrics/logs for visibility.
|
||||||
|
- Optional Linux TUN `virtio_net_hdr` support: set `tun.enable_vnet_hdr` to
|
||||||
|
have Nebula negotiate VNET headers and offload flags so future batches can
|
||||||
|
be delivered to the kernel with metadata instead of per-packet writes.
|
||||||
|
- Linux UDP send sharding can now be tuned with `listen.send_shards`; defaults
|
||||||
|
to `GOMAXPROCS` but can be increased to stripe heavy peers across more
|
||||||
|
goroutines.
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
|
||||||
- `default_local_cidr_any` now defaults to false, meaning that any firewall rule
|
- `default_local_cidr_any` now defaults to false, meaning that any firewall rule
|
||||||
intended to target an `unsafe_routes` entry must explicitly declare it via the
|
intended to target an `unsafe_routes` entry must explicitly declare it via the
|
||||||
`local_cidr` field. This is almost always the intended behavior. This flag is
|
`local_cidr` field. This is almost always the intended behavior. This flag is
|
||||||
deprecated and will be removed in a future release.
|
deprecated and will be removed in a future release.
|
||||||
|
- UDP receive path now enqueues into per-worker lock-free rings, restoring the
|
||||||
|
`listen.decrypt_workers`/`listen.decrypt_queue_depth` tuning knobs while
|
||||||
|
eliminating the mutex contention from the old shared channel.
|
||||||
|
- Increased replay protection window to 32k packets so high-throughput links
|
||||||
|
tolerate larger bursts of reordering without tripping the anti-replay logic.
|
||||||
|
|
||||||
## [1.9.4] - 2024-09-09
|
## [1.9.4] - 2024-09-09
|
||||||
|
|
||||||
|
|||||||
70
cert/pem.go
70
cert/pem.go
@@ -1,10 +1,8 @@
|
|||||||
package cert
|
package cert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/hex"
|
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
)
|
)
|
||||||
@@ -191,71 +189,3 @@ func UnmarshalSigningPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error)
|
|||||||
}
|
}
|
||||||
return k.Bytes, r, curve, nil
|
return k.Bytes, r, curve, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Backward compatibility functions for older API
|
|
||||||
func MarshalX25519PublicKey(b []byte) []byte {
|
|
||||||
return MarshalPublicKeyToPEM(Curve_CURVE25519, b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func MarshalX25519PrivateKey(b []byte) []byte {
|
|
||||||
return MarshalPrivateKeyToPEM(Curve_CURVE25519, b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func MarshalPublicKey(curve Curve, b []byte) []byte {
|
|
||||||
return MarshalPublicKeyToPEM(curve, b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func MarshalPrivateKey(curve Curve, b []byte) []byte {
|
|
||||||
return MarshalPrivateKeyToPEM(curve, b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NebulaCertificate is a compatibility wrapper for the old API
|
|
||||||
type NebulaCertificate struct {
|
|
||||||
Details NebulaCertificateDetails
|
|
||||||
Signature []byte
|
|
||||||
cert Certificate
|
|
||||||
}
|
|
||||||
|
|
||||||
// NebulaCertificateDetails is a compatibility wrapper for certificate details
|
|
||||||
type NebulaCertificateDetails struct {
|
|
||||||
Name string
|
|
||||||
NotBefore time.Time
|
|
||||||
NotAfter time.Time
|
|
||||||
PublicKey []byte
|
|
||||||
IsCA bool
|
|
||||||
Issuer []byte
|
|
||||||
Curve Curve
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnmarshalNebulaCertificateFromPEM provides backward compatibility with the old API
|
|
||||||
func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) {
|
|
||||||
c, rest, err := UnmarshalCertificateFromPEM(b)
|
|
||||||
if err != nil {
|
|
||||||
return nil, rest, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert to old format
|
|
||||||
nc := &NebulaCertificate{
|
|
||||||
Details: NebulaCertificateDetails{
|
|
||||||
Name: c.Name(),
|
|
||||||
NotBefore: c.NotBefore(),
|
|
||||||
NotAfter: c.NotAfter(),
|
|
||||||
PublicKey: c.PublicKey(),
|
|
||||||
IsCA: c.IsCA(),
|
|
||||||
Curve: c.Curve(),
|
|
||||||
},
|
|
||||||
Signature: c.Signature(),
|
|
||||||
cert: c,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle issuer
|
|
||||||
if c.Issuer() != "" {
|
|
||||||
issuerBytes, err := hex.DecodeString(c.Issuer())
|
|
||||||
if err != nil {
|
|
||||||
return nil, rest, fmt.Errorf("failed to decode issuer fingerprint: %w", err)
|
|
||||||
}
|
|
||||||
nc.Details.Issuer = issuerBytes
|
|
||||||
}
|
|
||||||
|
|
||||||
return nc, rest, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -13,7 +13,10 @@ import (
|
|||||||
"github.com/slackhq/nebula/noiseutil"
|
"github.com/slackhq/nebula/noiseutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
const ReplayWindow = 1024
|
// ReplayWindow controls the size of the sliding window used to detect replays.
|
||||||
|
// High-bandwidth links with GRO/GSO can reorder more than a thousand packets in
|
||||||
|
// flight, so keep this comfortably above the largest expected burst.
|
||||||
|
const ReplayWindow = 32768
|
||||||
|
|
||||||
type ConnectionState struct {
|
type ConnectionState struct {
|
||||||
eKey *NebulaCipherState
|
eKey *NebulaCipherState
|
||||||
|
|||||||
368
interface.go
368
interface.go
@@ -5,9 +5,11 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"math/bits"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -21,7 +23,12 @@ import (
|
|||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
const mtu = 9001
|
const (
|
||||||
|
mtu = 9001
|
||||||
|
tunReadBufferSize = mtu * 8
|
||||||
|
defaultDecryptWorkerFactor = 2
|
||||||
|
defaultInboundQueueDepth = 1024
|
||||||
|
)
|
||||||
|
|
||||||
type InterfaceConfig struct {
|
type InterfaceConfig struct {
|
||||||
HostMap *HostMap
|
HostMap *HostMap
|
||||||
@@ -48,6 +55,8 @@ type InterfaceConfig struct {
|
|||||||
|
|
||||||
ConntrackCacheTimeout time.Duration
|
ConntrackCacheTimeout time.Duration
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
DecryptWorkers int
|
||||||
|
DecryptQueueDepth int
|
||||||
}
|
}
|
||||||
|
|
||||||
type Interface struct {
|
type Interface struct {
|
||||||
@@ -92,7 +101,167 @@ type Interface struct {
|
|||||||
messageMetrics *MessageMetrics
|
messageMetrics *MessageMetrics
|
||||||
cachedPacketMetrics *cachedPacketMetrics
|
cachedPacketMetrics *cachedPacketMetrics
|
||||||
|
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
ctx context.Context
|
||||||
|
udpListenWG sync.WaitGroup
|
||||||
|
inboundPool sync.Pool
|
||||||
|
decryptWG sync.WaitGroup
|
||||||
|
decryptQueues []*inboundRing
|
||||||
|
decryptWorkers int
|
||||||
|
decryptStates []decryptWorkerState
|
||||||
|
decryptCounter atomic.Uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type inboundPacket struct {
|
||||||
|
addr netip.AddrPort
|
||||||
|
payload []byte
|
||||||
|
release func()
|
||||||
|
queue int
|
||||||
|
}
|
||||||
|
|
||||||
|
type decryptWorkerState struct {
|
||||||
|
queue *inboundRing
|
||||||
|
notify chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type decryptContext struct {
|
||||||
|
ctTicker *firewall.ConntrackCacheTicker
|
||||||
|
plain []byte
|
||||||
|
head header.H
|
||||||
|
fwPacket firewall.Packet
|
||||||
|
light *LightHouseHandler
|
||||||
|
nebula []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type inboundCell struct {
|
||||||
|
seq atomic.Uint64
|
||||||
|
pkt *inboundPacket
|
||||||
|
}
|
||||||
|
|
||||||
|
type inboundRing struct {
|
||||||
|
mask uint64
|
||||||
|
cells []inboundCell
|
||||||
|
enqueuePos atomic.Uint64
|
||||||
|
dequeuePos atomic.Uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func newInboundRing(capacity int) *inboundRing {
|
||||||
|
if capacity < 2 {
|
||||||
|
capacity = 2
|
||||||
|
}
|
||||||
|
size := nextPowerOfTwo(uint32(capacity))
|
||||||
|
if size < 2 {
|
||||||
|
size = 2
|
||||||
|
}
|
||||||
|
ring := &inboundRing{
|
||||||
|
mask: uint64(size - 1),
|
||||||
|
cells: make([]inboundCell, size),
|
||||||
|
}
|
||||||
|
for i := range ring.cells {
|
||||||
|
ring.cells[i].seq.Store(uint64(i))
|
||||||
|
}
|
||||||
|
return ring
|
||||||
|
}
|
||||||
|
|
||||||
|
func nextPowerOfTwo(v uint32) uint32 {
|
||||||
|
if v == 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 1 << (32 - bits.LeadingZeros32(v-1))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *inboundRing) Enqueue(pkt *inboundPacket) bool {
|
||||||
|
var cell *inboundCell
|
||||||
|
pos := r.enqueuePos.Load()
|
||||||
|
for {
|
||||||
|
cell = &r.cells[pos&r.mask]
|
||||||
|
seq := cell.seq.Load()
|
||||||
|
diff := int64(seq) - int64(pos)
|
||||||
|
if diff == 0 {
|
||||||
|
if r.enqueuePos.CompareAndSwap(pos, pos+1) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
} else if diff < 0 {
|
||||||
|
return false
|
||||||
|
} else {
|
||||||
|
pos = r.enqueuePos.Load()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cell.pkt = pkt
|
||||||
|
cell.seq.Store(pos + 1)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *inboundRing) Dequeue() (*inboundPacket, bool) {
|
||||||
|
var cell *inboundCell
|
||||||
|
pos := r.dequeuePos.Load()
|
||||||
|
for {
|
||||||
|
cell = &r.cells[pos&r.mask]
|
||||||
|
seq := cell.seq.Load()
|
||||||
|
diff := int64(seq) - int64(pos+1)
|
||||||
|
if diff == 0 {
|
||||||
|
if r.dequeuePos.CompareAndSwap(pos, pos+1) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
} else if diff < 0 {
|
||||||
|
return nil, false
|
||||||
|
} else {
|
||||||
|
pos = r.dequeuePos.Load()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pkt := cell.pkt
|
||||||
|
cell.pkt = nil
|
||||||
|
cell.seq.Store(pos + r.mask + 1)
|
||||||
|
return pkt, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) getInboundPacket() *inboundPacket {
|
||||||
|
if pkt, ok := f.inboundPool.Get().(*inboundPacket); ok && pkt != nil {
|
||||||
|
return pkt
|
||||||
|
}
|
||||||
|
return &inboundPacket{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) putInboundPacket(pkt *inboundPacket) {
|
||||||
|
if pkt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pkt.addr = netip.AddrPort{}
|
||||||
|
pkt.payload = nil
|
||||||
|
pkt.release = nil
|
||||||
|
pkt.queue = 0
|
||||||
|
f.inboundPool.Put(pkt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDecryptContext(f *Interface) *decryptContext {
|
||||||
|
return &decryptContext{
|
||||||
|
ctTicker: firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout),
|
||||||
|
plain: make([]byte, udp.MTU),
|
||||||
|
head: header.H{},
|
||||||
|
fwPacket: firewall.Packet{},
|
||||||
|
light: f.lightHouse.NewRequestHandler(),
|
||||||
|
nebula: make([]byte, 12, 12),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) processInboundPacket(pkt *inboundPacket, ctx *decryptContext) {
|
||||||
|
if pkt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if pkt.release != nil {
|
||||||
|
pkt.release()
|
||||||
|
}
|
||||||
|
f.putInboundPacket(pkt)
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx.head = header.H{}
|
||||||
|
ctx.fwPacket = firewall.Packet{}
|
||||||
|
var cache firewall.ConntrackCache
|
||||||
|
if ctx.ctTicker != nil {
|
||||||
|
cache = ctx.ctTicker.Get(f.l)
|
||||||
|
}
|
||||||
|
f.readOutsidePackets(pkt.addr, nil, ctx.plain[:0], pkt.payload, &ctx.head, &ctx.fwPacket, ctx.light, ctx.nebula, pkt.queue, cache)
|
||||||
}
|
}
|
||||||
|
|
||||||
type EncWriter interface {
|
type EncWriter interface {
|
||||||
@@ -162,6 +331,35 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
cs := c.pki.getCertState()
|
cs := c.pki.getCertState()
|
||||||
|
decryptWorkers := c.DecryptWorkers
|
||||||
|
if decryptWorkers < 0 {
|
||||||
|
decryptWorkers = 0
|
||||||
|
}
|
||||||
|
if decryptWorkers == 0 {
|
||||||
|
decryptWorkers = c.routines * defaultDecryptWorkerFactor
|
||||||
|
if decryptWorkers < c.routines {
|
||||||
|
decryptWorkers = c.routines
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if decryptWorkers < 0 {
|
||||||
|
decryptWorkers = 0
|
||||||
|
}
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
decryptWorkers = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
queueDepth := c.DecryptQueueDepth
|
||||||
|
if queueDepth <= 0 {
|
||||||
|
queueDepth = defaultInboundQueueDepth
|
||||||
|
}
|
||||||
|
minDepth := c.routines * 64
|
||||||
|
if minDepth <= 0 {
|
||||||
|
minDepth = 64
|
||||||
|
}
|
||||||
|
if queueDepth < minDepth {
|
||||||
|
queueDepth = minDepth
|
||||||
|
}
|
||||||
|
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
pki: c.pki,
|
pki: c.pki,
|
||||||
hostMap: c.HostMap,
|
hostMap: c.HostMap,
|
||||||
@@ -194,7 +392,10 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
|
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
|
||||||
},
|
},
|
||||||
|
|
||||||
l: c.l,
|
l: c.l,
|
||||||
|
ctx: ctx,
|
||||||
|
inboundPool: sync.Pool{New: func() any { return &inboundPacket{} }},
|
||||||
|
decryptWorkers: decryptWorkers,
|
||||||
}
|
}
|
||||||
|
|
||||||
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
||||||
@@ -203,6 +404,19 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
|
|
||||||
ifce.connectionManager.intf = ifce
|
ifce.connectionManager.intf = ifce
|
||||||
|
|
||||||
|
if decryptWorkers > 0 {
|
||||||
|
ifce.decryptQueues = make([]*inboundRing, decryptWorkers)
|
||||||
|
ifce.decryptStates = make([]decryptWorkerState, decryptWorkers)
|
||||||
|
for i := 0; i < decryptWorkers; i++ {
|
||||||
|
queue := newInboundRing(queueDepth)
|
||||||
|
ifce.decryptQueues[i] = queue
|
||||||
|
ifce.decryptStates[i] = decryptWorkerState{
|
||||||
|
queue: queue,
|
||||||
|
notify: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return ifce, nil
|
return ifce, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -242,8 +456,68 @@ func (f *Interface) activate() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *Interface) startDecryptWorkers() {
|
||||||
|
if f.decryptWorkers <= 0 || len(f.decryptQueues) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.decryptWG.Add(f.decryptWorkers)
|
||||||
|
for i := 0; i < f.decryptWorkers; i++ {
|
||||||
|
go f.decryptWorker(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) decryptWorker(id int) {
|
||||||
|
defer f.decryptWG.Done()
|
||||||
|
if id < 0 || id >= len(f.decryptStates) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
state := f.decryptStates[id]
|
||||||
|
if state.queue == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx := newDecryptContext(f)
|
||||||
|
for {
|
||||||
|
for {
|
||||||
|
pkt, ok := state.queue.Dequeue()
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
f.processInboundPacket(pkt, ctx)
|
||||||
|
}
|
||||||
|
if f.closed.Load() || f.ctx.Err() != nil {
|
||||||
|
for {
|
||||||
|
pkt, ok := state.queue.Dequeue()
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.processInboundPacket(pkt, ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-f.ctx.Done():
|
||||||
|
case <-state.notify:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) notifyDecryptWorker(idx int) {
|
||||||
|
if idx < 0 || idx >= len(f.decryptStates) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
state := f.decryptStates[idx]
|
||||||
|
if state.notify == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case state.notify <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (f *Interface) run() {
|
func (f *Interface) run() {
|
||||||
|
f.startDecryptWorkers()
|
||||||
// Launch n queues to read packets from udp
|
// Launch n queues to read packets from udp
|
||||||
|
f.udpListenWG.Add(f.routines)
|
||||||
for i := 0; i < f.routines; i++ {
|
for i := 0; i < f.routines; i++ {
|
||||||
go f.listenOut(i)
|
go f.listenOut(i)
|
||||||
}
|
}
|
||||||
@@ -256,6 +530,7 @@ func (f *Interface) run() {
|
|||||||
|
|
||||||
func (f *Interface) listenOut(i int) {
|
func (f *Interface) listenOut(i int) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
defer f.udpListenWG.Done()
|
||||||
|
|
||||||
var li udp.Conn
|
var li udp.Conn
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
@@ -264,26 +539,78 @@ func (f *Interface) listenOut(i int) {
|
|||||||
li = f.outside
|
li = f.outside
|
||||||
}
|
}
|
||||||
|
|
||||||
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
useWorkers := f.decryptWorkers > 0 && len(f.decryptQueues) > 0
|
||||||
lhh := f.lightHouse.NewRequestHandler()
|
var (
|
||||||
plaintext := make([]byte, udp.MTU)
|
inlineTicker *firewall.ConntrackCacheTicker
|
||||||
h := &header.H{}
|
inlineHandler *LightHouseHandler
|
||||||
fwPacket := &firewall.Packet{}
|
inlinePlain []byte
|
||||||
nb := make([]byte, 12, 12)
|
inlineHeader header.H
|
||||||
|
inlinePacket firewall.Packet
|
||||||
|
inlineNB []byte
|
||||||
|
inlineCtx *decryptContext
|
||||||
|
)
|
||||||
|
|
||||||
|
if useWorkers {
|
||||||
|
inlineCtx = newDecryptContext(f)
|
||||||
|
} else {
|
||||||
|
inlineTicker = firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||||
|
inlineHandler = f.lightHouse.NewRequestHandler()
|
||||||
|
inlinePlain = make([]byte, udp.MTU)
|
||||||
|
inlineNB = make([]byte, 12, 12)
|
||||||
|
}
|
||||||
|
|
||||||
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte, release func()) {
|
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte, release func()) {
|
||||||
if release != nil {
|
if !useWorkers {
|
||||||
defer release()
|
if release != nil {
|
||||||
|
defer release()
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-f.ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
inlineHeader = header.H{}
|
||||||
|
inlinePacket = firewall.Packet{}
|
||||||
|
var cache firewall.ConntrackCache
|
||||||
|
if inlineTicker != nil {
|
||||||
|
cache = inlineTicker.Get(f.l)
|
||||||
|
}
|
||||||
|
f.readOutsidePackets(fromUdpAddr, nil, inlinePlain[:0], payload, &inlineHeader, &inlinePacket, inlineHandler, inlineNB, i, cache)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
|
||||||
|
if f.ctx.Err() != nil {
|
||||||
|
if release != nil {
|
||||||
|
release()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pkt := f.getInboundPacket()
|
||||||
|
pkt.addr = fromUdpAddr
|
||||||
|
pkt.payload = payload
|
||||||
|
pkt.release = release
|
||||||
|
pkt.queue = i
|
||||||
|
|
||||||
|
queueCount := len(f.decryptQueues)
|
||||||
|
if queueCount == 0 {
|
||||||
|
f.processInboundPacket(pkt, inlineCtx)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w := int(f.decryptCounter.Add(1)-1) % queueCount
|
||||||
|
if w < 0 || w >= queueCount || !f.decryptQueues[w].Enqueue(pkt) {
|
||||||
|
f.processInboundPacket(pkt, inlineCtx)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.notifyDecryptWorker(w)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
|
||||||
packet := make([]byte, mtu)
|
packet := make([]byte, tunReadBufferSize)
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, tunReadBufferSize)
|
||||||
fwPacket := &firewall.Packet{}
|
fwPacket := &firewall.Packet{}
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
@@ -461,6 +788,19 @@ func (f *Interface) Close() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
f.udpListenWG.Wait()
|
||||||
|
if f.decryptWorkers > 0 {
|
||||||
|
for _, state := range f.decryptStates {
|
||||||
|
if state.notify != nil {
|
||||||
|
select {
|
||||||
|
case state.notify <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
f.decryptWG.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
// Release the tun device
|
// Release the tun device
|
||||||
return f.inside.Close()
|
return f.inside.Close()
|
||||||
}
|
}
|
||||||
|
|||||||
7
main.go
7
main.go
@@ -120,6 +120,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache")
|
l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
udp.SetDisableUDPCsum(c.GetBool("listen.disable_udp_checksum", false))
|
||||||
|
|
||||||
var tun overlay.Device
|
var tun overlay.Device
|
||||||
if !configTest {
|
if !configTest {
|
||||||
c.CatchHUP(ctx)
|
c.CatchHUP(ctx)
|
||||||
@@ -221,6 +223,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
decryptWorkers := c.GetInt("listen.decrypt_workers", 0)
|
||||||
|
decryptQueueDepth := c.GetInt("listen.decrypt_queue_depth", 0)
|
||||||
|
|
||||||
ifConfig := &InterfaceConfig{
|
ifConfig := &InterfaceConfig{
|
||||||
HostMap: hostMap,
|
HostMap: hostMap,
|
||||||
Inside: tun,
|
Inside: tun,
|
||||||
@@ -243,6 +248,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
punchy: punchy,
|
punchy: punchy,
|
||||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||||
l: l,
|
l: l,
|
||||||
|
DecryptWorkers: decryptWorkers,
|
||||||
|
DecryptQueueDepth: decryptQueueDepth,
|
||||||
}
|
}
|
||||||
|
|
||||||
var ifce *Interface
|
var ifce *Interface
|
||||||
|
|||||||
@@ -470,7 +470,13 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
|
|
||||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
hostinfo.logger(f.l).
|
||||||
|
WithError(err).
|
||||||
|
WithField("tag", "decrypt-debug").
|
||||||
|
WithField("remoteIndexLocal", hostinfo.localIndexId).
|
||||||
|
WithField("messageCounter", messageCounter).
|
||||||
|
WithField("packet_len", len(packet)).
|
||||||
|
Error("Failed to decrypt packet")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -25,14 +25,17 @@ import (
|
|||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
io.ReadWriteCloser
|
io.ReadWriteCloser
|
||||||
fd int
|
fd int
|
||||||
Device string
|
Device string
|
||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
MaxMTU int
|
MaxMTU int
|
||||||
DefaultMTU int
|
DefaultMTU int
|
||||||
TXQueueLen int
|
TXQueueLen int
|
||||||
deviceIndex int
|
deviceIndex int
|
||||||
ioctlFd uintptr
|
ioctlFd uintptr
|
||||||
|
enableVnetHdr bool
|
||||||
|
vnetHdrLen int
|
||||||
|
queues []*tunQueue
|
||||||
|
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
@@ -65,10 +68,90 @@ type ifreqQLEN struct {
|
|||||||
pad [8]byte
|
pad [8]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
virtioNetHdrLen = 12
|
||||||
|
tunDefaultMaxPacket = 65536
|
||||||
|
)
|
||||||
|
|
||||||
|
type tunQueue struct {
|
||||||
|
file *os.File
|
||||||
|
fd int
|
||||||
|
enableVnetHdr bool
|
||||||
|
vnetHdrLen int
|
||||||
|
maxPacket int
|
||||||
|
writeScratch []byte
|
||||||
|
readScratch []byte
|
||||||
|
l *logrus.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTunQueue(file *os.File, enableVnetHdr bool, vnetHdrLen, maxPacket int, l *logrus.Logger) *tunQueue {
|
||||||
|
if maxPacket <= 0 {
|
||||||
|
maxPacket = tunDefaultMaxPacket
|
||||||
|
}
|
||||||
|
q := &tunQueue{
|
||||||
|
file: file,
|
||||||
|
fd: int(file.Fd()),
|
||||||
|
enableVnetHdr: enableVnetHdr,
|
||||||
|
vnetHdrLen: vnetHdrLen,
|
||||||
|
maxPacket: maxPacket,
|
||||||
|
l: l,
|
||||||
|
}
|
||||||
|
if enableVnetHdr {
|
||||||
|
q.growReadScratch(maxPacket)
|
||||||
|
}
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *tunQueue) growReadScratch(packetSize int) {
|
||||||
|
needed := q.vnetHdrLen + packetSize
|
||||||
|
if needed < q.vnetHdrLen+DefaultMTU {
|
||||||
|
needed = q.vnetHdrLen + DefaultMTU
|
||||||
|
}
|
||||||
|
if q.readScratch == nil || cap(q.readScratch) < needed {
|
||||||
|
q.readScratch = make([]byte, needed)
|
||||||
|
} else {
|
||||||
|
q.readScratch = q.readScratch[:needed]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *tunQueue) setMaxPacket(packet int) {
|
||||||
|
if packet <= 0 {
|
||||||
|
packet = DefaultMTU
|
||||||
|
}
|
||||||
|
q.maxPacket = packet
|
||||||
|
if q.enableVnetHdr {
|
||||||
|
q.growReadScratch(packet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func configureVnetHdr(fd int, hdrLen int, l *logrus.Logger) error {
|
||||||
|
features, err := unix.IoctlGetInt(fd, unix.TUNGETFEATURES)
|
||||||
|
if err == nil && features&unix.IFF_VNET_HDR == 0 {
|
||||||
|
return fmt.Errorf("kernel does not support IFF_VNET_HDR")
|
||||||
|
}
|
||||||
|
if err := unix.IoctlSetInt(fd, unix.TUNSETVNETHDRSZ, hdrLen); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
offload := unix.TUN_F_CSUM | unix.TUN_F_UFO
|
||||||
|
if err := unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, offload); err != nil {
|
||||||
|
if l != nil {
|
||||||
|
l.WithError(err).Warn("Failed to enable TUN offload features")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||||
|
enableVnetHdr := c.GetBool("tun.enable_vnet_hdr", false)
|
||||||
|
if enableVnetHdr {
|
||||||
|
if err := configureVnetHdr(deviceFd, virtioNetHdrLen, l); err != nil {
|
||||||
|
l.WithError(err).Warn("Failed to configure VNET header support on provided tun fd; disabling")
|
||||||
|
enableVnetHdr = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
t, err := newTunGeneric(c, l, file, vpnNetworks, enableVnetHdr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -106,14 +189,25 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|||||||
if multiqueue {
|
if multiqueue {
|
||||||
req.Flags |= unix.IFF_MULTI_QUEUE
|
req.Flags |= unix.IFF_MULTI_QUEUE
|
||||||
}
|
}
|
||||||
|
enableVnetHdr := c.GetBool("tun.enable_vnet_hdr", false)
|
||||||
|
if enableVnetHdr {
|
||||||
|
req.Flags |= unix.IFF_VNET_HDR
|
||||||
|
}
|
||||||
copy(req.Name[:], c.GetString("tun.dev", ""))
|
copy(req.Name[:], c.GetString("tun.dev", ""))
|
||||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
name := strings.Trim(string(req.Name[:]), "\x00")
|
name := strings.Trim(string(req.Name[:]), "\x00")
|
||||||
|
|
||||||
|
if enableVnetHdr {
|
||||||
|
if err := configureVnetHdr(fd, virtioNetHdrLen, l); err != nil {
|
||||||
|
l.WithError(err).Warn("Failed to configure VNET header support on tun device; disabling")
|
||||||
|
enableVnetHdr = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
t, err := newTunGeneric(c, l, file, vpnNetworks, enableVnetHdr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -123,21 +217,30 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
|
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix, enableVnetHdr bool) (*tun, error) {
|
||||||
|
queue := newTunQueue(file, enableVnetHdr, virtioNetHdrLen, tunDefaultMaxPacket, l)
|
||||||
t := &tun{
|
t := &tun{
|
||||||
ReadWriteCloser: file,
|
ReadWriteCloser: queue,
|
||||||
fd: int(file.Fd()),
|
fd: int(file.Fd()),
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||||
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
||||||
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
||||||
l: l,
|
l: l,
|
||||||
|
enableVnetHdr: enableVnetHdr,
|
||||||
|
vnetHdrLen: virtioNetHdrLen,
|
||||||
|
queues: []*tunQueue{queue},
|
||||||
}
|
}
|
||||||
|
|
||||||
err := t.reload(c, true)
|
err := t.reload(c, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if enableVnetHdr {
|
||||||
|
for _, q := range t.queues {
|
||||||
|
q.setMaxPacket(t.MaxMTU)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
c.RegisterReloadCallback(func(c *config.C) {
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
err := t.reload(c, false)
|
err := t.reload(c, false)
|
||||||
@@ -180,6 +283,11 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||||||
|
|
||||||
t.MaxMTU = newMaxMTU
|
t.MaxMTU = newMaxMTU
|
||||||
t.DefaultMTU = newDefaultMTU
|
t.DefaultMTU = newDefaultMTU
|
||||||
|
if t.enableVnetHdr {
|
||||||
|
for _, q := range t.queues {
|
||||||
|
q.setMaxPacket(t.MaxMTU)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Teach nebula how to handle the routes before establishing them in the system table
|
// Teach nebula how to handle the routes before establishing them in the system table
|
||||||
oldRoutes := t.Routes.Swap(&routes)
|
oldRoutes := t.Routes.Swap(&routes)
|
||||||
@@ -224,14 +332,87 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|||||||
|
|
||||||
var req ifReq
|
var req ifReq
|
||||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
|
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
|
||||||
|
if t.enableVnetHdr {
|
||||||
|
req.Flags |= unix.IFF_VNET_HDR
|
||||||
|
}
|
||||||
copy(req.Name[:], t.Device)
|
copy(req.Name[:], t.Device)
|
||||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||||
|
queue := newTunQueue(file, t.enableVnetHdr, t.vnetHdrLen, t.MaxMTU, t.l)
|
||||||
|
if t.enableVnetHdr {
|
||||||
|
if err := configureVnetHdr(fd, t.vnetHdrLen, t.l); err != nil {
|
||||||
|
queue.enableVnetHdr = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.queues = append(t.queues, queue)
|
||||||
|
|
||||||
return file, nil
|
return queue, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *tunQueue) Read(p []byte) (int, error) {
|
||||||
|
if !q.enableVnetHdr {
|
||||||
|
return q.file.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(p)+q.vnetHdrLen > cap(q.readScratch) {
|
||||||
|
q.growReadScratch(len(p))
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := q.readScratch[:cap(q.readScratch)]
|
||||||
|
n, err := q.file.Read(buf)
|
||||||
|
if n <= 0 {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
if n < q.vnetHdrLen {
|
||||||
|
if err == nil {
|
||||||
|
err = io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := buf[q.vnetHdrLen:n]
|
||||||
|
if len(payload) > len(p) {
|
||||||
|
copy(p, payload[:len(p)])
|
||||||
|
if err == nil {
|
||||||
|
err = io.ErrShortBuffer
|
||||||
|
}
|
||||||
|
return len(p), err
|
||||||
|
}
|
||||||
|
copy(p, payload)
|
||||||
|
return len(payload), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *tunQueue) Write(b []byte) (int, error) {
|
||||||
|
if !q.enableVnetHdr {
|
||||||
|
return unix.Write(q.fd, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
total := q.vnetHdrLen + len(b)
|
||||||
|
if cap(q.writeScratch) < total {
|
||||||
|
q.writeScratch = make([]byte, total)
|
||||||
|
} else {
|
||||||
|
q.writeScratch = q.writeScratch[:total]
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < q.vnetHdrLen; i++ {
|
||||||
|
q.writeScratch[i] = 0
|
||||||
|
}
|
||||||
|
copy(q.writeScratch[q.vnetHdrLen:], b)
|
||||||
|
|
||||||
|
n, err := unix.Write(q.fd, q.writeScratch)
|
||||||
|
if n >= q.vnetHdrLen {
|
||||||
|
n -= q.vnetHdrLen
|
||||||
|
} else {
|
||||||
|
n = 0
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *tunQueue) Close() error {
|
||||||
|
return q.file.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
1753
udp/udp_linux.go
1753
udp/udp_linux.go
File diff suppressed because it is too large
Load Diff
@@ -7,9 +7,6 @@
|
|||||||
package udp
|
package udp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -82,52 +79,3 @@ func setIovecBase(msg *rawMessage, buf []byte) {
|
|||||||
iov.Base = &buf[0]
|
iov.Base = &buf[0]
|
||||||
iov.Len = uint32(len(buf))
|
iov.Len = uint32(len(buf))
|
||||||
}
|
}
|
||||||
|
|
||||||
func rawMessageToUnixMsghdr(msg *rawMessage) (unix.Msghdr, unix.Iovec, error) {
|
|
||||||
var hdr unix.Msghdr
|
|
||||||
var iov unix.Iovec
|
|
||||||
if msg == nil {
|
|
||||||
return hdr, iov, errors.New("nil rawMessage")
|
|
||||||
}
|
|
||||||
if msg.Hdr.Iov == nil || msg.Hdr.Iov.Base == nil {
|
|
||||||
return hdr, iov, errors.New("rawMessage missing payload buffer")
|
|
||||||
}
|
|
||||||
payloadLen := int(msg.Hdr.Iov.Len)
|
|
||||||
if payloadLen < 0 {
|
|
||||||
return hdr, iov, fmt.Errorf("invalid payload length: %d", payloadLen)
|
|
||||||
}
|
|
||||||
iov.Base = msg.Hdr.Iov.Base
|
|
||||||
iov.Len = uint32(payloadLen)
|
|
||||||
hdr.Iov = &iov
|
|
||||||
hdr.Iovlen = 1
|
|
||||||
hdr.Name = msg.Hdr.Name
|
|
||||||
// CRITICAL: Always set to full buffer size for receive, not what kernel wrote last time
|
|
||||||
if hdr.Name != nil {
|
|
||||||
hdr.Namelen = uint32(unix.SizeofSockaddrInet6)
|
|
||||||
} else {
|
|
||||||
hdr.Namelen = 0
|
|
||||||
}
|
|
||||||
hdr.Control = msg.Hdr.Control
|
|
||||||
// CRITICAL: Use the allocated size, not what was previously returned
|
|
||||||
if hdr.Control != nil {
|
|
||||||
// Control buffer size is stored in Controllen from PrepareRawMessages
|
|
||||||
hdr.Controllen = msg.Hdr.Controllen
|
|
||||||
} else {
|
|
||||||
hdr.Controllen = 0
|
|
||||||
}
|
|
||||||
hdr.Flags = 0 // Reset flags for new receive
|
|
||||||
return hdr, iov, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateRawMessageFromUnixMsghdr(msg *rawMessage, hdr *unix.Msghdr, n int) {
|
|
||||||
if msg == nil || hdr == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
msg.Hdr.Namelen = hdr.Namelen
|
|
||||||
msg.Hdr.Controllen = hdr.Controllen
|
|
||||||
msg.Hdr.Flags = hdr.Flags
|
|
||||||
if n < 0 {
|
|
||||||
n = 0
|
|
||||||
}
|
|
||||||
msg.Len = uint32(n)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -7,9 +7,6 @@
|
|||||||
package udp
|
package udp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -83,52 +80,3 @@ func setIovecBase(msg *rawMessage, buf []byte) {
|
|||||||
iov.Base = &buf[0]
|
iov.Base = &buf[0]
|
||||||
iov.Len = uint64(len(buf))
|
iov.Len = uint64(len(buf))
|
||||||
}
|
}
|
||||||
|
|
||||||
func rawMessageToUnixMsghdr(msg *rawMessage) (unix.Msghdr, unix.Iovec, error) {
|
|
||||||
var hdr unix.Msghdr
|
|
||||||
var iov unix.Iovec
|
|
||||||
if msg == nil {
|
|
||||||
return hdr, iov, errors.New("nil rawMessage")
|
|
||||||
}
|
|
||||||
if msg.Hdr.Iov == nil || msg.Hdr.Iov.Base == nil {
|
|
||||||
return hdr, iov, errors.New("rawMessage missing payload buffer")
|
|
||||||
}
|
|
||||||
payloadLen := int(msg.Hdr.Iov.Len)
|
|
||||||
if payloadLen < 0 {
|
|
||||||
return hdr, iov, fmt.Errorf("invalid payload length: %d", payloadLen)
|
|
||||||
}
|
|
||||||
iov.Base = msg.Hdr.Iov.Base
|
|
||||||
iov.Len = uint64(payloadLen)
|
|
||||||
hdr.Iov = &iov
|
|
||||||
hdr.Iovlen = 1
|
|
||||||
hdr.Name = msg.Hdr.Name
|
|
||||||
// CRITICAL: Always set to full buffer size for receive, not what kernel wrote last time
|
|
||||||
if hdr.Name != nil {
|
|
||||||
hdr.Namelen = uint32(unix.SizeofSockaddrInet6)
|
|
||||||
} else {
|
|
||||||
hdr.Namelen = 0
|
|
||||||
}
|
|
||||||
hdr.Control = msg.Hdr.Control
|
|
||||||
// CRITICAL: Use the allocated size, not what was previously returned
|
|
||||||
if hdr.Control != nil {
|
|
||||||
// Control buffer size is stored in Controllen from PrepareRawMessages
|
|
||||||
hdr.Controllen = msg.Hdr.Controllen
|
|
||||||
} else {
|
|
||||||
hdr.Controllen = 0
|
|
||||||
}
|
|
||||||
hdr.Flags = 0 // Reset flags for new receive
|
|
||||||
return hdr, iov, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateRawMessageFromUnixMsghdr(msg *rawMessage, hdr *unix.Msghdr, n int) {
|
|
||||||
if msg == nil || hdr == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
msg.Hdr.Namelen = hdr.Namelen
|
|
||||||
msg.Hdr.Controllen = hdr.Controllen
|
|
||||||
msg.Hdr.Flags = hdr.Flags
|
|
||||||
if n < 0 {
|
|
||||||
n = 0
|
|
||||||
}
|
|
||||||
msg.Len = uint32(n)
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user