mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 16:34:25 +01:00
Compare commits
17 Commits
batched-pa
...
jay.wren-w
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
be90e4aa05 | ||
|
|
bc9711df68 | ||
|
|
4e333c76ba | ||
|
|
f29e21b411 | ||
|
|
8b32382cd9 | ||
|
|
518a78c9d2 | ||
|
|
7c3708561d | ||
|
|
a62ffca975 | ||
|
|
226787ea1f | ||
|
|
b2bc6a09ca | ||
|
|
0f9b33aa36 | ||
|
|
ef0a022375 | ||
|
|
b68e504865 | ||
|
|
3344a840d1 | ||
|
|
2bc9863e66 | ||
|
|
97b3972c11 | ||
|
|
0f305d5397 |
@@ -1,164 +0,0 @@
|
|||||||
package nebula
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/slackhq/nebula/overlay"
|
|
||||||
"github.com/slackhq/nebula/udp"
|
|
||||||
)
|
|
||||||
|
|
||||||
// batchPipelines tracks whether the inside device can operate on packet batches
|
|
||||||
// and, if so, holds the shared packet pool sized for the virtio headroom and
|
|
||||||
// payload limits advertised by the device. It also owns the fan-in/fan-out
|
|
||||||
// queues between the TUN readers, encrypt/decrypt workers, and the UDP writers.
|
|
||||||
type batchPipelines struct {
|
|
||||||
enabled bool
|
|
||||||
inside overlay.BatchCapableDevice
|
|
||||||
headroom int
|
|
||||||
payloadCap int
|
|
||||||
pool *overlay.PacketPool
|
|
||||||
batchSize int
|
|
||||||
routines int
|
|
||||||
rxQueues []chan *overlay.Packet
|
|
||||||
txQueues []chan queuedDatagram
|
|
||||||
tunQueues []chan *overlay.Packet
|
|
||||||
}
|
|
||||||
|
|
||||||
type queuedDatagram struct {
|
|
||||||
packet *overlay.Packet
|
|
||||||
addr netip.AddrPort
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bp *batchPipelines) init(device overlay.Device, routines int, queueDepth int, maxSegments int) {
|
|
||||||
if device == nil || routines <= 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
bcap, ok := device.(overlay.BatchCapableDevice)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
headroom := bcap.BatchHeadroom()
|
|
||||||
payload := bcap.BatchPayloadCap()
|
|
||||||
if maxSegments < 1 {
|
|
||||||
maxSegments = 1
|
|
||||||
}
|
|
||||||
requiredPayload := udp.MTU * maxSegments
|
|
||||||
if payload < requiredPayload {
|
|
||||||
payload = requiredPayload
|
|
||||||
}
|
|
||||||
batchSize := bcap.BatchSize()
|
|
||||||
if headroom <= 0 || payload <= 0 || batchSize <= 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
bp.enabled = true
|
|
||||||
bp.inside = bcap
|
|
||||||
bp.headroom = headroom
|
|
||||||
bp.payloadCap = payload
|
|
||||||
bp.batchSize = batchSize
|
|
||||||
bp.routines = routines
|
|
||||||
bp.pool = overlay.NewPacketPool(headroom, payload)
|
|
||||||
queueCap := batchSize * defaultBatchQueueDepthFactor
|
|
||||||
if queueDepth > 0 {
|
|
||||||
queueCap = queueDepth
|
|
||||||
}
|
|
||||||
if queueCap < batchSize {
|
|
||||||
queueCap = batchSize
|
|
||||||
}
|
|
||||||
bp.rxQueues = make([]chan *overlay.Packet, routines)
|
|
||||||
bp.txQueues = make([]chan queuedDatagram, routines)
|
|
||||||
bp.tunQueues = make([]chan *overlay.Packet, routines)
|
|
||||||
for i := 0; i < routines; i++ {
|
|
||||||
bp.rxQueues[i] = make(chan *overlay.Packet, queueCap)
|
|
||||||
bp.txQueues[i] = make(chan queuedDatagram, queueCap)
|
|
||||||
bp.tunQueues[i] = make(chan *overlay.Packet, queueCap)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bp *batchPipelines) Pool() *overlay.PacketPool {
|
|
||||||
if bp == nil || !bp.enabled {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return bp.pool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bp *batchPipelines) Enabled() bool {
|
|
||||||
return bp != nil && bp.enabled
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bp *batchPipelines) batchSizeHint() int {
|
|
||||||
if bp == nil || bp.batchSize <= 0 {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
return bp.batchSize
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bp *batchPipelines) rxQueue(i int) chan *overlay.Packet {
|
|
||||||
if bp == nil || !bp.enabled || i < 0 || i >= len(bp.rxQueues) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return bp.rxQueues[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bp *batchPipelines) txQueue(i int) chan queuedDatagram {
|
|
||||||
if bp == nil || !bp.enabled || i < 0 || i >= len(bp.txQueues) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return bp.txQueues[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bp *batchPipelines) tunQueue(i int) chan *overlay.Packet {
|
|
||||||
if bp == nil || !bp.enabled || i < 0 || i >= len(bp.tunQueues) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return bp.tunQueues[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bp *batchPipelines) txQueueLen(i int) int {
|
|
||||||
q := bp.txQueue(i)
|
|
||||||
if q == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return len(q)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bp *batchPipelines) tunQueueLen(i int) int {
|
|
||||||
q := bp.tunQueue(i)
|
|
||||||
if q == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return len(q)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bp *batchPipelines) enqueueRx(i int, pkt *overlay.Packet) bool {
|
|
||||||
q := bp.rxQueue(i)
|
|
||||||
if q == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
q <- pkt
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bp *batchPipelines) enqueueTx(i int, pkt *overlay.Packet, addr netip.AddrPort) bool {
|
|
||||||
q := bp.txQueue(i)
|
|
||||||
if q == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
q <- queuedDatagram{packet: pkt, addr: addr}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bp *batchPipelines) enqueueTun(i int, pkt *overlay.Packet) bool {
|
|
||||||
q := bp.tunQueue(i)
|
|
||||||
if q == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
q <- pkt
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bp *batchPipelines) newPacket() *overlay.Packet {
|
|
||||||
if bp == nil || !bp.enabled || bp.pool == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return bp.pool.Get()
|
|
||||||
}
|
|
||||||
97
cert/pem.go
97
cert/pem.go
@@ -1,10 +1,8 @@
|
|||||||
package cert
|
package cert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/hex"
|
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
)
|
)
|
||||||
@@ -140,101 +138,6 @@ func MarshalSigningPrivateKeyToPEM(curve Curve, b []byte) []byte {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Backward compatibility functions for older API
|
|
||||||
func MarshalX25519PublicKey(b []byte) []byte {
|
|
||||||
return MarshalPublicKeyToPEM(Curve_CURVE25519, b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func MarshalX25519PrivateKey(b []byte) []byte {
|
|
||||||
return MarshalPrivateKeyToPEM(Curve_CURVE25519, b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func MarshalPublicKey(curve Curve, b []byte) []byte {
|
|
||||||
return MarshalPublicKeyToPEM(curve, b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func MarshalPrivateKey(curve Curve, b []byte) []byte {
|
|
||||||
return MarshalPrivateKeyToPEM(curve, b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NebulaCertificate is a compatibility wrapper for the old API
|
|
||||||
type NebulaCertificate struct {
|
|
||||||
Details NebulaCertificateDetails
|
|
||||||
Signature []byte
|
|
||||||
cert Certificate
|
|
||||||
}
|
|
||||||
|
|
||||||
// NebulaCertificateDetails is a compatibility wrapper for certificate details
|
|
||||||
type NebulaCertificateDetails struct {
|
|
||||||
Name string
|
|
||||||
NotBefore time.Time
|
|
||||||
NotAfter time.Time
|
|
||||||
PublicKey []byte
|
|
||||||
IsCA bool
|
|
||||||
Issuer []byte
|
|
||||||
Curve Curve
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnmarshalNebulaCertificateFromPEM provides backward compatibility with the old API
|
|
||||||
func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) {
|
|
||||||
c, rest, err := UnmarshalCertificateFromPEM(b)
|
|
||||||
if err != nil {
|
|
||||||
return nil, rest, err
|
|
||||||
}
|
|
||||||
|
|
||||||
issuerBytes, err := func() ([]byte, error) {
|
|
||||||
issuer := c.Issuer()
|
|
||||||
if issuer == "" {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
decoded, err := hex.DecodeString(issuer)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to decode issuer fingerprint: %w", err)
|
|
||||||
}
|
|
||||||
return decoded, nil
|
|
||||||
}()
|
|
||||||
if err != nil {
|
|
||||||
return nil, rest, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pubKey := c.PublicKey()
|
|
||||||
if pubKey != nil {
|
|
||||||
pubKey = append([]byte(nil), pubKey...)
|
|
||||||
}
|
|
||||||
|
|
||||||
sig := c.Signature()
|
|
||||||
if sig != nil {
|
|
||||||
sig = append([]byte(nil), sig...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &NebulaCertificate{
|
|
||||||
Details: NebulaCertificateDetails{
|
|
||||||
Name: c.Name(),
|
|
||||||
NotBefore: c.NotBefore(),
|
|
||||||
NotAfter: c.NotAfter(),
|
|
||||||
PublicKey: pubKey,
|
|
||||||
IsCA: c.IsCA(),
|
|
||||||
Issuer: issuerBytes,
|
|
||||||
Curve: c.Curve(),
|
|
||||||
},
|
|
||||||
Signature: sig,
|
|
||||||
cert: c,
|
|
||||||
}, rest, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IssuerString returns the issuer in hex format for compatibility
|
|
||||||
func (n *NebulaCertificate) IssuerString() string {
|
|
||||||
if n.Details.Issuer == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return hex.EncodeToString(n.Details.Issuer)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Certificate returns the underlying certificate (read-only)
|
|
||||||
func (n *NebulaCertificate) Certificate() Certificate {
|
|
||||||
return n.cert
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non
|
// UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non
|
||||||
// consumed data or an error on failure
|
// consumed data or an error on failure
|
||||||
func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
||||||
|
|||||||
12
firewall.go
12
firewall.go
@@ -423,7 +423,7 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
|||||||
|
|
||||||
// Drop returns an error if the packet should be dropped, explaining why. It
|
// Drop returns an error if the packet should be dropped, explaining why. It
|
||||||
// returns nil if the packet should not be dropped.
|
// returns nil if the packet should not be dropped.
|
||||||
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache *firewall.ConntrackCache) error {
|
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
|
||||||
// Check if we spoke to this tuple, if we did then allow this packet
|
// Check if we spoke to this tuple, if we did then allow this packet
|
||||||
if f.inConns(fp, h, caPool, localCache) {
|
if f.inConns(fp, h, caPool, localCache) {
|
||||||
return nil
|
return nil
|
||||||
@@ -490,9 +490,11 @@ func (f *Firewall) EmitStats() {
|
|||||||
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
|
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache *firewall.ConntrackCache) bool {
|
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
|
||||||
if localCache != nil && localCache.Has(fp) {
|
if localCache != nil {
|
||||||
return true
|
if _, ok := localCache[fp]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
conntrack := f.Conntrack
|
conntrack := f.Conntrack
|
||||||
conntrack.Lock()
|
conntrack.Lock()
|
||||||
@@ -557,7 +559,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
|||||||
conntrack.Unlock()
|
conntrack.Unlock()
|
||||||
|
|
||||||
if localCache != nil {
|
if localCache != nil {
|
||||||
localCache.Add(fp)
|
localCache[fp] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package firewall
|
package firewall
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -10,58 +9,13 @@ import (
|
|||||||
|
|
||||||
// ConntrackCache is used as a local routine cache to know if a given flow
|
// ConntrackCache is used as a local routine cache to know if a given flow
|
||||||
// has been seen in the conntrack table.
|
// has been seen in the conntrack table.
|
||||||
type ConntrackCache struct {
|
type ConntrackCache map[Packet]struct{}
|
||||||
mu sync.Mutex
|
|
||||||
entries map[Packet]struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newConntrackCache() *ConntrackCache {
|
|
||||||
return &ConntrackCache{entries: make(map[Packet]struct{})}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ConntrackCache) Has(p Packet) bool {
|
|
||||||
if c == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
c.mu.Lock()
|
|
||||||
_, ok := c.entries[p]
|
|
||||||
c.mu.Unlock()
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ConntrackCache) Add(p Packet) {
|
|
||||||
if c == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.mu.Lock()
|
|
||||||
c.entries[p] = struct{}{}
|
|
||||||
c.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ConntrackCache) Len() int {
|
|
||||||
if c == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
c.mu.Lock()
|
|
||||||
l := len(c.entries)
|
|
||||||
c.mu.Unlock()
|
|
||||||
return l
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ConntrackCache) Reset(capHint int) {
|
|
||||||
if c == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.mu.Lock()
|
|
||||||
c.entries = make(map[Packet]struct{}, capHint)
|
|
||||||
c.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
type ConntrackCacheTicker struct {
|
type ConntrackCacheTicker struct {
|
||||||
cacheV uint64
|
cacheV uint64
|
||||||
cacheTick atomic.Uint64
|
cacheTick atomic.Uint64
|
||||||
|
|
||||||
cache *ConntrackCache
|
cache ConntrackCache
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
|
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
|
||||||
@@ -69,7 +23,9 @@ func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
c := &ConntrackCacheTicker{cache: newConntrackCache()}
|
c := &ConntrackCacheTicker{
|
||||||
|
cache: ConntrackCache{},
|
||||||
|
}
|
||||||
|
|
||||||
go c.tick(d)
|
go c.tick(d)
|
||||||
|
|
||||||
@@ -85,17 +41,17 @@ func (c *ConntrackCacheTicker) tick(d time.Duration) {
|
|||||||
|
|
||||||
// Get checks if the cache ticker has moved to the next version before returning
|
// Get checks if the cache ticker has moved to the next version before returning
|
||||||
// the map. If it has moved, we reset the map.
|
// the map. If it has moved, we reset the map.
|
||||||
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) *ConntrackCache {
|
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if tick := c.cacheTick.Load(); tick != c.cacheV {
|
if tick := c.cacheTick.Load(); tick != c.cacheV {
|
||||||
c.cacheV = tick
|
c.cacheV = tick
|
||||||
if ll := c.cache.Len(); ll > 0 {
|
if ll := len(c.cache); ll > 0 {
|
||||||
if l.Level == logrus.DebugLevel {
|
if l.Level == logrus.DebugLevel {
|
||||||
l.WithField("len", ll).Debug("resetting conntrack cache")
|
l.WithField("len", ll).Debug("resetting conntrack cache")
|
||||||
}
|
}
|
||||||
c.cache.Reset(ll)
|
c.cache = make(ConntrackCache, ll)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
251
inside.go
251
inside.go
@@ -2,18 +2,159 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
"github.com/slackhq/nebula/noiseutil"
|
"github.com/slackhq/nebula/noiseutil"
|
||||||
"github.com/slackhq/nebula/overlay"
|
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache *firewall.ConntrackCache) {
|
// consumeInsidePackets processes multiple packets in a batch for improved performance
|
||||||
|
// packets: slice of packet buffers to process
|
||||||
|
// sizes: slice of packet sizes
|
||||||
|
// count: number of packets to process
|
||||||
|
// outs: slice of output buffers (one per packet) with virtio headroom
|
||||||
|
// q: queue index
|
||||||
|
// localCache: firewall conntrack cache
|
||||||
|
// batchPackets: pre-allocated slice for accumulating encrypted packets
|
||||||
|
// batchAddrs: pre-allocated slice for accumulating destination addresses
|
||||||
|
func (f *Interface) consumeInsidePackets(packets [][]byte, sizes []int, count int, outs [][]byte, nb []byte, q int, localCache firewall.ConntrackCache, batchPackets *[][]byte, batchAddrs *[]netip.AddrPort) {
|
||||||
|
// Reusable per-packet state
|
||||||
|
fwPacket := &firewall.Packet{}
|
||||||
|
|
||||||
|
// Reset batch accumulation slices (reuse capacity)
|
||||||
|
*batchPackets = (*batchPackets)[:0]
|
||||||
|
*batchAddrs = (*batchAddrs)[:0]
|
||||||
|
|
||||||
|
// Process each packet in the batch
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
packet := packets[i][:sizes[i]]
|
||||||
|
out := outs[i]
|
||||||
|
|
||||||
|
// Inline the consumeInsidePacket logic for better performance
|
||||||
|
err := newPacket(packet, false, fwPacket)
|
||||||
|
if err != nil {
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ignore local broadcast packets
|
||||||
|
if f.dropLocalBroadcast {
|
||||||
|
if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) {
|
||||||
|
// Immediately forward packets from self to self.
|
||||||
|
if immediatelyForwardToSelf {
|
||||||
|
_, err := f.readers[q].Write(packet)
|
||||||
|
if err != nil {
|
||||||
|
f.l.WithError(err).Error("Failed to forward to tun")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ignore multicast packets
|
||||||
|
if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
|
||||||
|
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
||||||
|
})
|
||||||
|
|
||||||
|
if hostinfo == nil {
|
||||||
|
f.rejectInside(packet, out, q)
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
|
||||||
|
WithField("fwPacket", fwPacket).
|
||||||
|
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ready {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
|
||||||
|
if dropReason != nil {
|
||||||
|
f.rejectInside(packet, out, q)
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
hostinfo.logger(f.l).
|
||||||
|
WithField("fwPacket", fwPacket).
|
||||||
|
WithField("reason", dropReason).
|
||||||
|
Debugln("dropping outbound packet")
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt and prepare packet for batch sending
|
||||||
|
ci := hostinfo.ConnectionState
|
||||||
|
if ci.eKey == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this needs relay - if so, send immediately and skip batching
|
||||||
|
useRelay := !hostinfo.remote.IsValid()
|
||||||
|
if useRelay {
|
||||||
|
// Handle relay sends individually (less common path)
|
||||||
|
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, packet, nb, out, q)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt the packet for batch sending
|
||||||
|
if noiseutil.EncryptLockNeeded {
|
||||||
|
ci.writeLock.Lock()
|
||||||
|
}
|
||||||
|
c := ci.messageCounter.Add(1)
|
||||||
|
out = header.Encode(out, header.Version, header.Message, 0, hostinfo.remoteIndexId, c)
|
||||||
|
f.connectionManager.Out(hostinfo)
|
||||||
|
|
||||||
|
// Query lighthouse if needed
|
||||||
|
if hostinfo.lastRebindCount != f.rebindCount {
|
||||||
|
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
|
||||||
|
hostinfo.lastRebindCount = f.rebindCount
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err = ci.eKey.EncryptDanger(out, out, packet, c, nb)
|
||||||
|
if noiseutil.EncryptLockNeeded {
|
||||||
|
ci.writeLock.Unlock()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
hostinfo.logger(f.l).WithError(err).
|
||||||
|
WithField("counter", c).
|
||||||
|
Error("Failed to encrypt outgoing packet")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add to batch
|
||||||
|
*batchPackets = append(*batchPackets, out)
|
||||||
|
*batchAddrs = append(*batchAddrs, hostinfo.remote)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send all accumulated packets in one batch
|
||||||
|
if len(*batchPackets) > 0 {
|
||||||
|
batchSize := len(*batchPackets)
|
||||||
|
f.batchMetrics.udpWriteSize.Update(int64(batchSize))
|
||||||
|
|
||||||
|
n, err := f.writers[q].WriteMulti(*batchPackets, *batchAddrs)
|
||||||
|
if err != nil {
|
||||||
|
f.l.WithError(err).WithField("sent", n).WithField("total", batchSize).Error("Failed to send batch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
err := newPacket(packet, false, fwPacket)
|
err := newPacket(packet, false, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
@@ -337,21 +478,9 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
if ci.eKey == nil {
|
if ci.eKey == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
target := remote
|
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
|
||||||
if !target.IsValid() {
|
|
||||||
target = hostinfo.remote
|
|
||||||
}
|
|
||||||
useRelay := !target.IsValid()
|
|
||||||
fullOut := out
|
fullOut := out
|
||||||
|
|
||||||
var pkt *overlay.Packet
|
|
||||||
if !useRelay && f.batches.Enabled() {
|
|
||||||
pkt = f.batches.newPacket()
|
|
||||||
if pkt != nil {
|
|
||||||
out = pkt.Payload()[:0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if useRelay {
|
if useRelay {
|
||||||
if len(out) < header.Len {
|
if len(out) < header.Len {
|
||||||
// out always has a capacity of mtu, but not always a length greater than the header.Len.
|
// out always has a capacity of mtu, but not always a length greater than the header.Len.
|
||||||
@@ -385,85 +514,41 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
if len(p) > 0 && slicesOverlap(out, p) {
|
|
||||||
tmp := make([]byte, len(p))
|
|
||||||
copy(tmp, p)
|
|
||||||
p = tmp
|
|
||||||
}
|
|
||||||
out, err = ci.eKey.EncryptDanger(out, out, p, c, nb)
|
out, err = ci.eKey.EncryptDanger(out, out, p, c, nb)
|
||||||
if noiseutil.EncryptLockNeeded {
|
if noiseutil.EncryptLockNeeded {
|
||||||
ci.writeLock.Unlock()
|
ci.writeLock.Unlock()
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if pkt != nil {
|
|
||||||
pkt.Release()
|
|
||||||
}
|
|
||||||
hostinfo.logger(f.l).WithError(err).
|
hostinfo.logger(f.l).WithError(err).
|
||||||
WithField("udpAddr", target).WithField("counter", c).
|
WithField("udpAddr", remote).WithField("counter", c).
|
||||||
WithField("attemptedCounter", c).
|
WithField("attemptedCounter", c).
|
||||||
Error("Failed to encrypt outgoing packet")
|
Error("Failed to encrypt outgoing packet")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if target.IsValid() {
|
if remote.IsValid() {
|
||||||
if pkt != nil {
|
err = f.writers[q].WriteTo(out, remote)
|
||||||
pkt.Len = len(out)
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
f.l.WithFields(logrus.Fields{
|
|
||||||
"queue": q,
|
|
||||||
"dest": target,
|
|
||||||
"payload_len": pkt.Len,
|
|
||||||
"use_batches": true,
|
|
||||||
"remote_index": hostinfo.remoteIndexId,
|
|
||||||
}).Debug("enqueueing packet to UDP batch queue")
|
|
||||||
}
|
|
||||||
if f.tryQueuePacket(q, pkt, target) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
f.l.WithFields(logrus.Fields{
|
|
||||||
"queue": q,
|
|
||||||
"dest": target,
|
|
||||||
}).Debug("failed to enqueue packet; falling back to immediate send")
|
|
||||||
}
|
|
||||||
f.writeImmediatePacket(q, pkt, target, hostinfo)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if f.tryQueueDatagram(q, out, target) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
f.writeImmediate(q, out, target, hostinfo)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// fall back to relay path
|
|
||||||
if pkt != nil {
|
|
||||||
pkt.Release()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to send via a relay
|
|
||||||
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
|
|
||||||
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.relayState.DeleteRelay(relayIP)
|
hostinfo.logger(f.l).WithError(err).
|
||||||
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
|
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
||||||
continue
|
}
|
||||||
|
} else if hostinfo.remote.IsValid() {
|
||||||
|
err = f.writers[q].WriteTo(out, hostinfo.remote)
|
||||||
|
if err != nil {
|
||||||
|
hostinfo.logger(f.l).WithError(err).
|
||||||
|
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Try to send via a relay
|
||||||
|
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
|
||||||
|
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
||||||
|
if err != nil {
|
||||||
|
hostinfo.relayState.DeleteRelay(relayIP)
|
||||||
|
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
|
||||||
|
break
|
||||||
}
|
}
|
||||||
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// slicesOverlap reports whether the two byte slices share any portion of memory.
|
|
||||||
// cipher.AEAD.Seal requires plaintext and dst to live in disjoint regions.
|
|
||||||
func slicesOverlap(a, b []byte) bool {
|
|
||||||
if len(a) == 0 || len(b) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
aStart := uintptr(unsafe.Pointer(&a[0]))
|
|
||||||
aEnd := aStart + uintptr(len(a))
|
|
||||||
bStart := uintptr(unsafe.Pointer(&b[0]))
|
|
||||||
bEnd := bStart + uintptr(len(b))
|
|
||||||
return aStart < bEnd && bStart < aEnd
|
|
||||||
}
|
|
||||||
|
|||||||
764
interface.go
764
interface.go
@@ -4,11 +4,9 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -22,13 +20,8 @@ import (
|
|||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const mtu = 9001
|
||||||
mtu = 9001
|
const virtioNetHdrLen = overlay.VirtioNetHdrLen
|
||||||
defaultGSOFlushInterval = 150 * time.Microsecond
|
|
||||||
defaultBatchQueueDepthFactor = 4
|
|
||||||
defaultGSOMaxSegments = 8
|
|
||||||
maxKernelGSOSegments = 64
|
|
||||||
)
|
|
||||||
|
|
||||||
type InterfaceConfig struct {
|
type InterfaceConfig struct {
|
||||||
HostMap *HostMap
|
HostMap *HostMap
|
||||||
@@ -43,9 +36,6 @@ type InterfaceConfig struct {
|
|||||||
connectionManager *connectionManager
|
connectionManager *connectionManager
|
||||||
DropLocalBroadcast bool
|
DropLocalBroadcast bool
|
||||||
DropMulticast bool
|
DropMulticast bool
|
||||||
EnableGSO bool
|
|
||||||
EnableGRO bool
|
|
||||||
GSOMaxSegments int
|
|
||||||
routines int
|
routines int
|
||||||
MessageMetrics *MessageMetrics
|
MessageMetrics *MessageMetrics
|
||||||
version string
|
version string
|
||||||
@@ -57,11 +47,16 @@ type InterfaceConfig struct {
|
|||||||
reQueryWait time.Duration
|
reQueryWait time.Duration
|
||||||
|
|
||||||
ConntrackCacheTimeout time.Duration
|
ConntrackCacheTimeout time.Duration
|
||||||
BatchFlushInterval time.Duration
|
|
||||||
BatchQueueDepth int
|
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type batchMetrics struct {
|
||||||
|
udpReadSize metrics.Histogram
|
||||||
|
tunReadSize metrics.Histogram
|
||||||
|
udpWriteSize metrics.Histogram
|
||||||
|
tunWriteSize metrics.Histogram
|
||||||
|
}
|
||||||
|
|
||||||
type Interface struct {
|
type Interface struct {
|
||||||
hostMap *HostMap
|
hostMap *HostMap
|
||||||
outside udp.Conn
|
outside udp.Conn
|
||||||
@@ -96,24 +91,14 @@ type Interface struct {
|
|||||||
version string
|
version string
|
||||||
|
|
||||||
conntrackCacheTimeout time.Duration
|
conntrackCacheTimeout time.Duration
|
||||||
batchQueueDepth int
|
|
||||||
enableGSO bool
|
|
||||||
enableGRO bool
|
|
||||||
gsoMaxSegments int
|
|
||||||
batchUDPQueueGauge metrics.Gauge
|
|
||||||
batchUDPFlushCounter metrics.Counter
|
|
||||||
batchTunQueueGauge metrics.Gauge
|
|
||||||
batchTunFlushCounter metrics.Counter
|
|
||||||
batchFlushInterval atomic.Int64
|
|
||||||
sendSem chan struct{}
|
|
||||||
|
|
||||||
writers []udp.Conn
|
writers []udp.Conn
|
||||||
readers []io.ReadWriteCloser
|
readers []overlay.BatchReadWriter
|
||||||
batches batchPipelines
|
|
||||||
|
|
||||||
metricHandshakes metrics.Histogram
|
metricHandshakes metrics.Histogram
|
||||||
messageMetrics *MessageMetrics
|
messageMetrics *MessageMetrics
|
||||||
cachedPacketMetrics *cachedPacketMetrics
|
cachedPacketMetrics *cachedPacketMetrics
|
||||||
|
batchMetrics *batchMetrics
|
||||||
|
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
@@ -184,22 +169,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
return nil, errors.New("no connection manager")
|
return nil, errors.New("no connection manager")
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.GSOMaxSegments <= 0 {
|
|
||||||
c.GSOMaxSegments = defaultGSOMaxSegments
|
|
||||||
}
|
|
||||||
if c.GSOMaxSegments > maxKernelGSOSegments {
|
|
||||||
c.GSOMaxSegments = maxKernelGSOSegments
|
|
||||||
}
|
|
||||||
if c.BatchQueueDepth <= 0 {
|
|
||||||
c.BatchQueueDepth = c.routines * defaultBatchQueueDepthFactor
|
|
||||||
}
|
|
||||||
if c.BatchFlushInterval < 0 {
|
|
||||||
c.BatchFlushInterval = 0
|
|
||||||
}
|
|
||||||
if c.BatchFlushInterval == 0 && c.EnableGSO {
|
|
||||||
c.BatchFlushInterval = defaultGSOFlushInterval
|
|
||||||
}
|
|
||||||
|
|
||||||
cs := c.pki.getCertState()
|
cs := c.pki.getCertState()
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
pki: c.pki,
|
pki: c.pki,
|
||||||
@@ -216,7 +185,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
routines: c.routines,
|
routines: c.routines,
|
||||||
version: c.version,
|
version: c.version,
|
||||||
writers: make([]udp.Conn, c.routines),
|
writers: make([]udp.Conn, c.routines),
|
||||||
readers: make([]io.ReadWriteCloser, c.routines),
|
readers: make([]overlay.BatchReadWriter, c.routines),
|
||||||
myVpnNetworks: cs.myVpnNetworks,
|
myVpnNetworks: cs.myVpnNetworks,
|
||||||
myVpnNetworksTable: cs.myVpnNetworksTable,
|
myVpnNetworksTable: cs.myVpnNetworksTable,
|
||||||
myVpnAddrs: cs.myVpnAddrs,
|
myVpnAddrs: cs.myVpnAddrs,
|
||||||
@@ -225,10 +194,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
relayManager: c.relayManager,
|
relayManager: c.relayManager,
|
||||||
connectionManager: c.connectionManager,
|
connectionManager: c.connectionManager,
|
||||||
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
||||||
batchQueueDepth: c.BatchQueueDepth,
|
|
||||||
enableGSO: c.EnableGSO,
|
|
||||||
enableGRO: c.EnableGRO,
|
|
||||||
gsoMaxSegments: c.GSOMaxSegments,
|
|
||||||
|
|
||||||
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
||||||
messageMetrics: c.MessageMetrics,
|
messageMetrics: c.MessageMetrics,
|
||||||
@@ -236,30 +201,19 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
sent: metrics.GetOrRegisterCounter("hostinfo.cached_packets.sent", nil),
|
sent: metrics.GetOrRegisterCounter("hostinfo.cached_packets.sent", nil),
|
||||||
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
|
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
|
||||||
},
|
},
|
||||||
|
batchMetrics: &batchMetrics{
|
||||||
|
udpReadSize: metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024)),
|
||||||
|
tunReadSize: metrics.GetOrRegisterHistogram("batch.tun_read_size", nil, metrics.NewUniformSample(1024)),
|
||||||
|
udpWriteSize: metrics.GetOrRegisterHistogram("batch.udp_write_size", nil, metrics.NewUniformSample(1024)),
|
||||||
|
tunWriteSize: metrics.GetOrRegisterHistogram("batch.tun_write_size", nil, metrics.NewUniformSample(1024)),
|
||||||
|
},
|
||||||
|
|
||||||
l: c.l,
|
l: c.l,
|
||||||
}
|
}
|
||||||
|
|
||||||
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
||||||
ifce.batchUDPQueueGauge = metrics.GetOrRegisterGauge("batch.udp.queue_depth", nil)
|
|
||||||
ifce.batchUDPFlushCounter = metrics.GetOrRegisterCounter("batch.udp.flushes", nil)
|
|
||||||
ifce.batchTunQueueGauge = metrics.GetOrRegisterGauge("batch.tun.queue_depth", nil)
|
|
||||||
ifce.batchTunFlushCounter = metrics.GetOrRegisterCounter("batch.tun.flushes", nil)
|
|
||||||
ifce.batchFlushInterval.Store(int64(c.BatchFlushInterval))
|
|
||||||
ifce.sendSem = make(chan struct{}, c.routines)
|
|
||||||
ifce.batches.init(c.Inside, c.routines, c.BatchQueueDepth, c.GSOMaxSegments)
|
|
||||||
ifce.reQueryEvery.Store(c.reQueryEvery)
|
ifce.reQueryEvery.Store(c.reQueryEvery)
|
||||||
ifce.reQueryWait.Store(int64(c.reQueryWait))
|
ifce.reQueryWait.Store(int64(c.reQueryWait))
|
||||||
if c.l.Level >= logrus.DebugLevel {
|
|
||||||
c.l.WithFields(logrus.Fields{
|
|
||||||
"enableGSO": c.EnableGSO,
|
|
||||||
"enableGRO": c.EnableGRO,
|
|
||||||
"gsoMaxSegments": c.GSOMaxSegments,
|
|
||||||
"batchQueueDepth": c.BatchQueueDepth,
|
|
||||||
"batchFlush": c.BatchFlushInterval,
|
|
||||||
"batching": ifce.batches.Enabled(),
|
|
||||||
}).Debug("initialized batch pipelines")
|
|
||||||
}
|
|
||||||
|
|
||||||
ifce.connectionManager.intf = ifce
|
ifce.connectionManager.intf = ifce
|
||||||
|
|
||||||
@@ -285,7 +239,7 @@ func (f *Interface) activate() {
|
|||||||
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
|
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
|
||||||
|
|
||||||
// Prepare n tun queues
|
// Prepare n tun queues
|
||||||
var reader io.ReadWriteCloser = f.inside
|
var reader overlay.BatchReadWriter = f.inside
|
||||||
for i := 0; i < f.routines; i++ {
|
for i := 0; i < f.routines; i++ {
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
reader, err = f.inside.NewMultiQueueReader()
|
reader, err = f.inside.NewMultiQueueReader()
|
||||||
@@ -308,18 +262,6 @@ func (f *Interface) run() {
|
|||||||
go f.listenOut(i)
|
go f.listenOut(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
f.l.WithField("batching", f.batches.Enabled()).Debug("starting interface run loops")
|
|
||||||
}
|
|
||||||
|
|
||||||
if f.batches.Enabled() {
|
|
||||||
for i := 0; i < f.routines; i++ {
|
|
||||||
go f.runInsideBatchWorker(i)
|
|
||||||
go f.runTunWriteQueue(i)
|
|
||||||
go f.runSendQueue(i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Launch n queues to read packets from tun dev
|
// Launch n queues to read packets from tun dev
|
||||||
for i := 0; i < f.routines; i++ {
|
for i := 0; i < f.routines; i++ {
|
||||||
go f.listenIn(f.readers[i], i)
|
go f.listenIn(f.readers[i], i)
|
||||||
@@ -338,625 +280,69 @@ func (f *Interface) listenOut(i int) {
|
|||||||
|
|
||||||
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||||
lhh := f.lightHouse.NewRequestHandler()
|
lhh := f.lightHouse.NewRequestHandler()
|
||||||
plaintext := make([]byte, udp.MTU)
|
|
||||||
|
// Pre-allocate output buffers for batch processing
|
||||||
|
batchSize := li.BatchSize()
|
||||||
|
outs := make([][]byte, batchSize)
|
||||||
|
for idx := range outs {
|
||||||
|
// Allocate full buffer with virtio header space
|
||||||
|
outs[idx] = make([]byte, virtioNetHdrLen, virtioNetHdrLen+udp.MTU)
|
||||||
|
}
|
||||||
|
|
||||||
h := &header.H{}
|
h := &header.H{}
|
||||||
fwPacket := &firewall.Packet{}
|
fwPacket := &firewall.Packet{}
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12)
|
||||||
|
|
||||||
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
li.ListenOutBatch(func(addrs []netip.AddrPort, payloads [][]byte, count int) {
|
||||||
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
f.readOutsidePacketsBatch(addrs, payloads, count, outs[:count], nb, i, h, fwPacket, lhh, ctCache.Get(f.l))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
func (f *Interface) listenIn(reader overlay.BatchReadWriter, i int) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
|
||||||
if f.batches.Enabled() {
|
batchSize := reader.BatchSize()
|
||||||
if br, ok := reader.(overlay.BatchReader); ok {
|
|
||||||
f.listenInBatchLocked(reader, br, i)
|
// Allocate buffers for batch reading
|
||||||
return
|
bufs := make([][]byte, batchSize)
|
||||||
}
|
for idx := range bufs {
|
||||||
|
bufs[idx] = make([]byte, mtu)
|
||||||
|
}
|
||||||
|
sizes := make([]int, batchSize)
|
||||||
|
|
||||||
|
// Allocate output buffers for batch processing (one per packet)
|
||||||
|
// Each has virtio header headroom to avoid copies on write
|
||||||
|
outs := make([][]byte, batchSize)
|
||||||
|
for idx := range outs {
|
||||||
|
outBuf := make([]byte, virtioNetHdrLen+mtu)
|
||||||
|
outs[idx] = outBuf[virtioNetHdrLen:] // Slice starting after headroom
|
||||||
}
|
}
|
||||||
|
|
||||||
f.listenInLegacyLocked(reader, i)
|
// Pre-allocate batch accumulation buffers for sending
|
||||||
}
|
batchPackets := make([][]byte, 0, batchSize)
|
||||||
|
batchAddrs := make([]netip.AddrPort, 0, batchSize)
|
||||||
|
|
||||||
func (f *Interface) listenInLegacyLocked(reader io.ReadWriteCloser, i int) {
|
// Pre-allocate nonce buffer (reused for all encryptions)
|
||||||
packet := make([]byte, mtu)
|
nb := make([]byte, 12)
|
||||||
out := make([]byte, mtu)
|
|
||||||
fwPacket := &firewall.Packet{}
|
|
||||||
nb := make([]byte, 12, 12)
|
|
||||||
|
|
||||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
n, err := reader.Read(packet)
|
n, err := reader.BatchRead(bufs, sizes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.l.WithError(err).Error("Error while reading outbound packet")
|
f.l.WithError(err).Error("Error while batch reading outbound packets")
|
||||||
// This only seems to happen when something fatal happens to the fd, so exit.
|
// This only seems to happen when something fatal happens to the fd, so exit.
|
||||||
os.Exit(2)
|
os.Exit(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
f.batchMetrics.tunReadSize.Update(int64(n))
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) listenInBatchLocked(raw io.ReadWriteCloser, reader overlay.BatchReader, i int) {
|
// Process all packets in the batch at once
|
||||||
pool := f.batches.Pool()
|
f.consumeInsidePackets(bufs, sizes, n, outs, nb, i, conntrackCache.Get(f.l), &batchPackets, &batchAddrs)
|
||||||
if pool == nil {
|
|
||||||
f.l.Warn("batch pipeline enabled without an allocated pool; falling back to single-packet reads")
|
|
||||||
f.listenInLegacyLocked(raw, i)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
packets, err := reader.ReadIntoBatch(pool)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if isVirtioHeadroomError(err) {
|
|
||||||
f.l.WithError(err).Warn("Batch reader fell back due to tun headroom issue")
|
|
||||||
f.listenInLegacyLocked(raw, i)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
f.l.WithError(err).Error("Error while reading outbound packet batch")
|
|
||||||
os.Exit(2)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(packets) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, pkt := range packets {
|
|
||||||
if pkt == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !f.batches.enqueueRx(i, pkt) {
|
|
||||||
pkt.Release()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) runInsideBatchWorker(i int) {
|
|
||||||
queue := f.batches.rxQueue(i)
|
|
||||||
if queue == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
out := make([]byte, mtu)
|
|
||||||
fwPacket := &firewall.Packet{}
|
|
||||||
nb := make([]byte, 12, 12)
|
|
||||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
|
||||||
|
|
||||||
for pkt := range queue {
|
|
||||||
if pkt == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
f.consumeInsidePacket(pkt.Payload(), fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
|
||||||
pkt.Release()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) runSendQueue(i int) {
|
|
||||||
queue := f.batches.txQueue(i)
|
|
||||||
if queue == nil {
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
f.l.WithField("queue", i).Debug("tx queue not initialized; batching disabled for writer")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
writer := f.writerForIndex(i)
|
|
||||||
if writer == nil {
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
f.l.WithField("queue", i).Debug("no UDP writer for batch queue")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
f.l.WithField("queue", i).Debug("send queue worker started")
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if f.l.Level >= logrus.WarnLevel {
|
|
||||||
f.l.WithField("queue", i).Warn("send queue worker exited")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
batchCap := f.batches.batchSizeHint()
|
|
||||||
if batchCap <= 0 {
|
|
||||||
batchCap = 1
|
|
||||||
}
|
|
||||||
gsoLimit := f.effectiveGSOMaxSegments()
|
|
||||||
if gsoLimit > batchCap {
|
|
||||||
batchCap = gsoLimit
|
|
||||||
}
|
|
||||||
pending := make([]queuedDatagram, 0, batchCap)
|
|
||||||
var (
|
|
||||||
flushTimer *time.Timer
|
|
||||||
flushC <-chan time.Time
|
|
||||||
)
|
|
||||||
dispatch := func(reason string, timerFired bool) {
|
|
||||||
if len(pending) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
batch := pending
|
|
||||||
f.flushAndReleaseBatch(i, writer, batch, reason)
|
|
||||||
for idx := range batch {
|
|
||||||
batch[idx] = queuedDatagram{}
|
|
||||||
}
|
|
||||||
pending = pending[:0]
|
|
||||||
if flushTimer != nil {
|
|
||||||
if !timerFired {
|
|
||||||
if !flushTimer.Stop() {
|
|
||||||
select {
|
|
||||||
case <-flushTimer.C:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
flushTimer = nil
|
|
||||||
flushC = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
armTimer := func() {
|
|
||||||
delay := f.currentBatchFlushInterval()
|
|
||||||
if delay <= 0 {
|
|
||||||
dispatch("nogso", false)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if flushTimer == nil {
|
|
||||||
flushTimer = time.NewTimer(delay)
|
|
||||||
flushC = flushTimer.C
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case d := <-queue:
|
|
||||||
if d.packet == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
f.l.WithFields(logrus.Fields{
|
|
||||||
"queue": i,
|
|
||||||
"payload_len": d.packet.Len,
|
|
||||||
"dest": d.addr,
|
|
||||||
}).Debug("send queue received packet")
|
|
||||||
}
|
|
||||||
pending = append(pending, d)
|
|
||||||
if gsoLimit > 0 && len(pending) >= gsoLimit {
|
|
||||||
dispatch("gso", false)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if len(pending) >= cap(pending) {
|
|
||||||
dispatch("cap", false)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
armTimer()
|
|
||||||
f.observeUDPQueueLen(i)
|
|
||||||
case <-flushC:
|
|
||||||
dispatch("timer", true)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) runTunWriteQueue(i int) {
|
|
||||||
queue := f.batches.tunQueue(i)
|
|
||||||
if queue == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
writer := f.batches.inside
|
|
||||||
if writer == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
requiredHeadroom := writer.BatchHeadroom()
|
|
||||||
|
|
||||||
batchCap := f.batches.batchSizeHint()
|
|
||||||
if batchCap <= 0 {
|
|
||||||
batchCap = 1
|
|
||||||
}
|
|
||||||
pending := make([]*overlay.Packet, 0, batchCap)
|
|
||||||
var (
|
|
||||||
flushTimer *time.Timer
|
|
||||||
flushC <-chan time.Time
|
|
||||||
)
|
|
||||||
flush := func(reason string, timerFired bool) {
|
|
||||||
if len(pending) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
valid := pending[:0]
|
|
||||||
for idx := range pending {
|
|
||||||
if !f.ensurePacketHeadroom(&pending[idx], requiredHeadroom, i, reason) {
|
|
||||||
pending[idx] = nil
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if pending[idx] != nil {
|
|
||||||
valid = append(valid, pending[idx])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(valid) > 0 {
|
|
||||||
if _, err := writer.WriteBatch(valid); err != nil {
|
|
||||||
f.l.WithError(err).
|
|
||||||
WithField("queue", i).
|
|
||||||
WithField("reason", reason).
|
|
||||||
Warn("Failed to write tun batch")
|
|
||||||
for _, pkt := range valid {
|
|
||||||
if pkt != nil {
|
|
||||||
f.writePacketToTun(i, pkt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pending = pending[:0]
|
|
||||||
if flushTimer != nil {
|
|
||||||
if !timerFired {
|
|
||||||
if !flushTimer.Stop() {
|
|
||||||
select {
|
|
||||||
case <-flushTimer.C:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
flushTimer = nil
|
|
||||||
flushC = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
armTimer := func() {
|
|
||||||
delay := f.currentBatchFlushInterval()
|
|
||||||
if delay <= 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if flushTimer == nil {
|
|
||||||
flushTimer = time.NewTimer(delay)
|
|
||||||
flushC = flushTimer.C
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case pkt := <-queue:
|
|
||||||
if pkt == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if f.ensurePacketHeadroom(&pkt, requiredHeadroom, i, "queue") {
|
|
||||||
pending = append(pending, pkt)
|
|
||||||
}
|
|
||||||
if len(pending) >= cap(pending) {
|
|
||||||
flush("cap", false)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
armTimer()
|
|
||||||
f.observeTunQueueLen(i)
|
|
||||||
case <-flushC:
|
|
||||||
flush("timer", true)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) flushAndReleaseBatch(index int, writer udp.Conn, batch []queuedDatagram, reason string) {
|
|
||||||
if len(batch) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
f.flushDatagrams(index, writer, batch, reason)
|
|
||||||
for idx := range batch {
|
|
||||||
if batch[idx].packet != nil {
|
|
||||||
batch[idx].packet.Release()
|
|
||||||
batch[idx].packet = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if f.batchUDPFlushCounter != nil {
|
|
||||||
f.batchUDPFlushCounter.Inc(int64(len(batch)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) flushDatagrams(index int, writer udp.Conn, batch []queuedDatagram, reason string) {
|
|
||||||
if len(batch) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
f.l.WithFields(logrus.Fields{
|
|
||||||
"writer": index,
|
|
||||||
"reason": reason,
|
|
||||||
"pending": len(batch),
|
|
||||||
}).Debug("udp batch flush summary")
|
|
||||||
}
|
|
||||||
maxSeg := f.effectiveGSOMaxSegments()
|
|
||||||
if bw, ok := writer.(udp.BatchConn); ok {
|
|
||||||
chunkCap := maxSeg
|
|
||||||
if chunkCap <= 0 {
|
|
||||||
chunkCap = len(batch)
|
|
||||||
}
|
|
||||||
chunk := make([]udp.Datagram, 0, chunkCap)
|
|
||||||
var (
|
|
||||||
currentAddr netip.AddrPort
|
|
||||||
segments int
|
|
||||||
)
|
|
||||||
flushChunk := func() {
|
|
||||||
if len(chunk) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
f.l.WithFields(logrus.Fields{
|
|
||||||
"writer": index,
|
|
||||||
"segments": len(chunk),
|
|
||||||
"dest": chunk[0].Addr,
|
|
||||||
"reason": reason,
|
|
||||||
"pending_total": len(batch),
|
|
||||||
}).Debug("flushing UDP batch")
|
|
||||||
}
|
|
||||||
if err := bw.WriteBatch(chunk); err != nil {
|
|
||||||
f.l.WithError(err).
|
|
||||||
WithField("writer", index).
|
|
||||||
WithField("reason", reason).
|
|
||||||
Warn("Failed to write UDP batch")
|
|
||||||
}
|
|
||||||
chunk = chunk[:0]
|
|
||||||
segments = 0
|
|
||||||
}
|
|
||||||
for _, item := range batch {
|
|
||||||
if item.packet == nil || !item.addr.IsValid() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
payload := item.packet.Payload()[:item.packet.Len]
|
|
||||||
if segments == 0 {
|
|
||||||
currentAddr = item.addr
|
|
||||||
}
|
|
||||||
if item.addr != currentAddr || (maxSeg > 0 && segments >= maxSeg) {
|
|
||||||
flushChunk()
|
|
||||||
currentAddr = item.addr
|
|
||||||
}
|
|
||||||
chunk = append(chunk, udp.Datagram{Payload: payload, Addr: item.addr})
|
|
||||||
segments++
|
|
||||||
}
|
|
||||||
flushChunk()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, item := range batch {
|
|
||||||
if item.packet == nil || !item.addr.IsValid() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
f.l.WithFields(logrus.Fields{
|
|
||||||
"writer": index,
|
|
||||||
"reason": reason,
|
|
||||||
"dest": item.addr,
|
|
||||||
"segments": 1,
|
|
||||||
}).Debug("flushing UDP batch")
|
|
||||||
}
|
|
||||||
if err := writer.WriteTo(item.packet.Payload()[:item.packet.Len], item.addr); err != nil {
|
|
||||||
f.l.WithError(err).
|
|
||||||
WithField("writer", index).
|
|
||||||
WithField("udpAddr", item.addr).
|
|
||||||
WithField("reason", reason).
|
|
||||||
Warn("Failed to write UDP packet")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) tryQueueDatagram(q int, buf []byte, addr netip.AddrPort) bool {
|
|
||||||
if !addr.IsValid() || !f.batches.Enabled() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
pkt := f.batches.newPacket()
|
|
||||||
if pkt == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
payload := pkt.Payload()
|
|
||||||
if len(payload) < len(buf) {
|
|
||||||
pkt.Release()
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
copy(payload, buf)
|
|
||||||
pkt.Len = len(buf)
|
|
||||||
if f.batches.enqueueTx(q, pkt, addr) {
|
|
||||||
f.observeUDPQueueLen(q)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
pkt.Release()
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) writerForIndex(i int) udp.Conn {
|
|
||||||
if i < 0 || i >= len(f.writers) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return f.writers[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) writeImmediate(q int, buf []byte, addr netip.AddrPort, hostinfo *HostInfo) {
|
|
||||||
writer := f.writerForIndex(q)
|
|
||||||
if writer == nil {
|
|
||||||
f.l.WithField("udpAddr", addr).
|
|
||||||
WithField("writer", q).
|
|
||||||
Error("Failed to write outgoing packet: no writer available")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := writer.WriteTo(buf, addr); err != nil {
|
|
||||||
hostinfo.logger(f.l).
|
|
||||||
WithError(err).
|
|
||||||
WithField("udpAddr", addr).
|
|
||||||
Error("Failed to write outgoing packet")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) tryQueuePacket(q int, pkt *overlay.Packet, addr netip.AddrPort) bool {
|
|
||||||
if pkt == nil || !addr.IsValid() || !f.batches.Enabled() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if f.batches.enqueueTx(q, pkt, addr) {
|
|
||||||
f.observeUDPQueueLen(q)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) writeImmediatePacket(q int, pkt *overlay.Packet, addr netip.AddrPort, hostinfo *HostInfo) {
|
|
||||||
if pkt == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
writer := f.writerForIndex(q)
|
|
||||||
if writer == nil {
|
|
||||||
f.l.WithField("udpAddr", addr).
|
|
||||||
WithField("writer", q).
|
|
||||||
Error("Failed to write outgoing packet: no writer available")
|
|
||||||
pkt.Release()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := writer.WriteTo(pkt.Payload()[:pkt.Len], addr); err != nil {
|
|
||||||
hostinfo.logger(f.l).
|
|
||||||
WithError(err).
|
|
||||||
WithField("udpAddr", addr).
|
|
||||||
Error("Failed to write outgoing packet")
|
|
||||||
}
|
|
||||||
pkt.Release()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) writePacketToTun(q int, pkt *overlay.Packet) {
|
|
||||||
if pkt == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
writer := f.readers[q]
|
|
||||||
if writer == nil {
|
|
||||||
pkt.Release()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if bw, ok := writer.(interface {
|
|
||||||
WriteBatch([]*overlay.Packet) (int, error)
|
|
||||||
}); ok {
|
|
||||||
if _, err := bw.WriteBatch([]*overlay.Packet{pkt}); err != nil {
|
|
||||||
f.l.WithError(err).WithField("queue", q).Warn("Failed to write tun packet via batch writer")
|
|
||||||
pkt.Release()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, err := writer.Write(pkt.Payload()[:pkt.Len]); err != nil {
|
|
||||||
f.l.WithError(err).Error("Failed to write to tun")
|
|
||||||
}
|
|
||||||
pkt.Release()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) clonePacketWithHeadroom(pkt *overlay.Packet, required int) *overlay.Packet {
|
|
||||||
if pkt == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
payload := pkt.Payload()[:pkt.Len]
|
|
||||||
if len(payload) == 0 && required <= 0 {
|
|
||||||
return pkt
|
|
||||||
}
|
|
||||||
|
|
||||||
pool := f.batches.Pool()
|
|
||||||
if pool != nil {
|
|
||||||
if clone := pool.Get(); clone != nil {
|
|
||||||
if len(clone.Payload()) >= len(payload) {
|
|
||||||
clone.Len = copy(clone.Payload(), payload)
|
|
||||||
pkt.Release()
|
|
||||||
return clone
|
|
||||||
}
|
|
||||||
clone.Release()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if required < 0 {
|
|
||||||
required = 0
|
|
||||||
}
|
|
||||||
buf := make([]byte, required+len(payload))
|
|
||||||
n := copy(buf[required:], payload)
|
|
||||||
pkt.Release()
|
|
||||||
return &overlay.Packet{
|
|
||||||
Buf: buf,
|
|
||||||
Offset: required,
|
|
||||||
Len: n,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) observeUDPQueueLen(i int) {
|
|
||||||
if f.batchUDPQueueGauge == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
f.batchUDPQueueGauge.Update(int64(f.batches.txQueueLen(i)))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) observeTunQueueLen(i int) {
|
|
||||||
if f.batchTunQueueGauge == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
f.batchTunQueueGauge.Update(int64(f.batches.tunQueueLen(i)))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) currentBatchFlushInterval() time.Duration {
|
|
||||||
if v := f.batchFlushInterval.Load(); v > 0 {
|
|
||||||
return time.Duration(v)
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) ensurePacketHeadroom(pkt **overlay.Packet, required int, queue int, reason string) bool {
|
|
||||||
p := *pkt
|
|
||||||
if p == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if required <= 0 || p.Offset >= required {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
clone := f.clonePacketWithHeadroom(p, required)
|
|
||||||
if clone == nil {
|
|
||||||
f.l.WithFields(logrus.Fields{
|
|
||||||
"queue": queue,
|
|
||||||
"reason": reason,
|
|
||||||
}).Warn("dropping packet lacking tun headroom")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
*pkt = clone
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func isVirtioHeadroomError(err error) bool {
|
|
||||||
if err == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
msg := err.Error()
|
|
||||||
return strings.Contains(msg, "headroom") || strings.Contains(msg, "virtio")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) effectiveGSOMaxSegments() int {
|
|
||||||
max := f.gsoMaxSegments
|
|
||||||
if max <= 0 {
|
|
||||||
max = defaultGSOMaxSegments
|
|
||||||
}
|
|
||||||
if max > maxKernelGSOSegments {
|
|
||||||
max = maxKernelGSOSegments
|
|
||||||
}
|
|
||||||
if !f.enableGSO {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
return max
|
|
||||||
}
|
|
||||||
|
|
||||||
type udpOffloadConfigurator interface {
|
|
||||||
ConfigureOffload(enableGSO, enableGRO bool, maxSegments int)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) applyOffloadConfig(enableGSO, enableGRO bool, maxSegments int) {
|
|
||||||
if maxSegments <= 0 {
|
|
||||||
maxSegments = defaultGSOMaxSegments
|
|
||||||
}
|
|
||||||
if maxSegments > maxKernelGSOSegments {
|
|
||||||
maxSegments = maxKernelGSOSegments
|
|
||||||
}
|
|
||||||
f.enableGSO = enableGSO
|
|
||||||
f.enableGRO = enableGRO
|
|
||||||
f.gsoMaxSegments = maxSegments
|
|
||||||
for _, writer := range f.writers {
|
|
||||||
if cfg, ok := writer.(udpOffloadConfigurator); ok {
|
|
||||||
cfg.ConfigureOffload(enableGSO, enableGRO, maxSegments)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1062,42 +448,6 @@ func (f *Interface) reloadMisc(c *config.C) {
|
|||||||
f.reQueryWait.Store(int64(n))
|
f.reQueryWait.Store(int64(n))
|
||||||
f.l.Info("timers.requery_wait_duration has changed")
|
f.l.Info("timers.requery_wait_duration has changed")
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.HasChanged("listen.gso_flush_timeout") {
|
|
||||||
d := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushInterval)
|
|
||||||
if d < 0 {
|
|
||||||
d = 0
|
|
||||||
}
|
|
||||||
f.batchFlushInterval.Store(int64(d))
|
|
||||||
f.l.WithField("duration", d).Info("listen.gso_flush_timeout has changed")
|
|
||||||
} else if c.HasChanged("batch.flush_interval") {
|
|
||||||
d := c.GetDuration("batch.flush_interval", defaultGSOFlushInterval)
|
|
||||||
if d < 0 {
|
|
||||||
d = 0
|
|
||||||
}
|
|
||||||
f.batchFlushInterval.Store(int64(d))
|
|
||||||
f.l.WithField("duration", d).Warn("batch.flush_interval is deprecated; use listen.gso_flush_timeout")
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.HasChanged("batch.queue_depth") {
|
|
||||||
n := c.GetInt("batch.queue_depth", f.batchQueueDepth)
|
|
||||||
if n != f.batchQueueDepth {
|
|
||||||
f.batchQueueDepth = n
|
|
||||||
f.l.Warn("batch.queue_depth changes require a restart to take effect")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.HasChanged("listen.enable_gso") || c.HasChanged("listen.enable_gro") || c.HasChanged("listen.gso_max_segments") {
|
|
||||||
enableGSO := c.GetBool("listen.enable_gso", f.enableGSO)
|
|
||||||
enableGRO := c.GetBool("listen.enable_gro", f.enableGRO)
|
|
||||||
maxSeg := c.GetInt("listen.gso_max_segments", f.gsoMaxSegments)
|
|
||||||
f.applyOffloadConfig(enableGSO, enableGRO, maxSeg)
|
|
||||||
f.l.WithFields(logrus.Fields{
|
|
||||||
"enableGSO": enableGSO,
|
|
||||||
"enableGRO": enableGRO,
|
|
||||||
"gsoMaxSegments": maxSeg,
|
|
||||||
}).Info("listen GSO/GRO configuration updated")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
||||||
|
|||||||
@@ -1337,12 +1337,19 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
remoteAllowList := lhh.lh.GetRemoteAllowList()
|
||||||
for _, a := range n.Details.V4AddrPorts {
|
for _, a := range n.Details.V4AddrPorts {
|
||||||
punch(protoV4AddrPortToNetAddrPort(a), detailsVpnAddr)
|
b := protoV4AddrPortToNetAddrPort(a)
|
||||||
|
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
|
||||||
|
punch(b, detailsVpnAddr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, a := range n.Details.V6AddrPorts {
|
for _, a := range n.Details.V6AddrPorts {
|
||||||
punch(protoV6AddrPortToNetAddrPort(a), detailsVpnAddr)
|
b := protoV6AddrPortToNetAddrPort(a)
|
||||||
|
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
|
||||||
|
punch(b, detailsVpnAddr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// This sends a nebula test packet to the host trying to contact us. In the case
|
// This sends a nebula test packet to the host trying to contact us. In the case
|
||||||
|
|||||||
40
main.go
40
main.go
@@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -76,7 +75,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
if c.GetBool("sshd.enabled", false) {
|
if c.GetBool("sshd.enabled", false) {
|
||||||
sshStart, err = configSSH(l, ssh, c)
|
sshStart, err = configSSH(l, ssh, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err)
|
l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available")
|
||||||
|
sshStart = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -144,20 +144,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
// set up our UDP listener
|
// set up our UDP listener
|
||||||
udpConns := make([]udp.Conn, routines)
|
udpConns := make([]udp.Conn, routines)
|
||||||
port := c.GetInt("listen.port", 0)
|
port := c.GetInt("listen.port", 0)
|
||||||
enableGSO := c.GetBool("listen.enable_gso", true)
|
|
||||||
enableGRO := c.GetBool("listen.enable_gro", true)
|
|
||||||
gsoMaxSegments := c.GetInt("listen.gso_max_segments", defaultGSOMaxSegments)
|
|
||||||
if gsoMaxSegments <= 0 {
|
|
||||||
gsoMaxSegments = defaultGSOMaxSegments
|
|
||||||
}
|
|
||||||
if gsoMaxSegments > maxKernelGSOSegments {
|
|
||||||
gsoMaxSegments = maxKernelGSOSegments
|
|
||||||
}
|
|
||||||
gsoFlushTimeout := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushInterval)
|
|
||||||
if gsoFlushTimeout < 0 {
|
|
||||||
gsoFlushTimeout = 0
|
|
||||||
}
|
|
||||||
batchQueueDepth := c.GetInt("batch.queue_depth", 0)
|
|
||||||
|
|
||||||
if !configTest {
|
if !configTest {
|
||||||
rawListenHost := c.GetString("listen.host", "0.0.0.0")
|
rawListenHost := c.GetString("listen.host", "0.0.0.0")
|
||||||
@@ -177,27 +163,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
listenHost = ips[0].Unmap()
|
listenHost = ips[0].Unmap()
|
||||||
}
|
}
|
||||||
|
|
||||||
useWGDefault := runtime.GOOS == "linux"
|
|
||||||
useWG := c.GetBool("listen.use_wireguard_stack", useWGDefault)
|
|
||||||
var mkListener func(*logrus.Logger, netip.Addr, int, bool, int) (udp.Conn, error)
|
|
||||||
if useWG {
|
|
||||||
mkListener = udp.NewWireguardListener
|
|
||||||
} else {
|
|
||||||
mkListener = udp.NewListener
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < routines; i++ {
|
for i := 0; i < routines; i++ {
|
||||||
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
|
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
|
||||||
udpServer, err := mkListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
|
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 128))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
||||||
}
|
}
|
||||||
udpServer.ReloadConfig(c)
|
udpServer.ReloadConfig(c)
|
||||||
if cfg, ok := udpServer.(interface {
|
|
||||||
ConfigureOffload(bool, bool, int)
|
|
||||||
}); ok {
|
|
||||||
cfg.ConfigureOffload(enableGSO, enableGRO, gsoMaxSegments)
|
|
||||||
}
|
|
||||||
udpConns[i] = udpServer
|
udpConns[i] = udpServer
|
||||||
|
|
||||||
// If port is dynamic, discover it before the next pass through the for loop
|
// If port is dynamic, discover it before the next pass through the for loop
|
||||||
@@ -265,17 +237,12 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
|
reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
|
||||||
DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false),
|
DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false),
|
||||||
DropMulticast: c.GetBool("tun.drop_multicast", false),
|
DropMulticast: c.GetBool("tun.drop_multicast", false),
|
||||||
EnableGSO: enableGSO,
|
|
||||||
EnableGRO: enableGRO,
|
|
||||||
GSOMaxSegments: gsoMaxSegments,
|
|
||||||
routines: routines,
|
routines: routines,
|
||||||
MessageMetrics: messageMetrics,
|
MessageMetrics: messageMetrics,
|
||||||
version: buildVersion,
|
version: buildVersion,
|
||||||
relayManager: NewRelayManager(ctx, l, hostMap, c),
|
relayManager: NewRelayManager(ctx, l, hostMap, c),
|
||||||
punchy: punchy,
|
punchy: punchy,
|
||||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||||
BatchFlushInterval: gsoFlushTimeout,
|
|
||||||
BatchQueueDepth: batchQueueDepth,
|
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -287,7 +254,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
}
|
}
|
||||||
|
|
||||||
ifce.writers = udpConns
|
ifce.writers = udpConns
|
||||||
ifce.applyOffloadConfig(enableGSO, enableGRO, gsoMaxSegments)
|
|
||||||
lightHouse.ifce = ifce
|
lightHouse.ifce = ifce
|
||||||
|
|
||||||
ifce.RegisterConfigChangeCallbacks(c)
|
ifce.RegisterConfigChangeCallbacks(c)
|
||||||
|
|||||||
169
outside.go
169
outside.go
@@ -12,7 +12,6 @@ import (
|
|||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/overlay"
|
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -20,7 +19,7 @@ const (
|
|||||||
minFwPacketLen = 4
|
minFwPacketLen = 4
|
||||||
)
|
)
|
||||||
|
|
||||||
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache *firewall.ConntrackCache) {
|
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
err := h.Parse(packet)
|
err := h.Parse(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
||||||
@@ -62,7 +61,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
|
|
||||||
switch h.Subtype {
|
switch h.Subtype {
|
||||||
case header.MessageNone:
|
case header.MessageNone:
|
||||||
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache, ip, h.RemoteIndex) {
|
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case header.MessageRelay:
|
case header.MessageRelay:
|
||||||
@@ -96,8 +95,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
switch relay.Type {
|
switch relay.Type {
|
||||||
case TerminalType:
|
case TerminalType:
|
||||||
// If I am the target of this relay, process the unwrapped packet
|
// If I am the target of this relay, process the unwrapped packet
|
||||||
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
|
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:virtioNetHdrLen], signedPayload, h, fwPacket, lhf, nb, q, localCache)
|
||||||
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
|
|
||||||
return
|
return
|
||||||
case ForwardingType:
|
case ForwardingType:
|
||||||
// Find the target HostInfo relay object
|
// Find the target HostInfo relay object
|
||||||
@@ -139,7 +137,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
|
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d[virtioNetHdrLen:], f)
|
||||||
|
|
||||||
// Fallthrough to the bottom to record incoming traffic
|
// Fallthrough to the bottom to record incoming traffic
|
||||||
|
|
||||||
@@ -161,7 +159,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
// This testRequest might be from TryPromoteBest, so we should roam
|
// This testRequest might be from TryPromoteBest, so we should roam
|
||||||
// to the new IP address before responding
|
// to the new IP address before responding
|
||||||
f.handleHostRoaming(hostinfo, ip)
|
f.handleHostRoaming(hostinfo, ip)
|
||||||
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
|
f.send(header.Test, header.TestReply, ci, hostinfo, d[virtioNetHdrLen:], nb, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallthrough to the bottom to record incoming traffic
|
// Fallthrough to the bottom to record incoming traffic
|
||||||
@@ -204,7 +202,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.relayManager.HandleControlMsg(hostinfo, d, f)
|
f.relayManager.HandleControlMsg(hostinfo, d[virtioNetHdrLen:], f)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
@@ -466,45 +464,25 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache *firewall.ConntrackCache, addr netip.AddrPort, recvIndex uint32) bool {
|
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
|
||||||
var (
|
var err error
|
||||||
err error
|
|
||||||
pkt *overlay.Packet
|
|
||||||
)
|
|
||||||
|
|
||||||
if f.batches.tunQueue(q) != nil {
|
|
||||||
pkt = f.batches.newPacket()
|
|
||||||
if pkt != nil {
|
|
||||||
out = pkt.Payload()[:0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if pkt != nil {
|
|
||||||
pkt.Release()
|
|
||||||
}
|
|
||||||
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
||||||
if addr.IsValid() {
|
|
||||||
f.maybeSendRecvError(addr, recvIndex)
|
|
||||||
}
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
err = newPacket(out, true, fwPacket)
|
packetData := out[virtioNetHdrLen:]
|
||||||
|
|
||||||
|
err = newPacket(packetData, true, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if pkt != nil {
|
hostinfo.logger(f.l).WithError(err).WithField("packet", packetData).
|
||||||
pkt.Release()
|
|
||||||
}
|
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
|
|
||||||
Warnf("Error while validating inbound packet")
|
Warnf("Error while validating inbound packet")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
|
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
|
||||||
if pkt != nil {
|
|
||||||
pkt.Release()
|
|
||||||
}
|
|
||||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
||||||
Debugln("dropping out of window packet")
|
Debugln("dropping out of window packet")
|
||||||
return false
|
return false
|
||||||
@@ -512,12 +490,9 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
|
|
||||||
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
|
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
|
||||||
if dropReason != nil {
|
if dropReason != nil {
|
||||||
if pkt != nil {
|
|
||||||
pkt.Release()
|
|
||||||
}
|
|
||||||
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
||||||
// This gives us a buffer to build the reject packet in
|
// This gives us a buffer to build the reject packet in
|
||||||
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
|
f.rejectOutside(packetData, hostinfo.ConnectionState, hostinfo, nb, packet, q)
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
||||||
WithField("reason", dropReason).
|
WithField("reason", dropReason).
|
||||||
@@ -527,17 +502,8 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
}
|
}
|
||||||
|
|
||||||
f.connectionManager.In(hostinfo)
|
f.connectionManager.In(hostinfo)
|
||||||
if pkt != nil {
|
_, err = f.readers[q].Write(out)
|
||||||
pkt.Len = len(out)
|
if err != nil {
|
||||||
if f.batches.enqueueTun(q, pkt) {
|
|
||||||
f.observeTunQueueLen(q)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
f.writePacketToTun(q, pkt)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err = f.readers[q].Write(out); err != nil {
|
|
||||||
f.l.WithError(err).Error("Failed to write to tun")
|
f.l.WithError(err).Error("Failed to write to tun")
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
@@ -583,3 +549,108 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
|
|||||||
// We also delete it from pending hostmap to allow for fast reconnect.
|
// We also delete it from pending hostmap to allow for fast reconnect.
|
||||||
f.handshakeManager.DeleteHostInfo(hostinfo)
|
f.handshakeManager.DeleteHostInfo(hostinfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// readOutsidePacketsBatch processes multiple packets received from UDP in a batch
|
||||||
|
// and writes all successfully decrypted packets to TUN in a single operation
|
||||||
|
func (f *Interface) readOutsidePacketsBatch(addrs []netip.AddrPort, payloads [][]byte, count int, outs [][]byte, nb []byte, q int, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, localCache firewall.ConntrackCache) {
|
||||||
|
// Pre-allocate slice for accumulating successful decryptions
|
||||||
|
tunPackets := make([][]byte, 0, count)
|
||||||
|
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
payload := payloads[i]
|
||||||
|
addr := addrs[i]
|
||||||
|
out := outs[i]
|
||||||
|
|
||||||
|
// Parse header
|
||||||
|
err := h.Parse(payload)
|
||||||
|
if err != nil {
|
||||||
|
if len(payload) > 1 {
|
||||||
|
f.l.WithField("packet", payload).Infof("Error while parsing inbound packet from %s: %s", addr, err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if addr.IsValid() {
|
||||||
|
if f.myVpnNetworksTable.Contains(addr.Addr()) {
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithField("udpAddr", addr).Debug("Refusing to process double encrypted packet")
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var hostinfo *HostInfo
|
||||||
|
if h.Type == header.Message && h.Subtype == header.MessageRelay {
|
||||||
|
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
|
||||||
|
} else {
|
||||||
|
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ci *ConnectionState
|
||||||
|
if hostinfo != nil {
|
||||||
|
ci = hostinfo.ConnectionState
|
||||||
|
}
|
||||||
|
|
||||||
|
switch h.Type {
|
||||||
|
case header.Message:
|
||||||
|
if !f.handleEncrypted(ci, addr, h) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch h.Subtype {
|
||||||
|
case header.MessageNone:
|
||||||
|
// Decrypt packet
|
||||||
|
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, payload[:header.Len], payload[header.Len:], h.MessageCounter, nb)
|
||||||
|
if err != nil {
|
||||||
|
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
packetData := out[virtioNetHdrLen:]
|
||||||
|
|
||||||
|
err = newPacket(packetData, true, fwPacket)
|
||||||
|
if err != nil {
|
||||||
|
hostinfo.logger(f.l).WithError(err).WithField("packet", packetData).Warnf("Error while validating inbound packet")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hostinfo.ConnectionState.window.Update(f.l, h.MessageCounter) {
|
||||||
|
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).Debugln("dropping out of window packet")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
|
||||||
|
if dropReason != nil {
|
||||||
|
f.rejectOutside(packetData, hostinfo.ConnectionState, hostinfo, nb, payload, q)
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).WithField("reason", dropReason).Debugln("dropping inbound packet")
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
f.connectionManager.In(hostinfo)
|
||||||
|
// Add to batch for TUN write
|
||||||
|
tunPackets = append(tunPackets, out)
|
||||||
|
|
||||||
|
case header.MessageRelay:
|
||||||
|
// Skip relay packets in batch mode for now (less common path)
|
||||||
|
f.readOutsidePackets(addr, nil, out, payload, h, fwPacket, lhf, nb, q, localCache)
|
||||||
|
|
||||||
|
default:
|
||||||
|
hostinfo.logger(f.l).Debugf("unexpected message subtype %d", h.Subtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Handle non-Message types using single-packet path
|
||||||
|
f.readOutsidePackets(addr, nil, out, payload, h, fwPacket, lhf, nb, q, localCache)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tunPackets) > 0 {
|
||||||
|
n, err := f.readers[q].WriteBatch(tunPackets, virtioNetHdrLen)
|
||||||
|
if err != nil {
|
||||||
|
f.l.WithError(err).WithField("sent", n).WithField("total", len(tunPackets)).Error("Failed to batch write to tun")
|
||||||
|
}
|
||||||
|
f.batchMetrics.tunWriteSize.Update(int64(len(tunPackets)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,97 +3,29 @@ package overlay
|
|||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Device interface {
|
// BatchReadWriter extends io.ReadWriteCloser with batch I/O operations
|
||||||
|
type BatchReadWriter interface {
|
||||||
io.ReadWriteCloser
|
io.ReadWriteCloser
|
||||||
|
|
||||||
|
// BatchRead reads multiple packets at once
|
||||||
|
BatchRead(bufs [][]byte, sizes []int) (int, error)
|
||||||
|
|
||||||
|
// WriteBatch writes multiple packets at once
|
||||||
|
WriteBatch(bufs [][]byte, offset int) (int, error)
|
||||||
|
|
||||||
|
// BatchSize returns the optimal batch size for this device
|
||||||
|
BatchSize() int
|
||||||
|
}
|
||||||
|
|
||||||
|
type Device interface {
|
||||||
|
BatchReadWriter
|
||||||
Activate() error
|
Activate() error
|
||||||
Networks() []netip.Prefix
|
Networks() []netip.Prefix
|
||||||
Name() string
|
Name() string
|
||||||
RoutesFor(netip.Addr) routing.Gateways
|
RoutesFor(netip.Addr) routing.Gateways
|
||||||
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
NewMultiQueueReader() (BatchReadWriter, error)
|
||||||
}
|
|
||||||
|
|
||||||
// Packet represents a single packet buffer with optional headroom to carry
|
|
||||||
// metadata (for example virtio-net headers).
|
|
||||||
type Packet struct {
|
|
||||||
Buf []byte
|
|
||||||
Offset int
|
|
||||||
Len int
|
|
||||||
release func()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Packet) Payload() []byte {
|
|
||||||
return p.Buf[p.Offset : p.Offset+p.Len]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Packet) Reset() {
|
|
||||||
p.Len = 0
|
|
||||||
p.Offset = 0
|
|
||||||
p.release = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Packet) Release() {
|
|
||||||
if p.release != nil {
|
|
||||||
p.release()
|
|
||||||
p.release = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Packet) Capacity() int {
|
|
||||||
return len(p.Buf) - p.Offset
|
|
||||||
}
|
|
||||||
|
|
||||||
// PacketPool manages reusable buffers with headroom.
|
|
||||||
type PacketPool struct {
|
|
||||||
headroom int
|
|
||||||
blksz int
|
|
||||||
pool sync.Pool
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPacketPool(headroom, payload int) *PacketPool {
|
|
||||||
p := &PacketPool{headroom: headroom, blksz: headroom + payload}
|
|
||||||
p.pool.New = func() any {
|
|
||||||
buf := make([]byte, p.blksz)
|
|
||||||
return &Packet{Buf: buf, Offset: headroom}
|
|
||||||
}
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PacketPool) Get() *Packet {
|
|
||||||
pkt := p.pool.Get().(*Packet)
|
|
||||||
pkt.Offset = p.headroom
|
|
||||||
pkt.Len = 0
|
|
||||||
pkt.release = func() { p.put(pkt) }
|
|
||||||
return pkt
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PacketPool) put(pkt *Packet) {
|
|
||||||
pkt.Reset()
|
|
||||||
p.pool.Put(pkt)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchReader allows reading multiple packets into a shared pool with
|
|
||||||
// preallocated headroom (e.g. virtio-net headers).
|
|
||||||
type BatchReader interface {
|
|
||||||
ReadIntoBatch(pool *PacketPool) ([]*Packet, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchWriter writes a slice of packets that carry their own metadata.
|
|
||||||
type BatchWriter interface {
|
|
||||||
WriteBatch(packets []*Packet) (int, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchCapableDevice describes a device that can efficiently read and write
|
|
||||||
// batches of packets with virtio headroom.
|
|
||||||
type BatchCapableDevice interface {
|
|
||||||
Device
|
|
||||||
BatchReader
|
|
||||||
BatchWriter
|
|
||||||
BatchHeadroom() int
|
|
||||||
BatchPayloadCap() int
|
|
||||||
BatchSize() int
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const DefaultMTU = 1300
|
const DefaultMTU = 1300
|
||||||
|
const VirtioNetHdrLen = 10 // Size of virtio_net_hdr structure
|
||||||
|
|
||||||
// TODO: We may be able to remove routines
|
// TODO: We may be able to remove routines
|
||||||
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
||||||
|
|||||||
@@ -95,6 +95,29 @@ func (t *tun) Name() string {
|
|||||||
return "android"
|
return "android"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||||
|
n, err := t.Read(bufs[0])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
sizes[0] = n
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
||||||
|
for i, buf := range bufs {
|
||||||
|
_, err := t.Write(buf[offset:])
|
||||||
|
if err != nil {
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(bufs), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) BatchSize() int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|||||||
@@ -549,6 +549,32 @@ func (t *tun) Name() string {
|
|||||||
return t.Device
|
return t.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BatchRead reads a single packet (batch size 1 for non-Linux platforms)
|
||||||
|
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||||
|
n, err := t.Read(bufs[0])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
sizes[0] = n
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteBatch writes packets individually (no batching for non-Linux platforms)
|
||||||
|
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
||||||
|
for i, buf := range bufs {
|
||||||
|
_, err := t.Write(buf[offset:])
|
||||||
|
if err != nil {
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(bufs), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchSize returns 1 for non-Linux platforms (no batching)
|
||||||
|
func (t *tun) BatchSize() int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|||||||
@@ -105,10 +105,36 @@ func (t *disabledTun) Write(b []byte) (int, error) {
|
|||||||
return len(b), nil
|
return len(b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *disabledTun) NewMultiQueueReader() (BatchReadWriter, error) {
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BatchRead reads a single packet (batch size 1 for disabled tun)
|
||||||
|
func (t *disabledTun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||||
|
n, err := t.Read(bufs[0])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
sizes[0] = n
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteBatch writes packets individually (no batching for disabled tun)
|
||||||
|
func (t *disabledTun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
||||||
|
for i, buf := range bufs {
|
||||||
|
_, err := t.Write(buf[offset:])
|
||||||
|
if err != nil {
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(bufs), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchSize returns 1 for disabled tun (no batching)
|
||||||
|
func (t *disabledTun) BatchSize() int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
func (t *disabledTun) Close() error {
|
func (t *disabledTun) Close() error {
|
||||||
if t.read != nil {
|
if t.read != nil {
|
||||||
close(t.read)
|
close(t.read)
|
||||||
|
|||||||
@@ -450,10 +450,36 @@ func (t *tun) Name() string {
|
|||||||
return t.Device
|
return t.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BatchRead reads a single packet (batch size 1 for FreeBSD)
|
||||||
|
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||||
|
n, err := t.Read(bufs[0])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
sizes[0] = n
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteBatch writes packets individually (no batching for FreeBSD)
|
||||||
|
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
||||||
|
for i, buf := range bufs {
|
||||||
|
_, err := t.Write(buf[offset:])
|
||||||
|
if err != nil {
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(bufs), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchSize returns 1 for FreeBSD (no batching)
|
||||||
|
func (t *tun) BatchSize() int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) addRoutes(logErrors bool) error {
|
func (t *tun) addRoutes(logErrors bool) error {
|
||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
|
|||||||
@@ -151,6 +151,29 @@ func (t *tun) Name() string {
|
|||||||
return "iOS"
|
return "iOS"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||||
|
n, err := t.Read(bufs[0])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
sizes[0] = n
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
||||||
|
for i, buf := range bufs {
|
||||||
|
_, err := t.Write(buf[offset:])
|
||||||
|
if err != nil {
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(bufs), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) BatchSize() int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
@@ -20,13 +18,14 @@ import (
|
|||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
wgtun "github.com/slackhq/nebula/wgstack/tun"
|
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
wgtun "golang.zx2c4.com/wireguard/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
io.ReadWriteCloser
|
io.ReadWriteCloser
|
||||||
|
wgDevice wgtun.Device
|
||||||
fd int
|
fd int
|
||||||
Device string
|
Device string
|
||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
@@ -35,7 +34,6 @@ type tun struct {
|
|||||||
TXQueueLen int
|
TXQueueLen int
|
||||||
deviceIndex int
|
deviceIndex int
|
||||||
ioctlFd uintptr
|
ioctlFd uintptr
|
||||||
wgDevice wgtun.Device
|
|
||||||
|
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
@@ -68,107 +66,169 @@ type ifreqQLEN struct {
|
|||||||
pad [8]byte
|
pad [8]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
// wgDeviceWrapper wraps a wireguard Device to implement io.ReadWriteCloser
|
||||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
// This allows multiqueue readers to use the same wireguard Device batching as the main device
|
||||||
|
type wgDeviceWrapper struct {
|
||||||
useWGDefault := runtime.GOOS == "linux"
|
dev wgtun.Device
|
||||||
useWG := c.GetBool("tun.use_wireguard_stack", c.GetBool("listen.use_wireguard_stack", useWGDefault))
|
buf []byte // Reusable buffer for single packet reads
|
||||||
t, err := newTunGeneric(c, l, file, vpnNetworks, useWG)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Device = "tun0"
|
|
||||||
|
|
||||||
return t, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
func (w *wgDeviceWrapper) Read(b []byte) (int, error) {
|
||||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
// Use wireguard Device's batch API for single packet
|
||||||
|
bufs := [][]byte{b}
|
||||||
|
sizes := make([]int, 1)
|
||||||
|
n, err := w.dev.Read(bufs, sizes, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
|
return 0, err
|
||||||
if os.IsNotExist(err) {
|
|
||||||
err = os.MkdirAll("/dev/net", 0755)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
|
|
||||||
}
|
|
||||||
err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200)))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if n == 0 {
|
||||||
var req ifReq
|
return 0, io.EOF
|
||||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
|
|
||||||
if multiqueue {
|
|
||||||
req.Flags |= unix.IFF_MULTI_QUEUE
|
|
||||||
}
|
}
|
||||||
copy(req.Name[:], c.GetString("tun.dev", ""))
|
return sizes[0], nil
|
||||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
}
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
name := strings.Trim(string(req.Name[:]), "\x00")
|
|
||||||
|
|
||||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
func (w *wgDeviceWrapper) Write(b []byte) (int, error) {
|
||||||
useWGDefault := runtime.GOOS == "linux"
|
// Buffer b should have virtio header space (10 bytes) at the beginning
|
||||||
useWG := c.GetBool("tun.use_wireguard_stack", c.GetBool("listen.use_wireguard_stack", useWGDefault))
|
// The decrypted packet data starts at offset 10
|
||||||
t, err := newTunGeneric(c, l, file, vpnNetworks, useWG)
|
// Pass the full buffer to WireGuard with offset=virtioNetHdrLen
|
||||||
|
bufs := [][]byte{b}
|
||||||
|
n, err := w.dev.Write(bufs, VirtioNetHdrLen)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
return 0, io.ErrShortWrite
|
||||||
|
}
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wgDeviceWrapper) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
||||||
|
// Pass all buffers to WireGuard's batch write
|
||||||
|
return w.dev.Write(bufs, offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wgDeviceWrapper) Close() error {
|
||||||
|
return w.dev.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchRead implements batching for multiqueue readers
|
||||||
|
func (w *wgDeviceWrapper) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||||
|
// The zero here is offset.
|
||||||
|
return w.dev.Read(bufs, sizes, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchSize returns the optimal batch size
|
||||||
|
func (w *wgDeviceWrapper) BatchSize() int {
|
||||||
|
return w.dev.BatchSize()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
|
wgDev, name, err := wgtun.CreateUnmonitoredTUNFromFD(deviceFd)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create TUN from FD: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
file := wgDev.File()
|
||||||
|
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||||
|
if err != nil {
|
||||||
|
_ = wgDev.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.wgDevice = wgDev
|
||||||
t.Device = name
|
t.Device = name
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix, useWireguard bool) (*tun, error) {
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
||||||
var (
|
// Check if /dev/net/tun exists, create if needed (for docker containers)
|
||||||
rw io.ReadWriteCloser = file
|
if _, err := os.Stat("/dev/net/tun"); os.IsNotExist(err) {
|
||||||
fd = int(file.Fd())
|
if err := os.MkdirAll("/dev/net", 0755); err != nil {
|
||||||
wgDev wgtun.Device
|
return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
|
||||||
)
|
}
|
||||||
|
if err := unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))); err != nil {
|
||||||
if useWireguard {
|
return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
|
||||||
dev, err := wgtun.CreateTUNFromFile(file, c.GetInt("tun.mtu", DefaultMTU))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to initialize wireguard tun device: %w", err)
|
|
||||||
}
|
}
|
||||||
wgDev = dev
|
|
||||||
rw = newWireguardTunIO(dev, c.GetInt("tun.mtu", DefaultMTU))
|
|
||||||
fd = int(dev.File().Fd())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
devName := c.GetString("tun.dev", "")
|
||||||
|
mtu := c.GetInt("tun.mtu", DefaultMTU)
|
||||||
|
|
||||||
|
// Create TUN device manually to support multiqueue
|
||||||
|
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var req ifReq
|
||||||
|
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR)
|
||||||
|
if multiqueue {
|
||||||
|
req.Flags |= unix.IFF_MULTI_QUEUE
|
||||||
|
}
|
||||||
|
copy(req.Name[:], devName)
|
||||||
|
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||||
|
unix.Close(fd)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set nonblocking
|
||||||
|
if err = unix.SetNonblock(fd, true); err != nil {
|
||||||
|
unix.Close(fd)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable TCP and UDP offload (TSO/GRO) for performance
|
||||||
|
// This allows the kernel to handle segmentation/coalescing
|
||||||
|
const (
|
||||||
|
tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
|
||||||
|
tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
|
||||||
|
)
|
||||||
|
offloads := tunTCPOffloads | tunUDPOffloads
|
||||||
|
if err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, offloads); err != nil {
|
||||||
|
// Log warning but don't fail - offload is optional
|
||||||
|
l.WithError(err).Warn("Failed to enable TUN offload (TSO/GRO), performance may be reduced")
|
||||||
|
}
|
||||||
|
|
||||||
|
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||||
|
|
||||||
|
// Create wireguard device from file descriptor
|
||||||
|
wgDev, err := wgtun.CreateTUNFromFile(file, mtu)
|
||||||
|
if err != nil {
|
||||||
|
file.Close()
|
||||||
|
return nil, fmt.Errorf("failed to create TUN from file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
name, err := wgDev.Name()
|
||||||
|
if err != nil {
|
||||||
|
_ = wgDev.Close()
|
||||||
|
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// file is now owned by wgDev, get a new reference
|
||||||
|
file = wgDev.File()
|
||||||
|
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||||
|
if err != nil {
|
||||||
|
_ = wgDev.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
t.wgDevice = wgDev
|
||||||
|
t.Device = name
|
||||||
|
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
t := &tun{
|
t := &tun{
|
||||||
ReadWriteCloser: rw,
|
ReadWriteCloser: file,
|
||||||
fd: fd,
|
fd: int(file.Fd()),
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||||
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
||||||
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
if wgDev != nil {
|
|
||||||
t.wgDevice = wgDev
|
|
||||||
}
|
|
||||||
if wgDev != nil {
|
|
||||||
// replace ioctl fd with device file descriptor to keep route management working
|
|
||||||
file = wgDev.File()
|
|
||||||
t.fd = int(file.Fd())
|
|
||||||
t.ioctlFd = file.Fd()
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.ioctlFd == 0 {
|
|
||||||
t.ioctlFd = file.Fd()
|
|
||||||
}
|
|
||||||
|
|
||||||
err := t.reload(c, true)
|
err := t.reload(c, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -252,22 +312,44 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
|
||||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var req ifReq
|
var req ifReq
|
||||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
|
// MUST match the flags used in newTun - includes IFF_VNET_HDR
|
||||||
|
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR | unix.IFF_MULTI_QUEUE)
|
||||||
copy(req.Name[:], t.Device)
|
copy(req.Name[:], t.Device)
|
||||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||||
|
unix.Close(fd)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set nonblocking mode - CRITICAL for proper netpoller integration
|
||||||
|
if err = unix.SetNonblock(fd, true); err != nil {
|
||||||
|
unix.Close(fd)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get MTU from main device
|
||||||
|
mtu := t.MaxMTU
|
||||||
|
if mtu == 0 {
|
||||||
|
mtu = DefaultMTU
|
||||||
|
}
|
||||||
|
|
||||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||||
|
|
||||||
return file, nil
|
// Create wireguard Device from the file descriptor (just like the main device)
|
||||||
|
wgDev, err := wgtun.CreateTUNFromFile(file, mtu)
|
||||||
|
if err != nil {
|
||||||
|
file.Close()
|
||||||
|
return nil, fmt.Errorf("failed to create multiqueue TUN device: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return a wrapper that uses the wireguard Device for all I/O
|
||||||
|
return &wgDeviceWrapper{dev: wgDev}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||||
@@ -275,7 +357,68 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) Read(b []byte) (int, error) {
|
||||||
|
if t.wgDevice != nil {
|
||||||
|
// Use wireguard device which handles virtio headers internally
|
||||||
|
bufs := [][]byte{b}
|
||||||
|
sizes := make([]int, 1)
|
||||||
|
n, err := t.wgDevice.Read(bufs, sizes, 0)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
return sizes[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: direct read from file (shouldn't happen in normal operation)
|
||||||
|
return t.ReadWriteCloser.Read(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchRead reads multiple packets at once for improved performance
|
||||||
|
// bufs: slice of buffers to read into
|
||||||
|
// sizes: slice that will be filled with packet sizes
|
||||||
|
// Returns number of packets read
|
||||||
|
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||||
|
if t.wgDevice != nil {
|
||||||
|
return t.wgDevice.Read(bufs, sizes, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: single packet read
|
||||||
|
n, err := t.ReadWriteCloser.Read(bufs[0])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
sizes[0] = n
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchSize returns the optimal number of packets to read/write in a batch
|
||||||
|
func (t *tun) BatchSize() int {
|
||||||
|
if t.wgDevice != nil {
|
||||||
|
return t.wgDevice.BatchSize()
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) Write(b []byte) (int, error) {
|
func (t *tun) Write(b []byte) (int, error) {
|
||||||
|
if t.wgDevice != nil {
|
||||||
|
// Buffer b should have virtio header space (10 bytes) at the beginning
|
||||||
|
// The decrypted packet data starts at offset 10
|
||||||
|
// Pass the full buffer to WireGuard with offset=virtioNetHdrLen
|
||||||
|
bufs := [][]byte{b}
|
||||||
|
n, err := t.wgDevice.Write(bufs, VirtioNetHdrLen)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
return 0, io.ErrShortWrite
|
||||||
|
}
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: direct write (shouldn't happen in normal operation)
|
||||||
var nn int
|
var nn int
|
||||||
maximum := len(b)
|
maximum := len(b)
|
||||||
|
|
||||||
@@ -298,6 +441,22 @@ func (t *tun) Write(b []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WriteBatch writes multiple packets to the TUN device in a single syscall
|
||||||
|
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
||||||
|
if t.wgDevice != nil {
|
||||||
|
return t.wgDevice.Write(bufs, offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: write individually (shouldn't happen in normal operation)
|
||||||
|
for i, buf := range bufs {
|
||||||
|
_, err := t.Write(buf)
|
||||||
|
if err != nil {
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(bufs), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) deviceBytes() (o [16]byte) {
|
func (t *tun) deviceBytes() (o [16]byte) {
|
||||||
for i, c := range t.Device {
|
for i, c := range t.Device {
|
||||||
o[i] = byte(c)
|
o[i] = byte(c)
|
||||||
@@ -710,16 +869,12 @@ func (t *tun) Close() error {
|
|||||||
close(t.routeChan)
|
close(t.routeChan)
|
||||||
}
|
}
|
||||||
|
|
||||||
if t.ReadWriteCloser != nil {
|
|
||||||
_ = t.ReadWriteCloser.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.wgDevice != nil {
|
if t.wgDevice != nil {
|
||||||
_ = t.wgDevice.Close()
|
_ = t.wgDevice.Close()
|
||||||
if t.ioctlFd > 0 {
|
}
|
||||||
// underlying fd already closed by the device
|
|
||||||
t.ioctlFd = 0
|
if t.ReadWriteCloser != nil {
|
||||||
}
|
_ = t.ReadWriteCloser.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
if t.ioctlFd > 0 {
|
if t.ioctlFd > 0 {
|
||||||
|
|||||||
@@ -1,56 +0,0 @@
|
|||||||
//go:build linux && !android && !e2e_testing
|
|
||||||
|
|
||||||
package overlay
|
|
||||||
|
|
||||||
import "fmt"
|
|
||||||
|
|
||||||
func (t *tun) batchIO() (*wireguardTunIO, bool) {
|
|
||||||
io, ok := t.ReadWriteCloser.(*wireguardTunIO)
|
|
||||||
return io, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) ReadIntoBatch(pool *PacketPool) ([]*Packet, error) {
|
|
||||||
io, ok := t.batchIO()
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("wireguard batch I/O not enabled")
|
|
||||||
}
|
|
||||||
return io.ReadIntoBatch(pool)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) WriteBatch(packets []*Packet) (int, error) {
|
|
||||||
io, ok := t.batchIO()
|
|
||||||
if ok {
|
|
||||||
return io.WriteBatch(packets)
|
|
||||||
}
|
|
||||||
for _, pkt := range packets {
|
|
||||||
if pkt == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if _, err := t.Write(pkt.Payload()[:pkt.Len]); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
pkt.Release()
|
|
||||||
}
|
|
||||||
return len(packets), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) BatchHeadroom() int {
|
|
||||||
if io, ok := t.batchIO(); ok {
|
|
||||||
return io.BatchHeadroom()
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) BatchPayloadCap() int {
|
|
||||||
if io, ok := t.batchIO(); ok {
|
|
||||||
return io.BatchPayloadCap()
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) BatchSize() int {
|
|
||||||
if io, ok := t.batchIO(); ok {
|
|
||||||
return io.BatchSize()
|
|
||||||
}
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
@@ -390,10 +390,33 @@ func (t *tun) Name() string {
|
|||||||
return t.Device
|
return t.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||||
|
n, err := t.Read(bufs[0])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
sizes[0] = n
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
||||||
|
for i, buf := range bufs {
|
||||||
|
_, err := t.Write(buf[offset:])
|
||||||
|
if err != nil {
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(bufs), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) BatchSize() int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) addRoutes(logErrors bool) error {
|
func (t *tun) addRoutes(logErrors bool) error {
|
||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
|
|
||||||
|
|||||||
@@ -310,10 +310,33 @@ func (t *tun) Name() string {
|
|||||||
return t.Device
|
return t.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||||
|
n, err := t.Read(bufs[0])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
sizes[0] = n
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
||||||
|
for i, buf := range bufs {
|
||||||
|
_, err := t.Write(buf[offset:])
|
||||||
|
if err != nil {
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(bufs), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) BatchSize() int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) addRoutes(logErrors bool) error {
|
func (t *tun) addRoutes(logErrors bool) error {
|
||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
|
|
||||||
|
|||||||
@@ -132,6 +132,29 @@ func (t *TestTun) Read(b []byte) (int, error) {
|
|||||||
return len(p), nil
|
return len(p), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *TestTun) NewMultiQueueReader() (BatchReadWriter, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *TestTun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||||
|
n, err := t.Read(bufs[0])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
sizes[0] = n
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TestTun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
||||||
|
for i, buf := range bufs {
|
||||||
|
_, err := t.Write(buf[offset:])
|
||||||
|
if err != nil {
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(bufs), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TestTun) BatchSize() int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ package overlay
|
|||||||
import (
|
import (
|
||||||
"crypto"
|
"crypto"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -234,10 +233,36 @@ func (t *winTun) Write(b []byte) (int, error) {
|
|||||||
return t.tun.Write(b, 0)
|
return t.tun.Write(b, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *winTun) NewMultiQueueReader() (BatchReadWriter, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BatchRead reads a single packet (batch size 1 for Windows)
|
||||||
|
func (t *winTun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||||
|
n, err := t.Read(bufs[0])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
sizes[0] = n
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteBatch writes packets individually (no batching for Windows)
|
||||||
|
func (t *winTun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
||||||
|
for i, buf := range bufs {
|
||||||
|
_, err := t.Write(buf[offset:])
|
||||||
|
if err != nil {
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(bufs), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchSize returns 1 for Windows (no batching)
|
||||||
|
func (t *winTun) BatchSize() int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
func (t *winTun) Close() error {
|
func (t *winTun) Close() error {
|
||||||
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
|
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
|
||||||
// so to be certain, just remove everything before destroying.
|
// so to be certain, just remove everything before destroying.
|
||||||
|
|||||||
@@ -46,10 +46,36 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
|
|||||||
return routing.Gateways{routing.NewGateway(ip, 1)}
|
return routing.Gateways{routing.NewGateway(ip, 1)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (d *UserDevice) NewMultiQueueReader() (BatchReadWriter, error) {
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BatchRead reads a single packet (batch size 1 for UserDevice)
|
||||||
|
func (d *UserDevice) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||||
|
n, err := d.Read(bufs[0])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
sizes[0] = n
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteBatch writes packets individually (no batching for UserDevice)
|
||||||
|
func (d *UserDevice) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
||||||
|
for i, buf := range bufs {
|
||||||
|
_, err := d.Write(buf[offset:])
|
||||||
|
if err != nil {
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(bufs), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchSize returns 1 for UserDevice (no batching)
|
||||||
|
func (d *UserDevice) BatchSize() int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) {
|
func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) {
|
||||||
return d.inboundReader, d.outboundWriter
|
return d.inboundReader, d.outboundWriter
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,220 +0,0 @@
|
|||||||
//go:build linux && !android && !e2e_testing
|
|
||||||
|
|
||||||
package overlay
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
wgtun "github.com/slackhq/nebula/wgstack/tun"
|
|
||||||
)
|
|
||||||
|
|
||||||
type wireguardTunIO struct {
|
|
||||||
dev wgtun.Device
|
|
||||||
mtu int
|
|
||||||
batchSize int
|
|
||||||
|
|
||||||
readMu sync.Mutex
|
|
||||||
readBuffers [][]byte
|
|
||||||
readLens []int
|
|
||||||
legacyBuf []byte
|
|
||||||
|
|
||||||
writeMu sync.Mutex
|
|
||||||
writeBuf []byte
|
|
||||||
writeWrap [][]byte
|
|
||||||
writeBuffers [][]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func newWireguardTunIO(dev wgtun.Device, mtu int) *wireguardTunIO {
|
|
||||||
batch := dev.BatchSize()
|
|
||||||
if batch <= 0 {
|
|
||||||
batch = 1
|
|
||||||
}
|
|
||||||
if mtu <= 0 {
|
|
||||||
mtu = DefaultMTU
|
|
||||||
}
|
|
||||||
return &wireguardTunIO{
|
|
||||||
dev: dev,
|
|
||||||
mtu: mtu,
|
|
||||||
batchSize: batch,
|
|
||||||
readLens: make([]int, batch),
|
|
||||||
legacyBuf: make([]byte, wgtun.VirtioNetHdrLen+mtu),
|
|
||||||
writeBuf: make([]byte, wgtun.VirtioNetHdrLen+mtu),
|
|
||||||
writeWrap: make([][]byte, 1),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *wireguardTunIO) Read(p []byte) (int, error) {
|
|
||||||
w.readMu.Lock()
|
|
||||||
defer w.readMu.Unlock()
|
|
||||||
|
|
||||||
bufs := w.readBuffers
|
|
||||||
if len(bufs) == 0 {
|
|
||||||
bufs = [][]byte{w.legacyBuf}
|
|
||||||
w.readBuffers = bufs
|
|
||||||
}
|
|
||||||
n, err := w.dev.Read(bufs[:1], w.readLens[:1], wgtun.VirtioNetHdrLen)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if n == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
length := w.readLens[0]
|
|
||||||
copy(p, w.legacyBuf[wgtun.VirtioNetHdrLen:wgtun.VirtioNetHdrLen+length])
|
|
||||||
return length, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *wireguardTunIO) Write(p []byte) (int, error) {
|
|
||||||
if len(p) > w.mtu {
|
|
||||||
return 0, fmt.Errorf("wireguard tun: payload exceeds MTU (%d > %d)", len(p), w.mtu)
|
|
||||||
}
|
|
||||||
w.writeMu.Lock()
|
|
||||||
defer w.writeMu.Unlock()
|
|
||||||
buf := w.writeBuf[:wgtun.VirtioNetHdrLen+len(p)]
|
|
||||||
for i := 0; i < wgtun.VirtioNetHdrLen; i++ {
|
|
||||||
buf[i] = 0
|
|
||||||
}
|
|
||||||
copy(buf[wgtun.VirtioNetHdrLen:], p)
|
|
||||||
w.writeWrap[0] = buf
|
|
||||||
n, err := w.dev.Write(w.writeWrap, wgtun.VirtioNetHdrLen)
|
|
||||||
if err != nil {
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *wireguardTunIO) ReadIntoBatch(pool *PacketPool) ([]*Packet, error) {
|
|
||||||
if pool == nil {
|
|
||||||
return nil, fmt.Errorf("wireguard tun: packet pool is nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
w.readMu.Lock()
|
|
||||||
defer w.readMu.Unlock()
|
|
||||||
|
|
||||||
if len(w.readBuffers) < w.batchSize {
|
|
||||||
w.readBuffers = make([][]byte, w.batchSize)
|
|
||||||
}
|
|
||||||
if len(w.readLens) < w.batchSize {
|
|
||||||
w.readLens = make([]int, w.batchSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
packets := make([]*Packet, w.batchSize)
|
|
||||||
requiredHeadroom := w.BatchHeadroom()
|
|
||||||
requiredPayload := w.BatchPayloadCap()
|
|
||||||
headroom := 0
|
|
||||||
for i := 0; i < w.batchSize; i++ {
|
|
||||||
pkt := pool.Get()
|
|
||||||
if pkt == nil {
|
|
||||||
releasePackets(packets[:i])
|
|
||||||
return nil, fmt.Errorf("wireguard tun: packet pool returned nil packet")
|
|
||||||
}
|
|
||||||
if pkt.Capacity() < requiredPayload {
|
|
||||||
pkt.Release()
|
|
||||||
releasePackets(packets[:i])
|
|
||||||
return nil, fmt.Errorf("wireguard tun: packet capacity %d below required %d", pkt.Capacity(), requiredPayload)
|
|
||||||
}
|
|
||||||
if i == 0 {
|
|
||||||
headroom = pkt.Offset
|
|
||||||
if headroom < requiredHeadroom {
|
|
||||||
pkt.Release()
|
|
||||||
releasePackets(packets[:i])
|
|
||||||
return nil, fmt.Errorf("wireguard tun: packet headroom %d below virtio requirement %d", headroom, requiredHeadroom)
|
|
||||||
}
|
|
||||||
} else if pkt.Offset != headroom {
|
|
||||||
pkt.Release()
|
|
||||||
releasePackets(packets[:i])
|
|
||||||
return nil, fmt.Errorf("wireguard tun: inconsistent packet headroom (%d != %d)", pkt.Offset, headroom)
|
|
||||||
}
|
|
||||||
packets[i] = pkt
|
|
||||||
w.readBuffers[i] = pkt.Buf
|
|
||||||
}
|
|
||||||
|
|
||||||
n, err := w.dev.Read(w.readBuffers[:w.batchSize], w.readLens[:w.batchSize], headroom)
|
|
||||||
if err != nil {
|
|
||||||
releasePackets(packets)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if n == 0 {
|
|
||||||
releasePackets(packets)
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
for i := 0; i < n; i++ {
|
|
||||||
packets[i].Len = w.readLens[i]
|
|
||||||
}
|
|
||||||
for i := n; i < w.batchSize; i++ {
|
|
||||||
packets[i].Release()
|
|
||||||
packets[i] = nil
|
|
||||||
}
|
|
||||||
return packets[:n], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *wireguardTunIO) WriteBatch(packets []*Packet) (int, error) {
|
|
||||||
if len(packets) == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
requiredHeadroom := w.BatchHeadroom()
|
|
||||||
offset := packets[0].Offset
|
|
||||||
if offset < requiredHeadroom {
|
|
||||||
releasePackets(packets)
|
|
||||||
return 0, fmt.Errorf("wireguard tun: packet offset %d smaller than required headroom %d", offset, requiredHeadroom)
|
|
||||||
}
|
|
||||||
for _, pkt := range packets {
|
|
||||||
if pkt == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if pkt.Offset != offset {
|
|
||||||
releasePackets(packets)
|
|
||||||
return 0, fmt.Errorf("wireguard tun: mixed packet offsets not supported")
|
|
||||||
}
|
|
||||||
limit := pkt.Offset + pkt.Len
|
|
||||||
if limit > len(pkt.Buf) {
|
|
||||||
releasePackets(packets)
|
|
||||||
return 0, fmt.Errorf("wireguard tun: packet length %d exceeds buffer capacity %d", pkt.Len, len(pkt.Buf)-pkt.Offset)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
w.writeMu.Lock()
|
|
||||||
defer w.writeMu.Unlock()
|
|
||||||
|
|
||||||
if len(w.writeBuffers) < len(packets) {
|
|
||||||
w.writeBuffers = make([][]byte, len(packets))
|
|
||||||
}
|
|
||||||
for i, pkt := range packets {
|
|
||||||
if pkt == nil {
|
|
||||||
w.writeBuffers[i] = nil
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
limit := pkt.Offset + pkt.Len
|
|
||||||
w.writeBuffers[i] = pkt.Buf[:limit]
|
|
||||||
}
|
|
||||||
n, err := w.dev.Write(w.writeBuffers[:len(packets)], offset)
|
|
||||||
if err != nil {
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
releasePackets(packets)
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *wireguardTunIO) BatchHeadroom() int {
|
|
||||||
return wgtun.VirtioNetHdrLen
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *wireguardTunIO) BatchPayloadCap() int {
|
|
||||||
return w.mtu
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *wireguardTunIO) BatchSize() int {
|
|
||||||
return w.batchSize
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *wireguardTunIO) Close() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func releasePackets(pkts []*Packet) {
|
|
||||||
for _, pkt := range pkts {
|
|
||||||
if pkt != nil {
|
|
||||||
pkt.Release()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
1
stats.go
1
stats.go
@@ -6,6 +6,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
_ "net/http/pprof"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|||||||
30
udp/conn.go
30
udp/conn.go
@@ -13,27 +13,24 @@ type EncReader func(
|
|||||||
payload []byte,
|
payload []byte,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type EncBatchReader func(
|
||||||
|
addrs []netip.AddrPort,
|
||||||
|
payloads [][]byte,
|
||||||
|
count int,
|
||||||
|
)
|
||||||
|
|
||||||
type Conn interface {
|
type Conn interface {
|
||||||
Rebind() error
|
Rebind() error
|
||||||
LocalAddr() (netip.AddrPort, error)
|
LocalAddr() (netip.AddrPort, error)
|
||||||
ListenOut(r EncReader)
|
ListenOut(r EncReader)
|
||||||
|
ListenOutBatch(r EncBatchReader)
|
||||||
WriteTo(b []byte, addr netip.AddrPort) error
|
WriteTo(b []byte, addr netip.AddrPort) error
|
||||||
|
WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error)
|
||||||
ReloadConfig(c *config.C)
|
ReloadConfig(c *config.C)
|
||||||
|
BatchSize() int
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Datagram represents a UDP payload destined to a specific address.
|
|
||||||
type Datagram struct {
|
|
||||||
Payload []byte
|
|
||||||
Addr netip.AddrPort
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchConn can send multiple datagrams in one syscall.
|
|
||||||
type BatchConn interface {
|
|
||||||
Conn
|
|
||||||
WriteBatch(pkts []Datagram) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type NoopConn struct{}
|
type NoopConn struct{}
|
||||||
|
|
||||||
func (NoopConn) Rebind() error {
|
func (NoopConn) Rebind() error {
|
||||||
@@ -45,12 +42,21 @@ func (NoopConn) LocalAddr() (netip.AddrPort, error) {
|
|||||||
func (NoopConn) ListenOut(_ EncReader) {
|
func (NoopConn) ListenOut(_ EncReader) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
func (NoopConn) ListenOutBatch(_ EncBatchReader) {
|
||||||
|
return
|
||||||
|
}
|
||||||
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (NoopConn) WriteMulti(_ [][]byte, _ []netip.AddrPort) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
func (NoopConn) ReloadConfig(_ *config.C) {
|
func (NoopConn) ReloadConfig(_ *config.C) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
func (NoopConn) BatchSize() int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
func (NoopConn) Close() error {
|
func (NoopConn) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -140,6 +140,17 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WriteMulti sends multiple packets - fallback implementation without sendmmsg
|
||||||
|
func (u *StdConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
|
||||||
|
for i := range packets {
|
||||||
|
err := u.WriteTo(packets[i], addrs[i])
|
||||||
|
if err != nil {
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(packets), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
||||||
a := u.UDPConn.LocalAddr()
|
a := u.UDPConn.LocalAddr()
|
||||||
|
|
||||||
@@ -184,6 +195,34 @@ func (u *StdConn) ListenOut(r EncReader) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListenOutBatch - fallback to single-packet reads for Darwin
|
||||||
|
func (u *StdConn) ListenOutBatch(r EncBatchReader) {
|
||||||
|
buffer := make([]byte, MTU)
|
||||||
|
addrs := make([]netip.AddrPort, 1)
|
||||||
|
payloads := make([][]byte, 1)
|
||||||
|
|
||||||
|
for {
|
||||||
|
// Just read one packet at a time and call batch callback with count=1
|
||||||
|
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, net.ErrClosed) {
|
||||||
|
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
u.l.WithError(err).Error("unexpected udp socket receive error")
|
||||||
|
}
|
||||||
|
|
||||||
|
addrs[0] = netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port())
|
||||||
|
payloads[0] = buffer[:n]
|
||||||
|
r(addrs, payloads, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) BatchSize() int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
func (u *StdConn) Rebind() error {
|
func (u *StdConn) Rebind() error {
|
||||||
var err error
|
var err error
|
||||||
if u.isV4 {
|
if u.isV4 {
|
||||||
|
|||||||
@@ -85,3 +85,42 @@ func (u *GenericConn) ListenOut(r EncReader) {
|
|||||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListenOutBatch - fallback to single-packet reads for generic platforms
|
||||||
|
func (u *GenericConn) ListenOutBatch(r EncBatchReader) {
|
||||||
|
buffer := make([]byte, MTU)
|
||||||
|
addrs := make([]netip.AddrPort, 1)
|
||||||
|
payloads := make([][]byte, 1)
|
||||||
|
|
||||||
|
for {
|
||||||
|
// Just read one packet at a time and call batch callback with count=1
|
||||||
|
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
||||||
|
if err != nil {
|
||||||
|
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
addrs[0] = netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port())
|
||||||
|
payloads[0] = buffer[:n]
|
||||||
|
r(addrs, payloads, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteMulti sends multiple packets - fallback implementation
|
||||||
|
func (u *GenericConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
|
||||||
|
for i := range packets {
|
||||||
|
err := u.WriteTo(packets[i], addrs[i])
|
||||||
|
if err != nil {
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(packets), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *GenericConn) BatchSize() int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *GenericConn) Rebind() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
262
udp/udp_linux.go
262
udp/udp_linux.go
@@ -22,6 +22,11 @@ type StdConn struct {
|
|||||||
isV4 bool
|
isV4 bool
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
batch int
|
batch int
|
||||||
|
|
||||||
|
// Pre-allocated buffers for batch writes (sized for IPv6, works for both)
|
||||||
|
writeMsgs []rawMessage
|
||||||
|
writeIovecs []iovec
|
||||||
|
writeNames [][]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func maybeIPV4(ip net.IP) (net.IP, bool) {
|
func maybeIPV4(ip net.IP) (net.IP, bool) {
|
||||||
@@ -69,7 +74,26 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
|
|||||||
return nil, fmt.Errorf("unable to bind to socket: %s", err)
|
return nil, fmt.Errorf("unable to bind to socket: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
|
c := &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}
|
||||||
|
|
||||||
|
// Pre-allocate write message structures for batching (sized for IPv6, works for both)
|
||||||
|
c.writeMsgs = make([]rawMessage, batch)
|
||||||
|
c.writeIovecs = make([]iovec, batch)
|
||||||
|
c.writeNames = make([][]byte, batch)
|
||||||
|
|
||||||
|
for i := range c.writeMsgs {
|
||||||
|
// Allocate for IPv6 size (larger than IPv4, works for both)
|
||||||
|
c.writeNames[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||||
|
|
||||||
|
// Point to the iovec in the slice
|
||||||
|
c.writeMsgs[i].Hdr.Iov = &c.writeIovecs[i]
|
||||||
|
c.writeMsgs[i].Hdr.Iovlen = 1
|
||||||
|
|
||||||
|
c.writeMsgs[i].Hdr.Name = &c.writeNames[i][0]
|
||||||
|
// Namelen will be set appropriately in writeMulti4/writeMulti6
|
||||||
|
}
|
||||||
|
|
||||||
|
return c, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) Rebind() error {
|
func (u *StdConn) Rebind() error {
|
||||||
@@ -127,6 +151,8 @@ func (u *StdConn) ListenOut(r EncReader) {
|
|||||||
read = u.ReadSingle
|
read = u.ReadSingle
|
||||||
}
|
}
|
||||||
|
|
||||||
|
udpBatchHist := metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024))
|
||||||
|
|
||||||
for {
|
for {
|
||||||
n, err := read(msgs)
|
n, err := read(msgs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -134,6 +160,8 @@ func (u *StdConn) ListenOut(r EncReader) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
udpBatchHist.Update(int64(n))
|
||||||
|
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
|
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
|
||||||
if u.isV4 {
|
if u.isV4 {
|
||||||
@@ -146,6 +174,46 @@ func (u *StdConn) ListenOut(r EncReader) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) ListenOutBatch(r EncBatchReader) {
|
||||||
|
var ip netip.Addr
|
||||||
|
|
||||||
|
msgs, buffers, names := u.PrepareRawMessages(u.batch)
|
||||||
|
read := u.ReadMulti
|
||||||
|
if u.batch == 1 {
|
||||||
|
read = u.ReadSingle
|
||||||
|
}
|
||||||
|
|
||||||
|
udpBatchHist := metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024))
|
||||||
|
|
||||||
|
// Pre-allocate slices for batch callback
|
||||||
|
addrs := make([]netip.AddrPort, u.batch)
|
||||||
|
payloads := make([][]byte, u.batch)
|
||||||
|
|
||||||
|
for {
|
||||||
|
n, err := read(msgs)
|
||||||
|
if err != nil {
|
||||||
|
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
udpBatchHist.Update(int64(n))
|
||||||
|
|
||||||
|
// Prepare batch data
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
if u.isV4 {
|
||||||
|
ip, _ = netip.AddrFromSlice(names[i][4:8])
|
||||||
|
} else {
|
||||||
|
ip, _ = netip.AddrFromSlice(names[i][8:24])
|
||||||
|
}
|
||||||
|
addrs[i] = netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
|
||||||
|
payloads[i] = buffers[i][:msgs[i].Len]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call batch callback with all packets
|
||||||
|
r(addrs, payloads, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
|
func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
|
||||||
for {
|
for {
|
||||||
n, _, err := unix.Syscall6(
|
n, _, err := unix.Syscall6(
|
||||||
@@ -194,6 +262,19 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
|
|||||||
return u.writeTo6(b, ip)
|
return u.writeTo6(b, ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
|
||||||
|
if len(packets) != len(addrs) {
|
||||||
|
return 0, fmt.Errorf("packets and addrs length mismatch")
|
||||||
|
}
|
||||||
|
if len(packets) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if u.isV4 {
|
||||||
|
return u.writeMulti4(packets, addrs)
|
||||||
|
}
|
||||||
|
return u.writeMulti6(packets, addrs)
|
||||||
|
}
|
||||||
|
|
||||||
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
|
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
|
||||||
var rsa unix.RawSockaddrInet6
|
var rsa unix.RawSockaddrInet6
|
||||||
rsa.Family = unix.AF_INET6
|
rsa.Family = unix.AF_INET6
|
||||||
@@ -248,6 +329,123 @@ func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) writeMulti4(packets [][]byte, addrs []netip.AddrPort) (int, error) {
|
||||||
|
sent := 0
|
||||||
|
for sent < len(packets) {
|
||||||
|
// Determine batch size based on remaining packets and buffer capacity
|
||||||
|
batchSize := len(packets) - sent
|
||||||
|
if batchSize > len(u.writeMsgs) {
|
||||||
|
batchSize = len(u.writeMsgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use pre-allocated buffers
|
||||||
|
msgs := u.writeMsgs[:batchSize]
|
||||||
|
iovecs := u.writeIovecs[:batchSize]
|
||||||
|
names := u.writeNames[:batchSize]
|
||||||
|
|
||||||
|
// Setup message structures for this batch
|
||||||
|
for i := 0; i < batchSize; i++ {
|
||||||
|
pktIdx := sent + i
|
||||||
|
if !addrs[pktIdx].Addr().Is4() {
|
||||||
|
return sent + i, ErrInvalidIPv6RemoteForSocket
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup the packet buffer
|
||||||
|
iovecs[i].Base = &packets[pktIdx][0]
|
||||||
|
iovecs[i].Len = uint(len(packets[pktIdx]))
|
||||||
|
|
||||||
|
// Setup the destination address
|
||||||
|
rsa := (*unix.RawSockaddrInet4)(unsafe.Pointer(&names[i][0]))
|
||||||
|
rsa.Family = unix.AF_INET
|
||||||
|
rsa.Addr = addrs[pktIdx].Addr().As4()
|
||||||
|
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[pktIdx].Port())
|
||||||
|
|
||||||
|
// Set the appropriate address length for IPv4
|
||||||
|
msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet4
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send this batch
|
||||||
|
nsent, _, err := unix.Syscall6(
|
||||||
|
unix.SYS_SENDMMSG,
|
||||||
|
uintptr(u.sysFd),
|
||||||
|
uintptr(unsafe.Pointer(&msgs[0])),
|
||||||
|
uintptr(batchSize),
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != 0 {
|
||||||
|
return sent + int(nsent), &net.OpError{Op: "sendmmsg", Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
sent += int(nsent)
|
||||||
|
if int(nsent) < batchSize {
|
||||||
|
// Couldn't send all packets in batch, return what we sent
|
||||||
|
return sent, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sent, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) writeMulti6(packets [][]byte, addrs []netip.AddrPort) (int, error) {
|
||||||
|
sent := 0
|
||||||
|
for sent < len(packets) {
|
||||||
|
// Determine batch size based on remaining packets and buffer capacity
|
||||||
|
batchSize := len(packets) - sent
|
||||||
|
if batchSize > len(u.writeMsgs) {
|
||||||
|
batchSize = len(u.writeMsgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use pre-allocated buffers
|
||||||
|
msgs := u.writeMsgs[:batchSize]
|
||||||
|
iovecs := u.writeIovecs[:batchSize]
|
||||||
|
names := u.writeNames[:batchSize]
|
||||||
|
|
||||||
|
// Setup message structures for this batch
|
||||||
|
for i := 0; i < batchSize; i++ {
|
||||||
|
pktIdx := sent + i
|
||||||
|
|
||||||
|
// Setup the packet buffer
|
||||||
|
iovecs[i].Base = &packets[pktIdx][0]
|
||||||
|
iovecs[i].Len = uint(len(packets[pktIdx]))
|
||||||
|
|
||||||
|
// Setup the destination address
|
||||||
|
rsa := (*unix.RawSockaddrInet6)(unsafe.Pointer(&names[i][0]))
|
||||||
|
rsa.Family = unix.AF_INET6
|
||||||
|
rsa.Addr = addrs[pktIdx].Addr().As16()
|
||||||
|
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[pktIdx].Port())
|
||||||
|
|
||||||
|
// Set the appropriate address length for IPv6
|
||||||
|
msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet6
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send this batch
|
||||||
|
nsent, _, err := unix.Syscall6(
|
||||||
|
unix.SYS_SENDMMSG,
|
||||||
|
uintptr(u.sysFd),
|
||||||
|
uintptr(unsafe.Pointer(&msgs[0])),
|
||||||
|
uintptr(batchSize),
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != 0 {
|
||||||
|
return sent + int(nsent), &net.OpError{Op: "sendmmsg", Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
sent += int(nsent)
|
||||||
|
if int(nsent) < batchSize {
|
||||||
|
// Couldn't send all packets in batch, return what we sent
|
||||||
|
return sent, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sent, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (u *StdConn) ReloadConfig(c *config.C) {
|
func (u *StdConn) ReloadConfig(c *config.C) {
|
||||||
b := c.GetInt("listen.read_buffer", 0)
|
b := c.GetInt("listen.read_buffer", 0)
|
||||||
if b > 0 {
|
if b > 0 {
|
||||||
@@ -305,56 +503,40 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) BatchSize() int {
|
||||||
|
return u.batch
|
||||||
|
}
|
||||||
|
|
||||||
func (u *StdConn) Close() error {
|
func (u *StdConn) Close() error {
|
||||||
return syscall.Close(u.sysFd)
|
return syscall.Close(u.sysFd)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUDPStatsEmitter(udpConns []Conn) func() {
|
func NewUDPStatsEmitter(udpConns []Conn) func() {
|
||||||
if len(udpConns) == 0 {
|
// Check if our kernel supports SO_MEMINFO before registering the gauges
|
||||||
return func() {}
|
var udpGauges [][unix.SK_MEMINFO_VARS]metrics.Gauge
|
||||||
}
|
|
||||||
|
|
||||||
type statsProvider struct {
|
|
||||||
index int
|
|
||||||
conn *StdConn
|
|
||||||
}
|
|
||||||
|
|
||||||
providers := make([]statsProvider, 0, len(udpConns))
|
|
||||||
for i, c := range udpConns {
|
|
||||||
if sc, ok := c.(*StdConn); ok {
|
|
||||||
providers = append(providers, statsProvider{index: i, conn: sc})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(providers) == 0 {
|
|
||||||
return func() {}
|
|
||||||
}
|
|
||||||
|
|
||||||
var meminfo [unix.SK_MEMINFO_VARS]uint32
|
var meminfo [unix.SK_MEMINFO_VARS]uint32
|
||||||
if err := providers[0].conn.getMemInfo(&meminfo); err != nil {
|
if err := udpConns[0].(*StdConn).getMemInfo(&meminfo); err == nil {
|
||||||
return func() {}
|
udpGauges = make([][unix.SK_MEMINFO_VARS]metrics.Gauge, len(udpConns))
|
||||||
}
|
for i := range udpConns {
|
||||||
|
udpGauges[i] = [unix.SK_MEMINFO_VARS]metrics.Gauge{
|
||||||
udpGauges := make([][unix.SK_MEMINFO_VARS]metrics.Gauge, len(providers))
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rmem_alloc", i), nil),
|
||||||
for i, provider := range providers {
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rcvbuf", i), nil),
|
||||||
udpGauges[i] = [unix.SK_MEMINFO_VARS]metrics.Gauge{
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_alloc", i), nil),
|
||||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rmem_alloc", provider.index), nil),
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.sndbuf", i), nil),
|
||||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rcvbuf", provider.index), nil),
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.fwd_alloc", i), nil),
|
||||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_alloc", provider.index), nil),
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_queued", i), nil),
|
||||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.sndbuf", provider.index), nil),
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.optmem", i), nil),
|
||||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.fwd_alloc", provider.index), nil),
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.backlog", i), nil),
|
||||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_queued", provider.index), nil),
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.drops", i), nil),
|
||||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.optmem", provider.index), nil),
|
}
|
||||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.backlog", provider.index), nil),
|
|
||||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.drops", provider.index), nil),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return func() {
|
return func() {
|
||||||
for i, provider := range providers {
|
for i, gauges := range udpGauges {
|
||||||
if err := provider.conn.getMemInfo(&meminfo); err == nil {
|
if err := udpConns[i].(*StdConn).getMemInfo(&meminfo); err == nil {
|
||||||
for j := 0; j < unix.SK_MEMINFO_VARS; j++ {
|
for j := 0; j < unix.SK_MEMINFO_VARS; j++ {
|
||||||
udpGauges[i][j].Update(int64(meminfo[j]))
|
gauges[j].Update(int64(meminfo[j]))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
|
|
||||||
type iovec struct {
|
type iovec struct {
|
||||||
Base *byte
|
Base *byte
|
||||||
Len uint32
|
Len uint
|
||||||
}
|
}
|
||||||
|
|
||||||
type msghdr struct {
|
type msghdr struct {
|
||||||
@@ -40,7 +40,7 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
|||||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||||
|
|
||||||
vs := []iovec{
|
vs := []iovec{
|
||||||
{Base: &buffers[i][0], Len: uint32(len(buffers[i]))},
|
{Base: &buffers[i][0], Len: uint(len(buffers[i]))},
|
||||||
}
|
}
|
||||||
|
|
||||||
msgs[i].Hdr.Iov = &vs[0]
|
msgs[i].Hdr.Iov = &vs[0]
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
|
|
||||||
type iovec struct {
|
type iovec struct {
|
||||||
Base *byte
|
Base *byte
|
||||||
Len uint64
|
Len uint
|
||||||
}
|
}
|
||||||
|
|
||||||
type msghdr struct {
|
type msghdr struct {
|
||||||
@@ -43,7 +43,7 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
|||||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||||
|
|
||||||
vs := []iovec{
|
vs := []iovec{
|
||||||
{Base: &buffers[i][0], Len: uint64(len(buffers[i]))},
|
{Base: &buffers[i][0], Len: uint(len(buffers[i]))},
|
||||||
}
|
}
|
||||||
|
|
||||||
msgs[i].Hdr.Iov = &vs[0]
|
msgs[i].Hdr.Iov = &vs[0]
|
||||||
|
|||||||
@@ -116,6 +116,31 @@ func (u *TesterConn) ListenOut(r EncReader) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *TesterConn) ListenOutBatch(r EncBatchReader) {
|
||||||
|
addrs := make([]netip.AddrPort, 1)
|
||||||
|
payloads := make([][]byte, 1)
|
||||||
|
|
||||||
|
for {
|
||||||
|
p, ok := <-u.RxPackets
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
addrs[0] = p.From
|
||||||
|
payloads[0] = p.Data
|
||||||
|
r(addrs, payloads, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *TesterConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
|
||||||
|
for i := range packets {
|
||||||
|
err := u.WriteTo(packets[i], addrs[i])
|
||||||
|
if err != nil {
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(packets), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (u *TesterConn) ReloadConfig(*config.C) {}
|
func (u *TesterConn) ReloadConfig(*config.C) {}
|
||||||
|
|
||||||
func NewUDPStatsEmitter(_ []Conn) func() {
|
func NewUDPStatsEmitter(_ []Conn) func() {
|
||||||
@@ -127,6 +152,10 @@ func (u *TesterConn) LocalAddr() (netip.AddrPort, error) {
|
|||||||
return u.Addr, nil
|
return u.Addr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *TesterConn) BatchSize() int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
func (u *TesterConn) Rebind() error {
|
func (u *TesterConn) Rebind() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,225 +0,0 @@
|
|||||||
//go:build linux && !android && !e2e_testing
|
|
||||||
|
|
||||||
package udp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
|
||||||
wgconn "github.com/slackhq/nebula/wgstack/conn"
|
|
||||||
)
|
|
||||||
|
|
||||||
// WGConn adapts WireGuard's batched UDP bind implementation to Nebula's udp.Conn interface.
|
|
||||||
type WGConn struct {
|
|
||||||
l *logrus.Logger
|
|
||||||
bind *wgconn.StdNetBind
|
|
||||||
recvers []wgconn.ReceiveFunc
|
|
||||||
batch int
|
|
||||||
reqBatch int
|
|
||||||
localIP netip.Addr
|
|
||||||
localPort uint16
|
|
||||||
enableGSO bool
|
|
||||||
enableGRO bool
|
|
||||||
gsoMaxSeg int
|
|
||||||
closed atomic.Bool
|
|
||||||
|
|
||||||
closeOnce sync.Once
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewWireguardListener creates a UDP listener backed by WireGuard's StdNetBind.
|
|
||||||
func NewWireguardListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
|
||||||
bind := wgconn.NewStdNetBindForAddr(ip, multi)
|
|
||||||
recvers, actualPort, err := bind.Open(uint16(port))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if batch <= 0 {
|
|
||||||
batch = bind.BatchSize()
|
|
||||||
} else if batch > bind.BatchSize() {
|
|
||||||
batch = bind.BatchSize()
|
|
||||||
}
|
|
||||||
return &WGConn{
|
|
||||||
l: l,
|
|
||||||
bind: bind,
|
|
||||||
recvers: recvers,
|
|
||||||
batch: batch,
|
|
||||||
reqBatch: batch,
|
|
||||||
localIP: ip,
|
|
||||||
localPort: actualPort,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WGConn) Rebind() error {
|
|
||||||
// WireGuard's bind does not support rebinding in place.
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WGConn) LocalAddr() (netip.AddrPort, error) {
|
|
||||||
if !c.localIP.IsValid() || c.localIP.IsUnspecified() {
|
|
||||||
// Fallback to wildcard IPv4 for display purposes.
|
|
||||||
return netip.AddrPortFrom(netip.IPv4Unspecified(), c.localPort), nil
|
|
||||||
}
|
|
||||||
return netip.AddrPortFrom(c.localIP, c.localPort), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WGConn) listen(fn wgconn.ReceiveFunc, r EncReader) {
|
|
||||||
batchSize := c.batch
|
|
||||||
packets := make([][]byte, batchSize)
|
|
||||||
for i := range packets {
|
|
||||||
packets[i] = make([]byte, MTU)
|
|
||||||
}
|
|
||||||
sizes := make([]int, batchSize)
|
|
||||||
endpoints := make([]wgconn.Endpoint, batchSize)
|
|
||||||
|
|
||||||
for {
|
|
||||||
if c.closed.Load() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
n, err := fn(packets, sizes, endpoints)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, net.ErrClosed) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if c.l != nil {
|
|
||||||
c.l.WithError(err).Debug("wireguard UDP listener receive error")
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for i := 0; i < n; i++ {
|
|
||||||
if sizes[i] == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
stdEp, ok := endpoints[i].(*wgconn.StdNetEndpoint)
|
|
||||||
if !ok {
|
|
||||||
if c.l != nil {
|
|
||||||
c.l.Warn("wireguard UDP listener received unexpected endpoint type")
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
addr := stdEp.AddrPort
|
|
||||||
r(addr, packets[i][:sizes[i]])
|
|
||||||
endpoints[i] = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WGConn) ListenOut(r EncReader) {
|
|
||||||
for _, fn := range c.recvers {
|
|
||||||
go c.listen(fn, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WGConn) WriteTo(b []byte, addr netip.AddrPort) error {
|
|
||||||
if len(b) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if c.closed.Load() {
|
|
||||||
return net.ErrClosed
|
|
||||||
}
|
|
||||||
ep := &wgconn.StdNetEndpoint{AddrPort: addr}
|
|
||||||
return c.bind.Send([][]byte{b}, ep)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WGConn) WriteBatch(datagrams []Datagram) error {
|
|
||||||
if len(datagrams) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if c.closed.Load() {
|
|
||||||
return net.ErrClosed
|
|
||||||
}
|
|
||||||
max := c.batch
|
|
||||||
if max <= 0 {
|
|
||||||
max = len(datagrams)
|
|
||||||
if max == 0 {
|
|
||||||
max = 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
bufs := make([][]byte, 0, max)
|
|
||||||
var (
|
|
||||||
current netip.AddrPort
|
|
||||||
endpoint *wgconn.StdNetEndpoint
|
|
||||||
haveAddr bool
|
|
||||||
)
|
|
||||||
flush := func() error {
|
|
||||||
if len(bufs) == 0 || endpoint == nil {
|
|
||||||
bufs = bufs[:0]
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
err := c.bind.Send(bufs, endpoint)
|
|
||||||
bufs = bufs[:0]
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, d := range datagrams {
|
|
||||||
if len(d.Payload) == 0 || !d.Addr.IsValid() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !haveAddr || d.Addr != current {
|
|
||||||
if err := flush(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
current = d.Addr
|
|
||||||
endpoint = &wgconn.StdNetEndpoint{AddrPort: current}
|
|
||||||
haveAddr = true
|
|
||||||
}
|
|
||||||
bufs = append(bufs, d.Payload)
|
|
||||||
if len(bufs) >= max {
|
|
||||||
if err := flush(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WGConn) ConfigureOffload(enableGSO, enableGRO bool, maxSegments int) {
|
|
||||||
c.enableGSO = enableGSO
|
|
||||||
c.enableGRO = enableGRO
|
|
||||||
if maxSegments <= 0 {
|
|
||||||
maxSegments = 1
|
|
||||||
} else if maxSegments > wgconn.IdealBatchSize {
|
|
||||||
maxSegments = wgconn.IdealBatchSize
|
|
||||||
}
|
|
||||||
c.gsoMaxSeg = maxSegments
|
|
||||||
|
|
||||||
effectiveBatch := c.reqBatch
|
|
||||||
if enableGSO && c.bind != nil {
|
|
||||||
bindBatch := c.bind.BatchSize()
|
|
||||||
if effectiveBatch < bindBatch {
|
|
||||||
if c.l != nil {
|
|
||||||
c.l.WithFields(logrus.Fields{
|
|
||||||
"requested": c.reqBatch,
|
|
||||||
"effective": bindBatch,
|
|
||||||
}).Warn("listen.batch below wireguard minimum; using bind batch size for UDP GSO support")
|
|
||||||
}
|
|
||||||
effectiveBatch = bindBatch
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.batch = effectiveBatch
|
|
||||||
|
|
||||||
if c.l != nil {
|
|
||||||
c.l.WithFields(logrus.Fields{
|
|
||||||
"enableGSO": enableGSO,
|
|
||||||
"enableGRO": enableGRO,
|
|
||||||
"gsoMaxSegments": maxSegments,
|
|
||||||
}).Debug("configured wireguard UDP offload")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WGConn) ReloadConfig(*config.C) {
|
|
||||||
// WireGuard bind currently does not expose runtime configuration knobs.
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WGConn) Close() error {
|
|
||||||
var err error
|
|
||||||
c.closeOnce.Do(func() {
|
|
||||||
c.closed.Store(true)
|
|
||||||
err = c.bind.Close()
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
//go:build !linux || android || e2e_testing
|
|
||||||
|
|
||||||
package udp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewWireguardListener is only available on Linux builds.
|
|
||||||
func NewWireguardListener(*logrus.Logger, netip.Addr, int, bool, int) (Conn, error) {
|
|
||||||
return nil, fmt.Errorf("wireguard experimental UDP listener is only supported on Linux")
|
|
||||||
}
|
|
||||||
@@ -1,539 +0,0 @@
|
|||||||
// SPDX-License-Identifier: MIT
|
|
||||||
//
|
|
||||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"runtime"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"golang.org/x/net/ipv4"
|
|
||||||
"golang.org/x/net/ipv6"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
_ Bind = (*StdNetBind)(nil)
|
|
||||||
)
|
|
||||||
|
|
||||||
// StdNetBind implements Bind for all platforms. While Windows has its own Bind
|
|
||||||
// (see bind_windows.go), it may fall back to StdNetBind.
|
|
||||||
// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
|
|
||||||
// methods for sending and receiving multiple datagrams per-syscall. See the
|
|
||||||
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
|
|
||||||
type StdNetBind struct {
|
|
||||||
mu sync.Mutex // protects all fields except as specified
|
|
||||||
ipv4 *net.UDPConn
|
|
||||||
ipv6 *net.UDPConn
|
|
||||||
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
|
|
||||||
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
|
|
||||||
|
|
||||||
// these three fields are not guarded by mu
|
|
||||||
udpAddrPool sync.Pool
|
|
||||||
ipv4MsgsPool sync.Pool
|
|
||||||
ipv6MsgsPool sync.Pool
|
|
||||||
|
|
||||||
blackhole4 bool
|
|
||||||
blackhole6 bool
|
|
||||||
|
|
||||||
listenAddr4 string
|
|
||||||
listenAddr6 string
|
|
||||||
bindV4 bool
|
|
||||||
bindV6 bool
|
|
||||||
reusePort bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func newStdNetBind() *StdNetBind {
|
|
||||||
return &StdNetBind{
|
|
||||||
udpAddrPool: sync.Pool{
|
|
||||||
New: func() any {
|
|
||||||
return &net.UDPAddr{
|
|
||||||
IP: make([]byte, 16),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
|
|
||||||
ipv4MsgsPool: sync.Pool{
|
|
||||||
New: func() any {
|
|
||||||
msgs := make([]ipv4.Message, IdealBatchSize)
|
|
||||||
for i := range msgs {
|
|
||||||
msgs[i].Buffers = make(net.Buffers, 1)
|
|
||||||
msgs[i].OOB = make([]byte, srcControlSize)
|
|
||||||
}
|
|
||||||
return &msgs
|
|
||||||
},
|
|
||||||
},
|
|
||||||
|
|
||||||
ipv6MsgsPool: sync.Pool{
|
|
||||||
New: func() any {
|
|
||||||
msgs := make([]ipv6.Message, IdealBatchSize)
|
|
||||||
for i := range msgs {
|
|
||||||
msgs[i].Buffers = make(net.Buffers, 1)
|
|
||||||
msgs[i].OOB = make([]byte, srcControlSize)
|
|
||||||
}
|
|
||||||
return &msgs
|
|
||||||
},
|
|
||||||
},
|
|
||||||
bindV4: true,
|
|
||||||
bindV6: true,
|
|
||||||
reusePort: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewStdNetBind creates a bind that listens on all interfaces.
|
|
||||||
func NewStdNetBind() *StdNetBind {
|
|
||||||
return newStdNetBind()
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewStdNetBindForAddr creates a bind that listens on a specific address.
|
|
||||||
// If addr is IPv4, only the IPv4 socket will be created. For IPv6, only the
|
|
||||||
// IPv6 socket will be created.
|
|
||||||
func NewStdNetBindForAddr(addr netip.Addr, reusePort bool) *StdNetBind {
|
|
||||||
b := newStdNetBind()
|
|
||||||
if addr.IsValid() {
|
|
||||||
if addr.IsUnspecified() {
|
|
||||||
// keep dual-stack defaults with empty listen addresses
|
|
||||||
} else if addr.Is4() {
|
|
||||||
b.listenAddr4 = addr.Unmap().String()
|
|
||||||
b.bindV4 = true
|
|
||||||
b.bindV6 = false
|
|
||||||
} else {
|
|
||||||
b.listenAddr6 = addr.Unmap().String()
|
|
||||||
b.bindV6 = true
|
|
||||||
b.bindV4 = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
b.reusePort = reusePort
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
type StdNetEndpoint struct {
|
|
||||||
// AddrPort is the endpoint destination.
|
|
||||||
netip.AddrPort
|
|
||||||
// src is the current sticky source address and interface index, if supported.
|
|
||||||
src struct {
|
|
||||||
netip.Addr
|
|
||||||
ifidx int32
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
_ Bind = (*StdNetBind)(nil)
|
|
||||||
_ Endpoint = &StdNetEndpoint{}
|
|
||||||
)
|
|
||||||
|
|
||||||
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
|
|
||||||
e, err := netip.ParseAddrPort(s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &StdNetEndpoint{
|
|
||||||
AddrPort: e,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *StdNetEndpoint) ClearSrc() {
|
|
||||||
e.src.ifidx = 0
|
|
||||||
e.src.Addr = netip.Addr{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *StdNetEndpoint) DstIP() netip.Addr {
|
|
||||||
return e.AddrPort.Addr()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *StdNetEndpoint) SrcIP() netip.Addr {
|
|
||||||
return e.src.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *StdNetEndpoint) SrcIfidx() int32 {
|
|
||||||
return e.src.ifidx
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *StdNetEndpoint) DstToBytes() []byte {
|
|
||||||
b, _ := e.AddrPort.MarshalBinary()
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *StdNetEndpoint) DstToString() string {
|
|
||||||
return e.AddrPort.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *StdNetEndpoint) SrcToString() string {
|
|
||||||
return e.src.Addr.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StdNetBind) listenNet(network string, host string, port int) (*net.UDPConn, int, error) {
|
|
||||||
lc := listenConfig()
|
|
||||||
if s.reusePort {
|
|
||||||
base := lc.Control
|
|
||||||
lc.Control = func(network, address string, c syscall.RawConn) error {
|
|
||||||
if base != nil {
|
|
||||||
if err := base(network, address, c); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return c.Control(func(fd uintptr) {
|
|
||||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
addr := ":" + strconv.Itoa(port)
|
|
||||||
if host != "" {
|
|
||||||
addr = net.JoinHostPort(host, strconv.Itoa(port))
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := lc.ListenPacket(context.Background(), network, addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retrieve port.
|
|
||||||
laddr := conn.LocalAddr()
|
|
||||||
uaddr, err := net.ResolveUDPAddr(
|
|
||||||
laddr.Network(),
|
|
||||||
laddr.String(),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
return conn.(*net.UDPConn), uaddr.Port, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StdNetBind) openIPv4(port int) (*net.UDPConn, *ipv4.PacketConn, int, error) {
|
|
||||||
if !s.bindV4 {
|
|
||||||
return nil, nil, port, nil
|
|
||||||
}
|
|
||||||
host := s.listenAddr4
|
|
||||||
conn, actualPort, err := s.listenNet("udp4", host, port)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, syscall.EAFNOSUPPORT) {
|
|
||||||
return nil, nil, port, nil
|
|
||||||
}
|
|
||||||
return nil, nil, port, err
|
|
||||||
}
|
|
||||||
if runtime.GOOS != "linux" {
|
|
||||||
return conn, nil, actualPort, nil
|
|
||||||
}
|
|
||||||
pc := ipv4.NewPacketConn(conn)
|
|
||||||
return conn, pc, actualPort, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StdNetBind) openIPv6(port int) (*net.UDPConn, *ipv6.PacketConn, int, error) {
|
|
||||||
if !s.bindV6 {
|
|
||||||
return nil, nil, port, nil
|
|
||||||
}
|
|
||||||
host := s.listenAddr6
|
|
||||||
conn, actualPort, err := s.listenNet("udp6", host, port)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, syscall.EAFNOSUPPORT) {
|
|
||||||
return nil, nil, port, nil
|
|
||||||
}
|
|
||||||
return nil, nil, port, err
|
|
||||||
}
|
|
||||||
if runtime.GOOS != "linux" {
|
|
||||||
return conn, nil, actualPort, nil
|
|
||||||
}
|
|
||||||
pc := ipv6.NewPacketConn(conn)
|
|
||||||
return conn, pc, actualPort, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
var err error
|
|
||||||
var tries int
|
|
||||||
|
|
||||||
if s.ipv4 != nil || s.ipv6 != nil {
|
|
||||||
return nil, 0, ErrBindAlreadyOpen
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempt to open ipv4 and ipv6 listeners on the same port.
|
|
||||||
// If uport is 0, we can retry on failure.
|
|
||||||
again:
|
|
||||||
port := int(uport)
|
|
||||||
var v4conn *net.UDPConn
|
|
||||||
var v6conn *net.UDPConn
|
|
||||||
var v4pc *ipv4.PacketConn
|
|
||||||
var v6pc *ipv6.PacketConn
|
|
||||||
|
|
||||||
v4conn, v4pc, port, err = s.openIPv4(port)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Listen on the same port as we're using for ipv4.
|
|
||||||
v6conn, v6pc, port, err = s.openIPv6(port)
|
|
||||||
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
|
|
||||||
if v4conn != nil {
|
|
||||||
v4conn.Close()
|
|
||||||
}
|
|
||||||
tries++
|
|
||||||
goto again
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
if v4conn != nil {
|
|
||||||
v4conn.Close()
|
|
||||||
}
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var fns []ReceiveFunc
|
|
||||||
if v4conn != nil {
|
|
||||||
s.ipv4 = v4conn
|
|
||||||
if v4pc != nil {
|
|
||||||
s.ipv4PC = v4pc
|
|
||||||
}
|
|
||||||
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn))
|
|
||||||
}
|
|
||||||
if v6conn != nil {
|
|
||||||
s.ipv6 = v6conn
|
|
||||||
if v6pc != nil {
|
|
||||||
s.ipv6PC = v6pc
|
|
||||||
}
|
|
||||||
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn))
|
|
||||||
}
|
|
||||||
if len(fns) == 0 {
|
|
||||||
return nil, 0, syscall.EAFNOSUPPORT
|
|
||||||
}
|
|
||||||
|
|
||||||
return fns, uint16(port), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc {
|
|
||||||
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
|
||||||
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
|
|
||||||
defer s.ipv4MsgsPool.Put(msgs)
|
|
||||||
for i := range bufs {
|
|
||||||
(*msgs)[i].Buffers[0] = bufs[i]
|
|
||||||
}
|
|
||||||
var numMsgs int
|
|
||||||
if runtime.GOOS == "linux" && pc != nil {
|
|
||||||
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
msg := &(*msgs)[0]
|
|
||||||
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
numMsgs = 1
|
|
||||||
}
|
|
||||||
for i := 0; i < numMsgs; i++ {
|
|
||||||
msg := &(*msgs)[i]
|
|
||||||
sizes[i] = msg.N
|
|
||||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
|
||||||
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
|
||||||
getSrcFromControl(msg.OOB[:msg.NN], ep)
|
|
||||||
eps[i] = ep
|
|
||||||
}
|
|
||||||
return numMsgs, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc {
|
|
||||||
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
|
||||||
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
|
|
||||||
defer s.ipv6MsgsPool.Put(msgs)
|
|
||||||
for i := range bufs {
|
|
||||||
(*msgs)[i].Buffers[0] = bufs[i]
|
|
||||||
}
|
|
||||||
var numMsgs int
|
|
||||||
if runtime.GOOS == "linux" && pc != nil {
|
|
||||||
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
msg := &(*msgs)[0]
|
|
||||||
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
numMsgs = 1
|
|
||||||
}
|
|
||||||
for i := 0; i < numMsgs; i++ {
|
|
||||||
msg := &(*msgs)[i]
|
|
||||||
sizes[i] = msg.N
|
|
||||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
|
||||||
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
|
||||||
getSrcFromControl(msg.OOB[:msg.NN], ep)
|
|
||||||
eps[i] = ep
|
|
||||||
}
|
|
||||||
return numMsgs, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
|
|
||||||
// rename the IdealBatchSize constant to BatchSize.
|
|
||||||
func (s *StdNetBind) BatchSize() int {
|
|
||||||
if runtime.GOOS == "linux" {
|
|
||||||
return IdealBatchSize
|
|
||||||
}
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StdNetBind) Close() error {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
var err1, err2 error
|
|
||||||
if s.ipv4 != nil {
|
|
||||||
err1 = s.ipv4.Close()
|
|
||||||
s.ipv4 = nil
|
|
||||||
s.ipv4PC = nil
|
|
||||||
}
|
|
||||||
if s.ipv6 != nil {
|
|
||||||
err2 = s.ipv6.Close()
|
|
||||||
s.ipv6 = nil
|
|
||||||
s.ipv6PC = nil
|
|
||||||
}
|
|
||||||
s.blackhole4 = false
|
|
||||||
s.blackhole6 = false
|
|
||||||
if err1 != nil {
|
|
||||||
return err1
|
|
||||||
}
|
|
||||||
return err2
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
|
|
||||||
s.mu.Lock()
|
|
||||||
blackhole := s.blackhole4
|
|
||||||
conn := s.ipv4
|
|
||||||
var (
|
|
||||||
pc4 *ipv4.PacketConn
|
|
||||||
pc6 *ipv6.PacketConn
|
|
||||||
)
|
|
||||||
is6 := false
|
|
||||||
if endpoint.DstIP().Is6() {
|
|
||||||
blackhole = s.blackhole6
|
|
||||||
conn = s.ipv6
|
|
||||||
pc6 = s.ipv6PC
|
|
||||||
is6 = true
|
|
||||||
} else {
|
|
||||||
pc4 = s.ipv4PC
|
|
||||||
}
|
|
||||||
s.mu.Unlock()
|
|
||||||
|
|
||||||
if blackhole {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if conn == nil {
|
|
||||||
return syscall.EAFNOSUPPORT
|
|
||||||
}
|
|
||||||
if is6 {
|
|
||||||
return s.send6(conn, pc6, endpoint, bufs)
|
|
||||||
} else {
|
|
||||||
return s.send4(conn, pc4, endpoint, bufs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error {
|
|
||||||
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
|
||||||
as4 := ep.DstIP().As4()
|
|
||||||
copy(ua.IP, as4[:])
|
|
||||||
ua.IP = ua.IP[:4]
|
|
||||||
ua.Port = int(ep.(*StdNetEndpoint).Port())
|
|
||||||
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
|
|
||||||
for i, buf := range bufs {
|
|
||||||
(*msgs)[i].Buffers[0] = buf
|
|
||||||
(*msgs)[i].Addr = ua
|
|
||||||
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
|
|
||||||
}
|
|
||||||
var (
|
|
||||||
n int
|
|
||||||
err error
|
|
||||||
start int
|
|
||||||
)
|
|
||||||
if runtime.GOOS == "linux" && pc != nil {
|
|
||||||
for {
|
|
||||||
n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, syscall.EAFNOSUPPORT) {
|
|
||||||
for j := start; j < len(bufs); j++ {
|
|
||||||
_, _, werr := conn.WriteMsgUDP(bufs[j], (*msgs)[j].OOB, ua)
|
|
||||||
if werr != nil {
|
|
||||||
err = werr
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if n == len((*msgs)[start:len(bufs)]) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
start += n
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for i, buf := range bufs {
|
|
||||||
_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s.udpAddrPool.Put(ua)
|
|
||||||
s.ipv4MsgsPool.Put(msgs)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error {
|
|
||||||
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
|
||||||
as16 := ep.DstIP().As16()
|
|
||||||
copy(ua.IP, as16[:])
|
|
||||||
ua.IP = ua.IP[:16]
|
|
||||||
ua.Port = int(ep.(*StdNetEndpoint).Port())
|
|
||||||
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
|
|
||||||
for i, buf := range bufs {
|
|
||||||
(*msgs)[i].Buffers[0] = buf
|
|
||||||
(*msgs)[i].Addr = ua
|
|
||||||
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
|
|
||||||
}
|
|
||||||
var (
|
|
||||||
n int
|
|
||||||
err error
|
|
||||||
start int
|
|
||||||
)
|
|
||||||
if runtime.GOOS == "linux" && pc != nil {
|
|
||||||
for {
|
|
||||||
n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, syscall.EAFNOSUPPORT) {
|
|
||||||
for j := start; j < len(bufs); j++ {
|
|
||||||
_, _, werr := conn.WriteMsgUDP(bufs[j], (*msgs)[j].OOB, ua)
|
|
||||||
if werr != nil {
|
|
||||||
err = werr
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if n == len((*msgs)[start:len(bufs)]) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
start += n
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for i, buf := range bufs {
|
|
||||||
_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s.udpAddrPool.Put(ua)
|
|
||||||
s.ipv6MsgsPool.Put(msgs)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
@@ -1,131 +0,0 @@
|
|||||||
// SPDX-License-Identifier: MIT
|
|
||||||
//
|
|
||||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
"reflect"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
IdealBatchSize = 128 // maximum number of packets handled per read and write
|
|
||||||
)
|
|
||||||
|
|
||||||
// A ReceiveFunc receives at least one packet from the network and writes them
|
|
||||||
// into packets. On a successful read it returns the number of elements of
|
|
||||||
// sizes, packets, and endpoints that should be evaluated. Some elements of
|
|
||||||
// sizes may be zero, and callers should ignore them. Callers must pass a sizes
|
|
||||||
// and eps slice with a length greater than or equal to the length of packets.
|
|
||||||
// These lengths must not exceed the length of the associated Bind.BatchSize().
|
|
||||||
type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
|
|
||||||
|
|
||||||
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
|
|
||||||
//
|
|
||||||
// A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
|
|
||||||
// depending on the platform-specific implementation.
|
|
||||||
type Bind interface {
|
|
||||||
// Open puts the Bind into a listening state on a given port and reports the actual
|
|
||||||
// port that it bound to. Passing zero results in a random selection.
|
|
||||||
// fns is the set of functions that will be called to receive packets.
|
|
||||||
Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error)
|
|
||||||
|
|
||||||
// Close closes the Bind listener.
|
|
||||||
// All fns returned by Open must return net.ErrClosed after a call to Close.
|
|
||||||
Close() error
|
|
||||||
|
|
||||||
// SetMark sets the mark for each packet sent through this Bind.
|
|
||||||
// This mark is passed to the kernel as the socket option SO_MARK.
|
|
||||||
SetMark(mark uint32) error
|
|
||||||
|
|
||||||
// Send writes one or more packets in bufs to address ep. The length of
|
|
||||||
// bufs must not exceed BatchSize().
|
|
||||||
Send(bufs [][]byte, ep Endpoint) error
|
|
||||||
|
|
||||||
// ParseEndpoint creates a new endpoint from a string.
|
|
||||||
ParseEndpoint(s string) (Endpoint, error)
|
|
||||||
|
|
||||||
// BatchSize is the number of buffers expected to be passed to
|
|
||||||
// the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
|
|
||||||
BatchSize() int
|
|
||||||
}
|
|
||||||
|
|
||||||
// BindSocketToInterface is implemented by Bind objects that support being
|
|
||||||
// tied to a single network interface. Used by wireguard-windows.
|
|
||||||
type BindSocketToInterface interface {
|
|
||||||
BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error
|
|
||||||
BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// PeekLookAtSocketFd is implemented by Bind objects that support having their
|
|
||||||
// file descriptor peeked at. Used by wireguard-android.
|
|
||||||
type PeekLookAtSocketFd interface {
|
|
||||||
PeekLookAtSocketFd4() (fd int, err error)
|
|
||||||
PeekLookAtSocketFd6() (fd int, err error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// An Endpoint maintains the source/destination caching for a peer.
|
|
||||||
//
|
|
||||||
// dst: the remote address of a peer ("endpoint" in uapi terminology)
|
|
||||||
// src: the local address from which datagrams originate going to the peer
|
|
||||||
type Endpoint interface {
|
|
||||||
ClearSrc() // clears the source address
|
|
||||||
SrcToString() string // returns the local source address (ip:port)
|
|
||||||
DstToString() string // returns the destination address (ip:port)
|
|
||||||
DstToBytes() []byte // used for mac2 cookie calculations
|
|
||||||
DstIP() netip.Addr
|
|
||||||
SrcIP() netip.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrBindAlreadyOpen = errors.New("bind is already open")
|
|
||||||
ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (fn ReceiveFunc) PrettyName() string {
|
|
||||||
name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
|
|
||||||
// 0. cheese/taco.beansIPv6.func12.func21218-fm
|
|
||||||
name = strings.TrimSuffix(name, "-fm")
|
|
||||||
// 1. cheese/taco.beansIPv6.func12.func21218
|
|
||||||
if idx := strings.LastIndexByte(name, '/'); idx != -1 {
|
|
||||||
name = name[idx+1:]
|
|
||||||
// 2. taco.beansIPv6.func12.func21218
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
var idx int
|
|
||||||
for idx = len(name) - 1; idx >= 0; idx-- {
|
|
||||||
if name[idx] < '0' || name[idx] > '9' {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if idx == len(name)-1 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
const dotFunc = ".func"
|
|
||||||
if !strings.HasSuffix(name[:idx+1], dotFunc) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
name = name[:idx+1-len(dotFunc)]
|
|
||||||
// 3. taco.beansIPv6.func12
|
|
||||||
// 4. taco.beansIPv6
|
|
||||||
}
|
|
||||||
if idx := strings.LastIndexByte(name, '.'); idx != -1 {
|
|
||||||
name = name[idx+1:]
|
|
||||||
// 5. beansIPv6
|
|
||||||
}
|
|
||||||
if name == "" {
|
|
||||||
return fmt.Sprintf("%p", fn)
|
|
||||||
}
|
|
||||||
if strings.HasSuffix(name, "IPv4") {
|
|
||||||
return "v4"
|
|
||||||
}
|
|
||||||
if strings.HasSuffix(name, "IPv6") {
|
|
||||||
return "v6"
|
|
||||||
}
|
|
||||||
return name
|
|
||||||
}
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
// SPDX-License-Identifier: MIT
|
|
||||||
//
|
|
||||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"syscall"
|
|
||||||
)
|
|
||||||
|
|
||||||
// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is
|
|
||||||
// the max supported by a default configuration of macOS. Some platforms will
|
|
||||||
// silently clamp the value to other maximums, such as linux clamping to
|
|
||||||
// net.core.{r,w}mem_max (see _linux.go for additional implementation that works
|
|
||||||
// around this limitation)
|
|
||||||
const socketBufferSize = 7 << 20
|
|
||||||
|
|
||||||
// controlFn is the callback function signature from net.ListenConfig.Control.
|
|
||||||
// It is used to apply platform specific configuration to the socket prior to
|
|
||||||
// bind.
|
|
||||||
type controlFn func(network, address string, c syscall.RawConn) error
|
|
||||||
|
|
||||||
// controlFns is a list of functions that are called from the listen config
|
|
||||||
// that can apply socket options.
|
|
||||||
var controlFns = []controlFn{}
|
|
||||||
|
|
||||||
// listenConfig returns a net.ListenConfig that applies the controlFns to the
|
|
||||||
// socket prior to bind. This is used to apply socket buffer sizing and packet
|
|
||||||
// information OOB configuration for sticky sockets.
|
|
||||||
func listenConfig() *net.ListenConfig {
|
|
||||||
return &net.ListenConfig{
|
|
||||||
Control: func(network, address string, c syscall.RawConn) error {
|
|
||||||
for _, fn := range controlFns {
|
|
||||||
if err := fn(network, address, c); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
//go:build linux
|
|
||||||
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
//
|
|
||||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"runtime"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
controlFns = append(controlFns,
|
|
||||||
|
|
||||||
// Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by
|
|
||||||
// using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to
|
|
||||||
// fail silently - the result of failure is lower performance on very fast
|
|
||||||
// links or high latency links.
|
|
||||||
func(network, address string, c syscall.RawConn) error {
|
|
||||||
return c.Control(func(fd uintptr) {
|
|
||||||
// Set up to *mem_max
|
|
||||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
|
|
||||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
|
|
||||||
// Set beyond *mem_max if CAP_NET_ADMIN
|
|
||||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize)
|
|
||||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize)
|
|
||||||
})
|
|
||||||
},
|
|
||||||
|
|
||||||
// Enable receiving of the packet information (IP_PKTINFO for IPv4,
|
|
||||||
// IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
|
|
||||||
func(network, address string, c syscall.RawConn) error {
|
|
||||||
var err error
|
|
||||||
switch network {
|
|
||||||
case "udp4":
|
|
||||||
if runtime.GOOS != "android" {
|
|
||||||
c.Control(func(fd uintptr) {
|
|
||||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
case "udp6":
|
|
||||||
c.Control(func(fd uintptr) {
|
|
||||||
if runtime.GOOS != "android" {
|
|
||||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
|
|
||||||
})
|
|
||||||
default:
|
|
||||||
err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
//go:build !windows
|
|
||||||
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
//
|
|
||||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
func NewDefaultBind() Bind { return NewStdNetBind() }
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
//go:build !linux
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
func errShouldDisableUDPGSO(err error) bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
func errShouldDisableUDPGSO(err error) bool {
|
|
||||||
var serr *os.SyscallError
|
|
||||||
if errors.As(err, &serr) {
|
|
||||||
// EIO is returned by udp_send_skb() if the device driver does not have
|
|
||||||
// tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
|
|
||||||
// See:
|
|
||||||
// https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
|
|
||||||
// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
|
|
||||||
return serr.Err == unix.EIO
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
//go:build !linux
|
|
||||||
// +build !linux
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import "net"
|
|
||||||
|
|
||||||
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
|
|
||||||
rc, err := conn.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = rc.Control(func(fd uintptr) {
|
|
||||||
_, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
|
|
||||||
txOffload = errSyscall == nil
|
|
||||||
opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO)
|
|
||||||
rxOffload = errSyscall == nil && opt == 1
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return false, false
|
|
||||||
}
|
|
||||||
return txOffload, rxOffload
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
//go:build !linux
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
|
|
||||||
func getGSOSize(control []byte) (int, error) {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
|
|
||||||
func setGSOSize(control *[]byte, gsoSize uint16) {
|
|
||||||
}
|
|
||||||
|
|
||||||
// gsoControlSize returns the recommended buffer size for pooling sticky and UDP
|
|
||||||
// offloading control data.
|
|
||||||
const gsoControlSize = 0
|
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
//go:build linux
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
sizeOfGSOData = 2
|
|
||||||
)
|
|
||||||
|
|
||||||
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
|
|
||||||
func getGSOSize(control []byte) (int, error) {
|
|
||||||
var (
|
|
||||||
hdr unix.Cmsghdr
|
|
||||||
data []byte
|
|
||||||
rem = control
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
for len(rem) > unix.SizeofCmsghdr {
|
|
||||||
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("error parsing socket control message: %w", err)
|
|
||||||
}
|
|
||||||
if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData {
|
|
||||||
var gso uint16
|
|
||||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
|
|
||||||
return int(gso), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
|
|
||||||
// data in control untouched.
|
|
||||||
func setGSOSize(control *[]byte, gsoSize uint16) {
|
|
||||||
existingLen := len(*control)
|
|
||||||
avail := cap(*control) - existingLen
|
|
||||||
space := unix.CmsgSpace(sizeOfGSOData)
|
|
||||||
if avail < space {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
*control = (*control)[:cap(*control)]
|
|
||||||
gsoControl := (*control)[existingLen:]
|
|
||||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
|
|
||||||
hdr.Level = unix.SOL_UDP
|
|
||||||
hdr.Type = unix.UDP_SEGMENT
|
|
||||||
hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
|
|
||||||
copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
|
|
||||||
*control = (*control)[:existingLen+space]
|
|
||||||
}
|
|
||||||
|
|
||||||
// gsoControlSize returns the recommended buffer size for pooling UDP
|
|
||||||
// offloading control data.
|
|
||||||
var gsoControlSize = unix.CmsgSpace(sizeOfGSOData)
|
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
//go:build linux || openbsd || freebsd
|
|
||||||
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
//
|
|
||||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"runtime"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
var fwmarkIoctl int
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
switch runtime.GOOS {
|
|
||||||
case "linux", "android":
|
|
||||||
fwmarkIoctl = 36 /* unix.SO_MARK */
|
|
||||||
case "freebsd":
|
|
||||||
fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */
|
|
||||||
case "openbsd":
|
|
||||||
fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StdNetBind) SetMark(mark uint32) error {
|
|
||||||
var operr error
|
|
||||||
if fwmarkIoctl == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if s.ipv4 != nil {
|
|
||||||
fd, err := s.ipv4.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = fd.Control(func(fd uintptr) {
|
|
||||||
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
|
|
||||||
})
|
|
||||||
if err == nil {
|
|
||||||
err = operr
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if s.ipv6 != nil {
|
|
||||||
fd, err := s.ipv6.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = fd.Control(func(fd uintptr) {
|
|
||||||
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
|
|
||||||
})
|
|
||||||
if err == nil {
|
|
||||||
err = operr
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
//go:build !linux || android
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import "net/netip"
|
|
||||||
|
|
||||||
func (e *StdNetEndpoint) SrcIP() netip.Addr {
|
|
||||||
return netip.Addr{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *StdNetEndpoint) SrcIfidx() int32 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *StdNetEndpoint) SrcToString() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
|
|
||||||
// {get,set}srcControl feature set, but use alternatively named flags and need
|
|
||||||
// ports and require testing.
|
|
||||||
|
|
||||||
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
|
||||||
// the source information found.
|
|
||||||
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
|
|
||||||
}
|
|
||||||
|
|
||||||
// setSrcControl parses the control for PKTINFO and if found updates ep with
|
|
||||||
// the source information found.
|
|
||||||
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
|
|
||||||
}
|
|
||||||
|
|
||||||
// stickyControlSize returns the recommended buffer size for pooling sticky
|
|
||||||
// offloading control data.
|
|
||||||
const stickyControlSize = 0
|
|
||||||
|
|
||||||
const StdNetSupportsStickySockets = false
|
|
||||||
@@ -1,116 +0,0 @@
|
|||||||
//go:build linux && !android
|
|
||||||
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
//
|
|
||||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
|
||||||
// the source information found.
|
|
||||||
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
|
|
||||||
ep.ClearSrc()
|
|
||||||
|
|
||||||
var (
|
|
||||||
hdr unix.Cmsghdr
|
|
||||||
data []byte
|
|
||||||
rem []byte = control
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
for len(rem) > unix.SizeofCmsghdr {
|
|
||||||
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if hdr.Level == unix.IPPROTO_IP &&
|
|
||||||
hdr.Type == unix.IP_PKTINFO {
|
|
||||||
|
|
||||||
info := pktInfoFromBuf[unix.Inet4Pktinfo](data)
|
|
||||||
ep.src.Addr = netip.AddrFrom4(info.Spec_dst)
|
|
||||||
ep.src.ifidx = info.Ifindex
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if hdr.Level == unix.IPPROTO_IPV6 &&
|
|
||||||
hdr.Type == unix.IPV6_PKTINFO {
|
|
||||||
|
|
||||||
info := pktInfoFromBuf[unix.Inet6Pktinfo](data)
|
|
||||||
ep.src.Addr = netip.AddrFrom16(info.Addr)
|
|
||||||
ep.src.ifidx = int32(info.Ifindex)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// pktInfoFromBuf returns type T populated from the provided buf via copy(). It
|
|
||||||
// panics if buf is of insufficient size.
|
|
||||||
func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) {
|
|
||||||
size := int(unsafe.Sizeof(t))
|
|
||||||
if len(buf) < size {
|
|
||||||
panic("pktInfoFromBuf: buffer too small")
|
|
||||||
}
|
|
||||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf)
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
|
|
||||||
// and source ifindex found in ep. control's len will be set to 0 in the event
|
|
||||||
// that ep is a default value.
|
|
||||||
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
|
|
||||||
*control = (*control)[:cap(*control)]
|
|
||||||
if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
|
|
||||||
*control = (*control)[:0]
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() {
|
|
||||||
*control = (*control)[:0]
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(*control) < srcControlSize {
|
|
||||||
*control = (*control)[:0]
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
|
|
||||||
if ep.SrcIP().Is4() {
|
|
||||||
hdr.Level = unix.IPPROTO_IP
|
|
||||||
hdr.Type = unix.IP_PKTINFO
|
|
||||||
hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
|
|
||||||
|
|
||||||
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
|
|
||||||
info.Ifindex = ep.src.ifidx
|
|
||||||
if ep.SrcIP().IsValid() {
|
|
||||||
info.Spec_dst = ep.SrcIP().As4()
|
|
||||||
}
|
|
||||||
*control = (*control)[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
|
|
||||||
} else {
|
|
||||||
hdr.Level = unix.IPPROTO_IPV6
|
|
||||||
hdr.Type = unix.IPV6_PKTINFO
|
|
||||||
hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
|
|
||||||
|
|
||||||
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
|
|
||||||
info.Ifindex = uint32(ep.src.ifidx)
|
|
||||||
if ep.SrcIP().IsValid() {
|
|
||||||
info.Addr = ep.SrcIP().As16()
|
|
||||||
}
|
|
||||||
*control = (*control)[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
var srcControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
|
|
||||||
|
|
||||||
const StdNetSupportsStickySockets = true
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
package tun
|
|
||||||
|
|
||||||
import "encoding/binary"
|
|
||||||
|
|
||||||
// TODO: Explore SIMD and/or other assembly optimizations.
|
|
||||||
func checksumNoFold(b []byte, initial uint64) uint64 {
|
|
||||||
ac := initial
|
|
||||||
i := 0
|
|
||||||
n := len(b)
|
|
||||||
for n >= 4 {
|
|
||||||
ac += uint64(binary.BigEndian.Uint32(b[i : i+4]))
|
|
||||||
n -= 4
|
|
||||||
i += 4
|
|
||||||
}
|
|
||||||
for n >= 2 {
|
|
||||||
ac += uint64(binary.BigEndian.Uint16(b[i : i+2]))
|
|
||||||
n -= 2
|
|
||||||
i += 2
|
|
||||||
}
|
|
||||||
if n == 1 {
|
|
||||||
ac += uint64(b[i]) << 8
|
|
||||||
}
|
|
||||||
return ac
|
|
||||||
}
|
|
||||||
|
|
||||||
func checksum(b []byte, initial uint64) uint16 {
|
|
||||||
ac := checksumNoFold(b, initial)
|
|
||||||
ac = (ac >> 16) + (ac & 0xffff)
|
|
||||||
ac = (ac >> 16) + (ac & 0xffff)
|
|
||||||
ac = (ac >> 16) + (ac & 0xffff)
|
|
||||||
ac = (ac >> 16) + (ac & 0xffff)
|
|
||||||
return uint16(ac)
|
|
||||||
}
|
|
||||||
|
|
||||||
func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
|
|
||||||
sum := checksumNoFold(srcAddr, 0)
|
|
||||||
sum = checksumNoFold(dstAddr, sum)
|
|
||||||
sum = checksumNoFold([]byte{0, protocol}, sum)
|
|
||||||
tmp := make([]byte, 2)
|
|
||||||
binary.BigEndian.PutUint16(tmp, totalLen)
|
|
||||||
return checksumNoFold(tmp, sum)
|
|
||||||
}
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
package tun
|
|
||||||
|
|
||||||
const VirtioNetHdrLen = virtioNetHdrLen
|
|
||||||
@@ -1,630 +0,0 @@
|
|||||||
//go:build linux
|
|
||||||
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
//
|
|
||||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
|
|
||||||
package tun
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
wgconn "github.com/slackhq/nebula/wgstack/conn"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
var ErrTooManySegments = errors.New("tun: too many segments for TSO")
|
|
||||||
|
|
||||||
const tcpFlagsOffset = 13
|
|
||||||
|
|
||||||
const (
|
|
||||||
tcpFlagFIN uint8 = 0x01
|
|
||||||
tcpFlagPSH uint8 = 0x08
|
|
||||||
tcpFlagACK uint8 = 0x10
|
|
||||||
)
|
|
||||||
|
|
||||||
// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
|
|
||||||
// kernel symbol is virtio_net_hdr.
|
|
||||||
type virtioNetHdr struct {
|
|
||||||
flags uint8
|
|
||||||
gsoType uint8
|
|
||||||
hdrLen uint16
|
|
||||||
gsoSize uint16
|
|
||||||
csumStart uint16
|
|
||||||
csumOffset uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *virtioNetHdr) decode(b []byte) error {
|
|
||||||
if len(b) < virtioNetHdrLen {
|
|
||||||
return io.ErrShortBuffer
|
|
||||||
}
|
|
||||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *virtioNetHdr) encode(b []byte) error {
|
|
||||||
if len(b) < virtioNetHdrLen {
|
|
||||||
return io.ErrShortBuffer
|
|
||||||
}
|
|
||||||
copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
// virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
|
|
||||||
// shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
|
|
||||||
virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
|
|
||||||
)
|
|
||||||
|
|
||||||
// flowKey represents the key for a flow.
|
|
||||||
type flowKey struct {
|
|
||||||
srcAddr, dstAddr [16]byte
|
|
||||||
srcPort, dstPort uint16
|
|
||||||
rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
|
|
||||||
}
|
|
||||||
|
|
||||||
// tcpGROTable holds flow and coalescing information for the purposes of GRO.
|
|
||||||
type tcpGROTable struct {
|
|
||||||
itemsByFlow map[flowKey][]tcpGROItem
|
|
||||||
itemsPool [][]tcpGROItem
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTCPGROTable() *tcpGROTable {
|
|
||||||
t := &tcpGROTable{
|
|
||||||
itemsByFlow: make(map[flowKey][]tcpGROItem, wgconn.IdealBatchSize),
|
|
||||||
itemsPool: make([][]tcpGROItem, wgconn.IdealBatchSize),
|
|
||||||
}
|
|
||||||
for i := range t.itemsPool {
|
|
||||||
t.itemsPool[i] = make([]tcpGROItem, 0, wgconn.IdealBatchSize)
|
|
||||||
}
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey {
|
|
||||||
key := flowKey{}
|
|
||||||
addrSize := dstAddr - srcAddr
|
|
||||||
copy(key.srcAddr[:], pkt[srcAddr:dstAddr])
|
|
||||||
copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize])
|
|
||||||
key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
|
|
||||||
key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
|
|
||||||
key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
|
|
||||||
return key
|
|
||||||
}
|
|
||||||
|
|
||||||
// lookupOrInsert looks up a flow for the provided packet and metadata,
|
|
||||||
// returning the packets found for the flow, or inserting a new one if none
|
|
||||||
// is found.
|
|
||||||
func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
|
|
||||||
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
|
|
||||||
items, ok := t.itemsByFlow[key]
|
|
||||||
if ok {
|
|
||||||
return items, ok
|
|
||||||
}
|
|
||||||
// TODO: insert() performs another map lookup. This could be rearranged to avoid.
|
|
||||||
t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex)
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// insert an item in the table for the provided packet and packet metadata.
|
|
||||||
func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
|
|
||||||
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
|
|
||||||
item := tcpGROItem{
|
|
||||||
key: key,
|
|
||||||
bufsIndex: uint16(bufsIndex),
|
|
||||||
gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])),
|
|
||||||
iphLen: uint8(tcphOffset),
|
|
||||||
tcphLen: uint8(tcphLen),
|
|
||||||
sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
|
|
||||||
pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
|
|
||||||
}
|
|
||||||
items, ok := t.itemsByFlow[key]
|
|
||||||
if !ok {
|
|
||||||
items = t.newItems()
|
|
||||||
}
|
|
||||||
items = append(items, item)
|
|
||||||
t.itemsByFlow[key] = items
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
|
|
||||||
items, _ := t.itemsByFlow[item.key]
|
|
||||||
items[i] = item
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tcpGROTable) deleteAt(key flowKey, i int) {
|
|
||||||
items, _ := t.itemsByFlow[key]
|
|
||||||
items = append(items[:i], items[i+1:]...)
|
|
||||||
t.itemsByFlow[key] = items
|
|
||||||
}
|
|
||||||
|
|
||||||
// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
|
|
||||||
// of a GRO evaluation across a vector of packets.
|
|
||||||
type tcpGROItem struct {
|
|
||||||
key flowKey
|
|
||||||
sentSeq uint32 // the sequence number
|
|
||||||
bufsIndex uint16 // the index into the original bufs slice
|
|
||||||
numMerged uint16 // the number of packets merged into this item
|
|
||||||
gsoSize uint16 // payload size
|
|
||||||
iphLen uint8 // ip header len
|
|
||||||
tcphLen uint8 // tcp header len
|
|
||||||
pshSet bool // psh flag is set
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tcpGROTable) newItems() []tcpGROItem {
|
|
||||||
var items []tcpGROItem
|
|
||||||
items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
|
|
||||||
return items
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tcpGROTable) reset() {
|
|
||||||
for k, items := range t.itemsByFlow {
|
|
||||||
items = items[:0]
|
|
||||||
t.itemsPool = append(t.itemsPool, items)
|
|
||||||
delete(t.itemsByFlow, k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// canCoalesce represents the outcome of checking if two TCP packets are
|
|
||||||
// candidates for coalescing.
|
|
||||||
type canCoalesce int
|
|
||||||
|
|
||||||
const (
|
|
||||||
coalescePrepend canCoalesce = -1
|
|
||||||
coalesceUnavailable canCoalesce = 0
|
|
||||||
coalesceAppend canCoalesce = 1
|
|
||||||
)
|
|
||||||
|
|
||||||
// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
|
|
||||||
// described by item. This function makes considerations that match the kernel's
|
|
||||||
// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
|
|
||||||
func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
|
|
||||||
pktTarget := bufs[item.bufsIndex][bufsOffset:]
|
|
||||||
if tcphLen != item.tcphLen {
|
|
||||||
// cannot coalesce with unequal tcp options len
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
if tcphLen > 20 {
|
|
||||||
if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
|
|
||||||
// cannot coalesce with unequal tcp options
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if pkt[0]>>4 == 6 {
|
|
||||||
if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 {
|
|
||||||
// cannot coalesce with unequal Traffic class values
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
if pkt[7] != pktTarget[7] {
|
|
||||||
// cannot coalesce with unequal Hop limit values
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if pkt[1] != pktTarget[1] {
|
|
||||||
// cannot coalesce with unequal ToS values
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
if pkt[6]>>5 != pktTarget[6]>>5 {
|
|
||||||
// cannot coalesce with unequal DF or reserved bits. MF is checked
|
|
||||||
// further up the stack.
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
if pkt[8] != pktTarget[8] {
|
|
||||||
// cannot coalesce with unequal TTL values
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// seq adjacency
|
|
||||||
lhsLen := item.gsoSize
|
|
||||||
lhsLen += item.numMerged * item.gsoSize
|
|
||||||
if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
|
|
||||||
if item.pshSet {
|
|
||||||
// We cannot append to a segment that has the PSH flag set, PSH
|
|
||||||
// can only be set on the final segment in a reassembled group.
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
|
|
||||||
// A smaller than gsoSize packet has been appended previously.
|
|
||||||
// Nothing can come after a smaller packet on the end.
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
if gsoSize > item.gsoSize {
|
|
||||||
// We cannot have a larger packet following a smaller one.
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
return coalesceAppend
|
|
||||||
} else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
|
|
||||||
if pshSet {
|
|
||||||
// We cannot prepend with a segment that has the PSH flag set, PSH
|
|
||||||
// can only be set on the final segment in a reassembled group.
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
if gsoSize < item.gsoSize {
|
|
||||||
// We cannot have a larger packet following a smaller one.
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
if gsoSize > item.gsoSize && item.numMerged > 0 {
|
|
||||||
// There's at least one previous merge, and we're larger than all
|
|
||||||
// previous. This would put multiple smaller packets on the end.
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
return coalescePrepend
|
|
||||||
}
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
|
|
||||||
func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool {
|
|
||||||
srcAddrAt := ipv4SrcAddrOffset
|
|
||||||
addrSize := 4
|
|
||||||
if isV6 {
|
|
||||||
srcAddrAt = ipv6SrcAddrOffset
|
|
||||||
addrSize = 16
|
|
||||||
}
|
|
||||||
tcpTotalLen := uint16(len(pkt) - int(iphLen))
|
|
||||||
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen)
|
|
||||||
return ^checksum(pkt[iphLen:], tcpCSumNoFold) == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// coalesceResult represents the result of attempting to coalesce two TCP
|
|
||||||
// packets.
|
|
||||||
type coalesceResult int
|
|
||||||
|
|
||||||
const (
|
|
||||||
coalesceInsufficientCap coalesceResult = 0
|
|
||||||
coalescePSHEnding coalesceResult = 1
|
|
||||||
coalesceItemInvalidCSum coalesceResult = 2
|
|
||||||
coalescePktInvalidCSum coalesceResult = 3
|
|
||||||
coalesceSuccess coalesceResult = 4
|
|
||||||
)
|
|
||||||
|
|
||||||
// coalesceTCPPackets attempts to coalesce pkt with the packet described by
|
|
||||||
// item, returning the outcome. This function may swap bufs elements in the
|
|
||||||
// event of a prepend as item's bufs index is already being tracked for writing
|
|
||||||
// to a Device.
|
|
||||||
func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
|
|
||||||
var pktHead []byte // the packet that will end up at the front
|
|
||||||
headersLen := item.iphLen + item.tcphLen
|
|
||||||
coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
|
|
||||||
|
|
||||||
// Copy data
|
|
||||||
if mode == coalescePrepend {
|
|
||||||
pktHead = pkt
|
|
||||||
if cap(pkt)-bufsOffset < coalescedLen {
|
|
||||||
// We don't want to allocate a new underlying array if capacity is
|
|
||||||
// too small.
|
|
||||||
return coalesceInsufficientCap
|
|
||||||
}
|
|
||||||
if pshSet {
|
|
||||||
return coalescePSHEnding
|
|
||||||
}
|
|
||||||
if item.numMerged == 0 {
|
|
||||||
if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
|
|
||||||
return coalesceItemInvalidCSum
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
|
|
||||||
return coalescePktInvalidCSum
|
|
||||||
}
|
|
||||||
item.sentSeq = seq
|
|
||||||
extendBy := coalescedLen - len(pktHead)
|
|
||||||
bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...)
|
|
||||||
copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):])
|
|
||||||
// Flip the slice headers in bufs as part of prepend. The index of item
|
|
||||||
// is already being tracked for writing.
|
|
||||||
bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex]
|
|
||||||
} else {
|
|
||||||
pktHead = bufs[item.bufsIndex][bufsOffset:]
|
|
||||||
if cap(pktHead)-bufsOffset < coalescedLen {
|
|
||||||
// We don't want to allocate a new underlying array if capacity is
|
|
||||||
// too small.
|
|
||||||
return coalesceInsufficientCap
|
|
||||||
}
|
|
||||||
if item.numMerged == 0 {
|
|
||||||
if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
|
|
||||||
return coalesceItemInvalidCSum
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
|
|
||||||
return coalescePktInvalidCSum
|
|
||||||
}
|
|
||||||
if pshSet {
|
|
||||||
// We are appending a segment with PSH set.
|
|
||||||
item.pshSet = pshSet
|
|
||||||
pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
|
|
||||||
}
|
|
||||||
extendBy := len(pkt) - int(headersLen)
|
|
||||||
bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
|
|
||||||
copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
|
|
||||||
}
|
|
||||||
|
|
||||||
if gsoSize > item.gsoSize {
|
|
||||||
item.gsoSize = gsoSize
|
|
||||||
}
|
|
||||||
hdr := virtioNetHdr{
|
|
||||||
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
|
|
||||||
hdrLen: uint16(headersLen),
|
|
||||||
gsoSize: uint16(item.gsoSize),
|
|
||||||
csumStart: uint16(item.iphLen),
|
|
||||||
csumOffset: 16,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Recalculate the total len (IPv4) or payload len (IPv6). Recalculate the
|
|
||||||
// (IPv4) header checksum.
|
|
||||||
if isV6 {
|
|
||||||
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
|
|
||||||
binary.BigEndian.PutUint16(pktHead[4:], uint16(coalescedLen)-uint16(item.iphLen)) // set new payload len
|
|
||||||
} else {
|
|
||||||
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
|
|
||||||
pktHead[10], pktHead[11] = 0, 0 // clear checksum field
|
|
||||||
binary.BigEndian.PutUint16(pktHead[2:], uint16(coalescedLen)) // set new total length
|
|
||||||
iphCSum := ^checksum(pktHead[:item.iphLen], 0) // compute checksum
|
|
||||||
binary.BigEndian.PutUint16(pktHead[10:], iphCSum) // set checksum field
|
|
||||||
}
|
|
||||||
hdr.encode(bufs[item.bufsIndex][bufsOffset-virtioNetHdrLen:])
|
|
||||||
|
|
||||||
// Calculate the pseudo header checksum and place it at the TCP checksum
|
|
||||||
// offset. Downstream checksum offloading will combine this with computation
|
|
||||||
// of the tcp header and payload checksum.
|
|
||||||
addrLen := 4
|
|
||||||
addrOffset := ipv4SrcAddrOffset
|
|
||||||
if isV6 {
|
|
||||||
addrLen = 16
|
|
||||||
addrOffset = ipv6SrcAddrOffset
|
|
||||||
}
|
|
||||||
srcAddrAt := bufsOffset + addrOffset
|
|
||||||
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
|
|
||||||
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
|
|
||||||
psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(coalescedLen-int(item.iphLen)))
|
|
||||||
binary.BigEndian.PutUint16(pktHead[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
|
|
||||||
|
|
||||||
item.numMerged++
|
|
||||||
return coalesceSuccess
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
ipv4FlagMoreFragments uint8 = 0x20
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
ipv4SrcAddrOffset = 12
|
|
||||||
ipv6SrcAddrOffset = 8
|
|
||||||
maxUint16 = 1<<16 - 1
|
|
||||||
)
|
|
||||||
|
|
||||||
// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
|
|
||||||
// existing packets tracked in table. It will return false when pktI is not
|
|
||||||
// coalesced, otherwise true. This indicates to the caller if bufs[pktI]
|
|
||||||
// should be written to the Device.
|
|
||||||
func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) (pktCoalesced bool) {
|
|
||||||
pkt := bufs[pktI][offset:]
|
|
||||||
if len(pkt) > maxUint16 {
|
|
||||||
// A valid IPv4 or IPv6 packet will never exceed this.
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
iphLen := int((pkt[0] & 0x0F) * 4)
|
|
||||||
if isV6 {
|
|
||||||
iphLen = 40
|
|
||||||
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
|
|
||||||
if ipv6HPayloadLen != len(pkt)-iphLen {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
|
|
||||||
if totalLen != len(pkt) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(pkt) < iphLen {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
tcphLen := int((pkt[iphLen+12] >> 4) * 4)
|
|
||||||
if tcphLen < 20 || tcphLen > 60 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if len(pkt) < iphLen+tcphLen {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if !isV6 {
|
|
||||||
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
|
|
||||||
// no GRO support for fragmented segments for now
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tcpFlags := pkt[iphLen+tcpFlagsOffset]
|
|
||||||
var pshSet bool
|
|
||||||
// not a candidate if any non-ACK flags (except PSH+ACK) are set
|
|
||||||
if tcpFlags != tcpFlagACK {
|
|
||||||
if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
pshSet = true
|
|
||||||
}
|
|
||||||
gsoSize := uint16(len(pkt) - tcphLen - iphLen)
|
|
||||||
// not a candidate if payload len is 0
|
|
||||||
if gsoSize < 1 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
|
|
||||||
srcAddrOffset := ipv4SrcAddrOffset
|
|
||||||
addrLen := 4
|
|
||||||
if isV6 {
|
|
||||||
srcAddrOffset = ipv6SrcAddrOffset
|
|
||||||
addrLen = 16
|
|
||||||
}
|
|
||||||
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
|
|
||||||
if !existing {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for i := len(items) - 1; i >= 0; i-- {
|
|
||||||
// In the best case of packets arriving in order iterating in reverse is
|
|
||||||
// more efficient if there are multiple items for a given flow. This
|
|
||||||
// also enables a natural table.deleteAt() in the
|
|
||||||
// coalesceItemInvalidCSum case without the need for index tracking.
|
|
||||||
// This algorithm makes a best effort to coalesce in the event of
|
|
||||||
// unordered packets, where pkt may land anywhere in items from a
|
|
||||||
// sequence number perspective, however once an item is inserted into
|
|
||||||
// the table it is never compared across other items later.
|
|
||||||
item := items[i]
|
|
||||||
can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset)
|
|
||||||
if can != coalesceUnavailable {
|
|
||||||
result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6)
|
|
||||||
switch result {
|
|
||||||
case coalesceSuccess:
|
|
||||||
table.updateAt(item, i)
|
|
||||||
return true
|
|
||||||
case coalesceItemInvalidCSum:
|
|
||||||
// delete the item with an invalid csum
|
|
||||||
table.deleteAt(item.key, i)
|
|
||||||
case coalescePktInvalidCSum:
|
|
||||||
// no point in inserting an item that we can't coalesce
|
|
||||||
return false
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// failed to coalesce with any other packets; store the item in the flow
|
|
||||||
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func isTCP4NoIPOptions(b []byte) bool {
|
|
||||||
if len(b) < 40 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if b[0]>>4 != 4 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if b[0]&0x0F != 5 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if b[9] != unix.IPPROTO_TCP {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func isTCP6NoEH(b []byte) bool {
|
|
||||||
if len(b) < 60 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if b[0]>>4 != 6 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if b[6] != unix.IPPROTO_TCP {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
|
|
||||||
// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be
|
|
||||||
// empty (but non-nil), and are passed in to save allocs as the caller may reset
|
|
||||||
// and recycle them across vectors of packets.
|
|
||||||
func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error {
|
|
||||||
for i := range bufs {
|
|
||||||
if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
|
|
||||||
return errors.New("invalid offset")
|
|
||||||
}
|
|
||||||
var coalesced bool
|
|
||||||
switch {
|
|
||||||
case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce
|
|
||||||
coalesced = tcpGRO(bufs, offset, i, tcp4Table, false)
|
|
||||||
case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce
|
|
||||||
coalesced = tcpGRO(bufs, offset, i, tcp6Table, true)
|
|
||||||
}
|
|
||||||
if !coalesced {
|
|
||||||
hdr := virtioNetHdr{}
|
|
||||||
err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
*toWrite = append(*toWrite, i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// tcpTSO splits packets from in into outBuffs, writing the size of each
|
|
||||||
// element into sizes. It returns the number of buffers populated, and/or an
|
|
||||||
// error.
|
|
||||||
func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) {
|
|
||||||
iphLen := int(hdr.csumStart)
|
|
||||||
srcAddrOffset := ipv6SrcAddrOffset
|
|
||||||
addrLen := 16
|
|
||||||
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
|
|
||||||
in[10], in[11] = 0, 0 // clear ipv4 header checksum
|
|
||||||
srcAddrOffset = ipv4SrcAddrOffset
|
|
||||||
addrLen = 4
|
|
||||||
}
|
|
||||||
tcpCSumAt := int(hdr.csumStart + hdr.csumOffset)
|
|
||||||
in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum
|
|
||||||
firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:])
|
|
||||||
nextSegmentDataAt := int(hdr.hdrLen)
|
|
||||||
i := 0
|
|
||||||
for ; nextSegmentDataAt < len(in); i++ {
|
|
||||||
if i == len(outBuffs) {
|
|
||||||
return i - 1, ErrTooManySegments
|
|
||||||
}
|
|
||||||
nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize)
|
|
||||||
if nextSegmentEnd > len(in) {
|
|
||||||
nextSegmentEnd = len(in)
|
|
||||||
}
|
|
||||||
segmentDataLen := nextSegmentEnd - nextSegmentDataAt
|
|
||||||
totalLen := int(hdr.hdrLen) + segmentDataLen
|
|
||||||
sizes[i] = totalLen
|
|
||||||
out := outBuffs[i][outOffset:]
|
|
||||||
|
|
||||||
copy(out, in[:iphLen])
|
|
||||||
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
|
|
||||||
// For IPv4 we are responsible for incrementing the ID field,
|
|
||||||
// updating the total len field, and recalculating the header
|
|
||||||
// checksum.
|
|
||||||
if i > 0 {
|
|
||||||
id := binary.BigEndian.Uint16(out[4:])
|
|
||||||
id += uint16(i)
|
|
||||||
binary.BigEndian.PutUint16(out[4:], id)
|
|
||||||
}
|
|
||||||
binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
|
|
||||||
ipv4CSum := ^checksum(out[:iphLen], 0)
|
|
||||||
binary.BigEndian.PutUint16(out[10:], ipv4CSum)
|
|
||||||
} else {
|
|
||||||
// For IPv6 we are responsible for updating the payload length field.
|
|
||||||
binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
|
|
||||||
}
|
|
||||||
|
|
||||||
// TCP header
|
|
||||||
copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
|
|
||||||
tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
|
|
||||||
binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
|
|
||||||
if nextSegmentEnd != len(in) {
|
|
||||||
// FIN and PSH should only be set on last segment
|
|
||||||
clearFlags := tcpFlagFIN | tcpFlagPSH
|
|
||||||
out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
|
|
||||||
}
|
|
||||||
|
|
||||||
// payload
|
|
||||||
copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
|
|
||||||
|
|
||||||
// TCP checksum
|
|
||||||
tcpHLen := int(hdr.hdrLen - hdr.csumStart)
|
|
||||||
tcpLenForPseudo := uint16(tcpHLen + segmentDataLen)
|
|
||||||
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo)
|
|
||||||
tcpCSum := ^checksum(out[hdr.csumStart:totalLen], tcpCSumNoFold)
|
|
||||||
binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum)
|
|
||||||
|
|
||||||
nextSegmentDataAt += int(hdr.gsoSize)
|
|
||||||
}
|
|
||||||
return i, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error {
|
|
||||||
cSumAt := cSumStart + cSumOffset
|
|
||||||
// The initial value at the checksum offset should be summed with the
|
|
||||||
// checksum we compute. This is typically the pseudo-header checksum.
|
|
||||||
initial := binary.BigEndian.Uint16(in[cSumAt:])
|
|
||||||
in[cSumAt], in[cSumAt+1] = 0, 0
|
|
||||||
binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial)))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,52 +0,0 @@
|
|||||||
// SPDX-License-Identifier: MIT
|
|
||||||
//
|
|
||||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
|
|
||||||
package tun
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Event int
|
|
||||||
|
|
||||||
const (
|
|
||||||
EventUp = 1 << iota
|
|
||||||
EventDown
|
|
||||||
EventMTUUpdate
|
|
||||||
)
|
|
||||||
|
|
||||||
type Device interface {
|
|
||||||
// File returns the file descriptor of the device.
|
|
||||||
File() *os.File
|
|
||||||
|
|
||||||
// Read one or more packets from the Device (without any additional headers).
|
|
||||||
// On a successful read it returns the number of packets read, and sets
|
|
||||||
// packet lengths within the sizes slice. len(sizes) must be >= len(bufs).
|
|
||||||
// A nonzero offset can be used to instruct the Device on where to begin
|
|
||||||
// reading into each element of the bufs slice.
|
|
||||||
Read(bufs [][]byte, sizes []int, offset int) (n int, err error)
|
|
||||||
|
|
||||||
// Write one or more packets to the device (without any additional headers).
|
|
||||||
// On a successful write it returns the number of packets written. A nonzero
|
|
||||||
// offset can be used to instruct the Device on where to begin writing from
|
|
||||||
// each packet contained within the bufs slice.
|
|
||||||
Write(bufs [][]byte, offset int) (int, error)
|
|
||||||
|
|
||||||
// MTU returns the MTU of the Device.
|
|
||||||
MTU() (int, error)
|
|
||||||
|
|
||||||
// Name returns the current name of the Device.
|
|
||||||
Name() (string, error)
|
|
||||||
|
|
||||||
// Events returns a channel of type Event, which is fed Device events.
|
|
||||||
Events() <-chan Event
|
|
||||||
|
|
||||||
// Close stops the Device and closes the Event channel.
|
|
||||||
Close() error
|
|
||||||
|
|
||||||
// BatchSize returns the preferred/max number of packets that can be read or
|
|
||||||
// written in a single read/write call. BatchSize must not change over the
|
|
||||||
// lifetime of a Device.
|
|
||||||
BatchSize() int
|
|
||||||
}
|
|
||||||
@@ -1,664 +0,0 @@
|
|||||||
//go:build linux
|
|
||||||
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
//
|
|
||||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
|
|
||||||
package tun
|
|
||||||
|
|
||||||
/* Implementation of the TUN device interface for linux
|
|
||||||
*/
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"sync"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
wgconn "github.com/slackhq/nebula/wgstack/conn"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
"golang.zx2c4.com/wireguard/rwcancel"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
cloneDevicePath = "/dev/net/tun"
|
|
||||||
ifReqSize = unix.IFNAMSIZ + 64
|
|
||||||
)
|
|
||||||
|
|
||||||
type NativeTun struct {
|
|
||||||
tunFile *os.File
|
|
||||||
index int32 // if index
|
|
||||||
errors chan error // async error handling
|
|
||||||
events chan Event // device related events
|
|
||||||
netlinkSock int
|
|
||||||
netlinkCancel *rwcancel.RWCancel
|
|
||||||
hackListenerClosed sync.Mutex
|
|
||||||
statusListenersShutdown chan struct{}
|
|
||||||
batchSize int
|
|
||||||
vnetHdr bool
|
|
||||||
|
|
||||||
closeOnce sync.Once
|
|
||||||
|
|
||||||
nameOnce sync.Once // guards calling initNameCache, which sets following fields
|
|
||||||
nameCache string // name of interface
|
|
||||||
nameErr error
|
|
||||||
|
|
||||||
readOpMu sync.Mutex // readOpMu guards readBuff
|
|
||||||
readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
|
|
||||||
|
|
||||||
writeOpMu sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable
|
|
||||||
toWrite []int
|
|
||||||
tcp4GROTable, tcp6GROTable *tcpGROTable
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *NativeTun) File() *os.File {
|
|
||||||
return tun.tunFile
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *NativeTun) routineHackListener() {
|
|
||||||
defer tun.hackListenerClosed.Unlock()
|
|
||||||
/* This is needed for the detection to work across network namespaces
|
|
||||||
* If you are reading this and know a better method, please get in touch.
|
|
||||||
*/
|
|
||||||
last := 0
|
|
||||||
const (
|
|
||||||
up = 1
|
|
||||||
down = 2
|
|
||||||
)
|
|
||||||
for {
|
|
||||||
sysconn, err := tun.tunFile.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err2 := sysconn.Control(func(fd uintptr) {
|
|
||||||
_, err = unix.Write(int(fd), nil)
|
|
||||||
})
|
|
||||||
if err2 != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
switch err {
|
|
||||||
case unix.EINVAL:
|
|
||||||
if last != up {
|
|
||||||
// If the tunnel is up, it reports that write() is
|
|
||||||
// allowed but we provided invalid data.
|
|
||||||
tun.events <- EventUp
|
|
||||||
last = up
|
|
||||||
}
|
|
||||||
case unix.EIO:
|
|
||||||
if last != down {
|
|
||||||
// If the tunnel is down, it reports that no I/O
|
|
||||||
// is possible, without checking our provided data.
|
|
||||||
tun.events <- EventDown
|
|
||||||
last = down
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
// nothing
|
|
||||||
case <-tun.statusListenersShutdown:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func createNetlinkSocket() (int, error) {
|
|
||||||
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
saddr := &unix.SockaddrNetlink{
|
|
||||||
Family: unix.AF_NETLINK,
|
|
||||||
Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR,
|
|
||||||
}
|
|
||||||
err = unix.Bind(sock, saddr)
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
return sock, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *NativeTun) routineNetlinkListener() {
|
|
||||||
defer func() {
|
|
||||||
unix.Close(tun.netlinkSock)
|
|
||||||
tun.hackListenerClosed.Lock()
|
|
||||||
close(tun.events)
|
|
||||||
tun.netlinkCancel.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
for msg := make([]byte, 1<<16); ; {
|
|
||||||
var err error
|
|
||||||
var msgn int
|
|
||||||
for {
|
|
||||||
msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
|
|
||||||
if err == nil || !rwcancel.RetryAfterError(err) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if !tun.netlinkCancel.ReadyRead() {
|
|
||||||
tun.errors <- fmt.Errorf("netlink socket closed: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-tun.statusListenersShutdown:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
wasEverUp := false
|
|
||||||
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
|
|
||||||
|
|
||||||
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
|
|
||||||
|
|
||||||
if int(hdr.Len) > len(remain) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
switch hdr.Type {
|
|
||||||
case unix.NLMSG_DONE:
|
|
||||||
remain = []byte{}
|
|
||||||
|
|
||||||
case unix.RTM_NEWLINK:
|
|
||||||
info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr]))
|
|
||||||
remain = remain[hdr.Len:]
|
|
||||||
|
|
||||||
if info.Index != tun.index {
|
|
||||||
// not our interface
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if info.Flags&unix.IFF_RUNNING != 0 {
|
|
||||||
tun.events <- EventUp
|
|
||||||
wasEverUp = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if info.Flags&unix.IFF_RUNNING == 0 {
|
|
||||||
// Don't emit EventDown before we've ever emitted EventUp.
|
|
||||||
// This avoids a startup race with HackListener, which
|
|
||||||
// might detect Up before we have finished reporting Down.
|
|
||||||
if wasEverUp {
|
|
||||||
tun.events <- EventDown
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tun.events <- EventMTUUpdate
|
|
||||||
|
|
||||||
default:
|
|
||||||
remain = remain[hdr.Len:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getIFIndex(name string) (int32, error) {
|
|
||||||
fd, err := unix.Socket(
|
|
||||||
unix.AF_INET,
|
|
||||||
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
defer unix.Close(fd)
|
|
||||||
|
|
||||||
var ifr [ifReqSize]byte
|
|
||||||
copy(ifr[:], name)
|
|
||||||
_, _, errno := unix.Syscall(
|
|
||||||
unix.SYS_IOCTL,
|
|
||||||
uintptr(fd),
|
|
||||||
uintptr(unix.SIOCGIFINDEX),
|
|
||||||
uintptr(unsafe.Pointer(&ifr[0])),
|
|
||||||
)
|
|
||||||
|
|
||||||
if errno != 0 {
|
|
||||||
return 0, errno
|
|
||||||
}
|
|
||||||
|
|
||||||
return *(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *NativeTun) setMTU(n int) error {
|
|
||||||
name, err := tun.Name()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// open datagram socket
|
|
||||||
fd, err := unix.Socket(
|
|
||||||
unix.AF_INET,
|
|
||||||
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer unix.Close(fd)
|
|
||||||
|
|
||||||
var ifr [ifReqSize]byte
|
|
||||||
copy(ifr[:], name)
|
|
||||||
*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
|
|
||||||
|
|
||||||
_, _, errno := unix.Syscall(
|
|
||||||
unix.SYS_IOCTL,
|
|
||||||
uintptr(fd),
|
|
||||||
uintptr(unix.SIOCSIFMTU),
|
|
||||||
uintptr(unsafe.Pointer(&ifr[0])),
|
|
||||||
)
|
|
||||||
|
|
||||||
if errno != 0 {
|
|
||||||
return errno
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *NativeTun) routineNetlinkRead() {
|
|
||||||
defer func() {
|
|
||||||
unix.Close(tun.netlinkSock)
|
|
||||||
tun.hackListenerClosed.Lock()
|
|
||||||
close(tun.events)
|
|
||||||
tun.netlinkCancel.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
for msg := make([]byte, 1<<16); ; {
|
|
||||||
var err error
|
|
||||||
var msgn int
|
|
||||||
for {
|
|
||||||
msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
|
|
||||||
if err == nil || !rwcancel.RetryAfterError(err) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if !tun.netlinkCancel.ReadyRead() {
|
|
||||||
tun.errors <- fmt.Errorf("netlink socket closed: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
wasEverUp := false
|
|
||||||
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
|
|
||||||
|
|
||||||
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
|
|
||||||
|
|
||||||
if int(hdr.Len) > len(remain) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
switch hdr.Type {
|
|
||||||
case unix.NLMSG_DONE:
|
|
||||||
remain = []byte{}
|
|
||||||
|
|
||||||
case unix.RTM_NEWLINK:
|
|
||||||
info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr]))
|
|
||||||
remain = remain[hdr.Len:]
|
|
||||||
|
|
||||||
if info.Index != tun.index {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if info.Flags&unix.IFF_RUNNING != 0 {
|
|
||||||
tun.events <- EventUp
|
|
||||||
wasEverUp = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if info.Flags&unix.IFF_RUNNING == 0 {
|
|
||||||
if wasEverUp {
|
|
||||||
tun.events <- EventDown
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tun.events <- EventMTUUpdate
|
|
||||||
|
|
||||||
default:
|
|
||||||
remain = remain[hdr.Len:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *NativeTun) routineNetlink() {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
tun.netlinkSock, err = createNetlinkSocket()
|
|
||||||
if err != nil {
|
|
||||||
tun.errors <- fmt.Errorf("failed to create netlink socket: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tun.netlinkCancel, err = rwcancel.NewRWCancel(tun.netlinkSock)
|
|
||||||
if err != nil {
|
|
||||||
tun.errors <- fmt.Errorf("failed to create netlink cancel: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
go tun.routineNetlinkListener()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *NativeTun) Close() error {
|
|
||||||
var err1, err2 error
|
|
||||||
tun.closeOnce.Do(func() {
|
|
||||||
if tun.statusListenersShutdown != nil {
|
|
||||||
close(tun.statusListenersShutdown)
|
|
||||||
if tun.netlinkCancel != nil {
|
|
||||||
err1 = tun.netlinkCancel.Cancel()
|
|
||||||
}
|
|
||||||
} else if tun.events != nil {
|
|
||||||
close(tun.events)
|
|
||||||
}
|
|
||||||
err2 = tun.tunFile.Close()
|
|
||||||
})
|
|
||||||
if err1 != nil {
|
|
||||||
return err1
|
|
||||||
}
|
|
||||||
return err2
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *NativeTun) BatchSize() int {
|
|
||||||
return tun.batchSize
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
// TODO: support TSO with ECN bits
|
|
||||||
tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
|
|
||||||
)
|
|
||||||
|
|
||||||
func (tun *NativeTun) initFromFlags(name string) error {
|
|
||||||
sc, err := tun.tunFile.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if e := sc.Control(func(fd uintptr) {
|
|
||||||
var (
|
|
||||||
ifr *unix.Ifreq
|
|
||||||
)
|
|
||||||
ifr, err = unix.NewIfreq(name)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
got := ifr.Uint16()
|
|
||||||
if got&unix.IFF_VNET_HDR != 0 {
|
|
||||||
err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tun.vnetHdr = true
|
|
||||||
tun.batchSize = wgconn.IdealBatchSize
|
|
||||||
} else {
|
|
||||||
tun.batchSize = 1
|
|
||||||
}
|
|
||||||
}); e != nil {
|
|
||||||
return e
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateTUN creates a Device with the provided name and MTU.
|
|
||||||
func CreateTUN(name string, mtu int) (Device, error) {
|
|
||||||
nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath)
|
|
||||||
}
|
|
||||||
fd := os.NewFile(uintptr(nfd), cloneDevicePath)
|
|
||||||
tun, err := CreateTUNFromFile(fd, mtu)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if name != "tun" {
|
|
||||||
if err := tun.(*NativeTun).initFromFlags(name); err != nil {
|
|
||||||
tun.Close()
|
|
||||||
return nil, fmt.Errorf("CreateTUN(%q) failed to set flags: %w", name, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return tun, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateTUNFromFile creates a Device from an os.File with the provided MTU.
|
|
||||||
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
|
|
||||||
tun := &NativeTun{
|
|
||||||
tunFile: file,
|
|
||||||
errors: make(chan error, 5),
|
|
||||||
events: make(chan Event, 5),
|
|
||||||
}
|
|
||||||
|
|
||||||
name, err := tun.Name()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to determine TUN name: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tun.initFromFlags(name); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to query TUN flags: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if tun.batchSize == 0 {
|
|
||||||
tun.batchSize = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
tun.index, err = getIFIndex(name)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get TUN index: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = tun.setMTU(mtu); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to set MTU: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tun.statusListenersShutdown = make(chan struct{})
|
|
||||||
go tun.routineNetlink()
|
|
||||||
|
|
||||||
if tun.batchSize == 0 {
|
|
||||||
tun.batchSize = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
tun.tcp4GROTable = newTCPGROTable()
|
|
||||||
tun.tcp6GROTable = newTCPGROTable()
|
|
||||||
|
|
||||||
return tun, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *NativeTun) Name() (string, error) {
|
|
||||||
tun.nameOnce.Do(tun.initNameCache)
|
|
||||||
return tun.nameCache, tun.nameErr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *NativeTun) initNameCache() {
|
|
||||||
sysconn, err := tun.tunFile.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
tun.nameErr = err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = sysconn.Control(func(fd uintptr) {
|
|
||||||
var ifr [ifReqSize]byte
|
|
||||||
_, _, errno := unix.Syscall(
|
|
||||||
unix.SYS_IOCTL,
|
|
||||||
fd,
|
|
||||||
uintptr(unix.TUNGETIFF),
|
|
||||||
uintptr(unsafe.Pointer(&ifr[0])),
|
|
||||||
)
|
|
||||||
if errno != 0 {
|
|
||||||
tun.nameErr = errno
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tun.nameCache = unix.ByteSliceToString(ifr[:])
|
|
||||||
})
|
|
||||||
if err != nil && tun.nameErr == nil {
|
|
||||||
tun.nameErr = err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *NativeTun) MTU() (int, error) {
|
|
||||||
name, err := tun.Name()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// open datagram socket
|
|
||||||
fd, err := unix.Socket(
|
|
||||||
unix.AF_INET,
|
|
||||||
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
defer unix.Close(fd)
|
|
||||||
|
|
||||||
var ifr [ifReqSize]byte
|
|
||||||
copy(ifr[:], name)
|
|
||||||
|
|
||||||
_, _, errno := unix.Syscall(
|
|
||||||
unix.SYS_IOCTL,
|
|
||||||
uintptr(fd),
|
|
||||||
uintptr(unix.SIOCGIFMTU),
|
|
||||||
uintptr(unsafe.Pointer(&ifr[0])),
|
|
||||||
)
|
|
||||||
|
|
||||||
if errno != 0 {
|
|
||||||
return 0, errno
|
|
||||||
}
|
|
||||||
|
|
||||||
return int(*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *NativeTun) Events() <-chan Event {
|
|
||||||
return tun.events
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
|
|
||||||
tun.writeOpMu.Lock()
|
|
||||||
defer func() {
|
|
||||||
tun.tcp4GROTable.reset()
|
|
||||||
tun.tcp6GROTable.reset()
|
|
||||||
tun.writeOpMu.Unlock()
|
|
||||||
}()
|
|
||||||
var (
|
|
||||||
errs error
|
|
||||||
total int
|
|
||||||
)
|
|
||||||
tun.toWrite = tun.toWrite[:0]
|
|
||||||
if tun.vnetHdr {
|
|
||||||
err := handleGRO(bufs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
offset -= virtioNetHdrLen
|
|
||||||
} else {
|
|
||||||
for i := range bufs {
|
|
||||||
tun.toWrite = append(tun.toWrite, i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, bufsI := range tun.toWrite {
|
|
||||||
n, err := tun.tunFile.Write(bufs[bufsI][offset:])
|
|
||||||
if errors.Is(err, syscall.EBADFD) {
|
|
||||||
return total, os.ErrClosed
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
errs = errors.Join(errs, err)
|
|
||||||
} else {
|
|
||||||
total += n
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return total, errs
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleVirtioRead splits in into bufs, leaving offset bytes at the front of
|
|
||||||
// each buffer. It mutates sizes to reflect the size of each element of bufs,
|
|
||||||
// and returns the number of packets read.
|
|
||||||
func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) {
|
|
||||||
var hdr virtioNetHdr
|
|
||||||
if err := hdr.decode(in); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
in = in[virtioNetHdrLen:]
|
|
||||||
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE {
|
|
||||||
if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
|
|
||||||
if err := gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(in) > len(bufs[0][offset:]) {
|
|
||||||
return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:]))
|
|
||||||
}
|
|
||||||
n := copy(bufs[0][offset:], in)
|
|
||||||
sizes[0] = n
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
|
|
||||||
return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
|
|
||||||
}
|
|
||||||
|
|
||||||
ipVersion := in[0] >> 4
|
|
||||||
switch ipVersion {
|
|
||||||
case 4:
|
|
||||||
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 {
|
|
||||||
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
|
|
||||||
}
|
|
||||||
case 6:
|
|
||||||
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
|
|
||||||
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(in) <= int(hdr.csumStart+12) {
|
|
||||||
return 0, errors.New("packet is too short")
|
|
||||||
}
|
|
||||||
tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
|
|
||||||
if tcpHLen < 20 || tcpHLen > 60 {
|
|
||||||
return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
|
|
||||||
}
|
|
||||||
hdr.hdrLen = hdr.csumStart + tcpHLen
|
|
||||||
if len(in) < int(hdr.hdrLen) {
|
|
||||||
return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
|
|
||||||
}
|
|
||||||
if hdr.hdrLen < hdr.csumStart {
|
|
||||||
return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart)
|
|
||||||
}
|
|
||||||
cSumAt := int(hdr.csumStart + hdr.csumOffset)
|
|
||||||
if cSumAt+1 >= len(in) {
|
|
||||||
return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
|
|
||||||
}
|
|
||||||
|
|
||||||
return tcpTSO(in, hdr, bufs, sizes, offset)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
|
|
||||||
tun.readOpMu.Lock()
|
|
||||||
defer tun.readOpMu.Unlock()
|
|
||||||
select {
|
|
||||||
case err := <-tun.errors:
|
|
||||||
return 0, err
|
|
||||||
default:
|
|
||||||
readInto := bufs[0][offset:]
|
|
||||||
if tun.vnetHdr {
|
|
||||||
readInto = tun.readBuff[:]
|
|
||||||
}
|
|
||||||
n, err := tun.tunFile.Read(readInto)
|
|
||||||
if errors.Is(err, syscall.EBADFD) {
|
|
||||||
err = os.ErrClosed
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if tun.vnetHdr {
|
|
||||||
return handleVirtioRead(readInto[:n], bufs, sizes, offset)
|
|
||||||
}
|
|
||||||
sizes[0] = n
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user