mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 16:34:25 +01:00
hmmmmmm it works i guess maybe
This commit is contained in:
164
batch_pipeline.go
Normal file
164
batch_pipeline.go
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
package nebula
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/overlay"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// batchPipelines tracks whether the inside device can operate on packet batches
|
||||||
|
// and, if so, holds the shared packet pool sized for the virtio headroom and
|
||||||
|
// payload limits advertised by the device. It also owns the fan-in/fan-out
|
||||||
|
// queues between the TUN readers, encrypt/decrypt workers, and the UDP writers.
|
||||||
|
type batchPipelines struct {
|
||||||
|
enabled bool
|
||||||
|
inside overlay.BatchCapableDevice
|
||||||
|
headroom int
|
||||||
|
payloadCap int
|
||||||
|
pool *overlay.PacketPool
|
||||||
|
batchSize int
|
||||||
|
routines int
|
||||||
|
rxQueues []chan *overlay.Packet
|
||||||
|
txQueues []chan queuedDatagram
|
||||||
|
tunQueues []chan *overlay.Packet
|
||||||
|
}
|
||||||
|
|
||||||
|
type queuedDatagram struct {
|
||||||
|
packet *overlay.Packet
|
||||||
|
addr netip.AddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) init(device overlay.Device, routines int, queueDepth int, maxSegments int) {
|
||||||
|
if device == nil || routines <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
bcap, ok := device.(overlay.BatchCapableDevice)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
headroom := bcap.BatchHeadroom()
|
||||||
|
payload := bcap.BatchPayloadCap()
|
||||||
|
if maxSegments < 1 {
|
||||||
|
maxSegments = 1
|
||||||
|
}
|
||||||
|
requiredPayload := udp.MTU * maxSegments
|
||||||
|
if payload < requiredPayload {
|
||||||
|
payload = requiredPayload
|
||||||
|
}
|
||||||
|
batchSize := bcap.BatchSize()
|
||||||
|
if headroom <= 0 || payload <= 0 || batchSize <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
bp.enabled = true
|
||||||
|
bp.inside = bcap
|
||||||
|
bp.headroom = headroom
|
||||||
|
bp.payloadCap = payload
|
||||||
|
bp.batchSize = batchSize
|
||||||
|
bp.routines = routines
|
||||||
|
bp.pool = overlay.NewPacketPool(headroom, payload)
|
||||||
|
queueCap := batchSize * defaultBatchQueueDepthFactor
|
||||||
|
if queueDepth > 0 {
|
||||||
|
queueCap = queueDepth
|
||||||
|
}
|
||||||
|
if queueCap < batchSize {
|
||||||
|
queueCap = batchSize
|
||||||
|
}
|
||||||
|
bp.rxQueues = make([]chan *overlay.Packet, routines)
|
||||||
|
bp.txQueues = make([]chan queuedDatagram, routines)
|
||||||
|
bp.tunQueues = make([]chan *overlay.Packet, routines)
|
||||||
|
for i := 0; i < routines; i++ {
|
||||||
|
bp.rxQueues[i] = make(chan *overlay.Packet, queueCap)
|
||||||
|
bp.txQueues[i] = make(chan queuedDatagram, queueCap)
|
||||||
|
bp.tunQueues[i] = make(chan *overlay.Packet, queueCap)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) Pool() *overlay.PacketPool {
|
||||||
|
if bp == nil || !bp.enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return bp.pool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) Enabled() bool {
|
||||||
|
return bp != nil && bp.enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) batchSizeHint() int {
|
||||||
|
if bp == nil || bp.batchSize <= 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return bp.batchSize
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) rxQueue(i int) chan *overlay.Packet {
|
||||||
|
if bp == nil || !bp.enabled || i < 0 || i >= len(bp.rxQueues) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return bp.rxQueues[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) txQueue(i int) chan queuedDatagram {
|
||||||
|
if bp == nil || !bp.enabled || i < 0 || i >= len(bp.txQueues) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return bp.txQueues[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) tunQueue(i int) chan *overlay.Packet {
|
||||||
|
if bp == nil || !bp.enabled || i < 0 || i >= len(bp.tunQueues) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return bp.tunQueues[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) txQueueLen(i int) int {
|
||||||
|
q := bp.txQueue(i)
|
||||||
|
if q == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return len(q)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) tunQueueLen(i int) int {
|
||||||
|
q := bp.tunQueue(i)
|
||||||
|
if q == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return len(q)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) enqueueRx(i int, pkt *overlay.Packet) bool {
|
||||||
|
q := bp.rxQueue(i)
|
||||||
|
if q == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
q <- pkt
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) enqueueTx(i int, pkt *overlay.Packet, addr netip.AddrPort) bool {
|
||||||
|
q := bp.txQueue(i)
|
||||||
|
if q == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
q <- queuedDatagram{packet: pkt, addr: addr}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) enqueueTun(i int, pkt *overlay.Packet) bool {
|
||||||
|
q := bp.tunQueue(i)
|
||||||
|
if q == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
q <- pkt
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *batchPipelines) newPacket() *overlay.Packet {
|
||||||
|
if bp == nil || !bp.enabled || bp.pool == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return bp.pool.Get()
|
||||||
|
}
|
||||||
66
inside.go
66
inside.go
@@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
"github.com/slackhq/nebula/noiseutil"
|
"github.com/slackhq/nebula/noiseutil"
|
||||||
|
"github.com/slackhq/nebula/overlay"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -335,9 +336,21 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
if ci.eKey == nil {
|
if ci.eKey == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
|
target := remote
|
||||||
|
if !target.IsValid() {
|
||||||
|
target = hostinfo.remote
|
||||||
|
}
|
||||||
|
useRelay := !target.IsValid()
|
||||||
fullOut := out
|
fullOut := out
|
||||||
|
|
||||||
|
var pkt *overlay.Packet
|
||||||
|
if !useRelay && f.batches.Enabled() {
|
||||||
|
pkt = f.batches.newPacket()
|
||||||
|
if pkt != nil {
|
||||||
|
out = pkt.Payload()[:0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if useRelay {
|
if useRelay {
|
||||||
if len(out) < header.Len {
|
if len(out) < header.Len {
|
||||||
// out always has a capacity of mtu, but not always a length greater than the header.Len.
|
// out always has a capacity of mtu, but not always a length greater than the header.Len.
|
||||||
@@ -376,26 +389,52 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
ci.writeLock.Unlock()
|
ci.writeLock.Unlock()
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if pkt != nil {
|
||||||
|
pkt.Release()
|
||||||
|
}
|
||||||
hostinfo.logger(f.l).WithError(err).
|
hostinfo.logger(f.l).WithError(err).
|
||||||
WithField("udpAddr", remote).WithField("counter", c).
|
WithField("udpAddr", target).WithField("counter", c).
|
||||||
WithField("attemptedCounter", c).
|
WithField("attemptedCounter", c).
|
||||||
Error("Failed to encrypt outgoing packet")
|
Error("Failed to encrypt outgoing packet")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if remote.IsValid() {
|
if target.IsValid() {
|
||||||
err = f.writers[q].WriteTo(out, remote)
|
if pkt != nil {
|
||||||
if err != nil {
|
pkt.Len = len(out)
|
||||||
hostinfo.logger(f.l).WithError(err).
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
f.l.WithFields(logrus.Fields{
|
||||||
|
"queue": q,
|
||||||
|
"dest": target,
|
||||||
|
"payload_len": pkt.Len,
|
||||||
|
"use_batches": true,
|
||||||
|
"remote_index": hostinfo.remoteIndexId,
|
||||||
|
}).Debug("enqueueing packet to UDP batch queue")
|
||||||
}
|
}
|
||||||
} else if hostinfo.remote.IsValid() {
|
if f.tryQueuePacket(q, pkt, target) {
|
||||||
err = f.writers[q].WriteTo(out, hostinfo.remote)
|
return
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).
|
|
||||||
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
|
||||||
}
|
}
|
||||||
} else {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithFields(logrus.Fields{
|
||||||
|
"queue": q,
|
||||||
|
"dest": target,
|
||||||
|
}).Debug("failed to enqueue packet; falling back to immediate send")
|
||||||
|
}
|
||||||
|
f.writeImmediatePacket(q, pkt, target, hostinfo)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if f.tryQueueDatagram(q, out, target) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.writeImmediate(q, out, target, hostinfo)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// fall back to relay path
|
||||||
|
if pkt != nil {
|
||||||
|
pkt.Release()
|
||||||
|
}
|
||||||
|
|
||||||
// Try to send via a relay
|
// Try to send via a relay
|
||||||
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
|
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
|
||||||
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
||||||
@@ -407,5 +446,4 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
|
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
603
interface.go
603
interface.go
@@ -21,7 +21,13 @@ import (
|
|||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
const mtu = 9001
|
const (
|
||||||
|
mtu = 9001
|
||||||
|
defaultGSOFlushInterval = 150 * time.Microsecond
|
||||||
|
defaultBatchQueueDepthFactor = 4
|
||||||
|
defaultGSOMaxSegments = 8
|
||||||
|
maxKernelGSOSegments = 64
|
||||||
|
)
|
||||||
|
|
||||||
type InterfaceConfig struct {
|
type InterfaceConfig struct {
|
||||||
HostMap *HostMap
|
HostMap *HostMap
|
||||||
@@ -36,6 +42,9 @@ type InterfaceConfig struct {
|
|||||||
connectionManager *connectionManager
|
connectionManager *connectionManager
|
||||||
DropLocalBroadcast bool
|
DropLocalBroadcast bool
|
||||||
DropMulticast bool
|
DropMulticast bool
|
||||||
|
EnableGSO bool
|
||||||
|
EnableGRO bool
|
||||||
|
GSOMaxSegments int
|
||||||
routines int
|
routines int
|
||||||
MessageMetrics *MessageMetrics
|
MessageMetrics *MessageMetrics
|
||||||
version string
|
version string
|
||||||
@@ -47,6 +56,8 @@ type InterfaceConfig struct {
|
|||||||
reQueryWait time.Duration
|
reQueryWait time.Duration
|
||||||
|
|
||||||
ConntrackCacheTimeout time.Duration
|
ConntrackCacheTimeout time.Duration
|
||||||
|
BatchFlushInterval time.Duration
|
||||||
|
BatchQueueDepth int
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -84,9 +95,20 @@ type Interface struct {
|
|||||||
version string
|
version string
|
||||||
|
|
||||||
conntrackCacheTimeout time.Duration
|
conntrackCacheTimeout time.Duration
|
||||||
|
batchQueueDepth int
|
||||||
|
enableGSO bool
|
||||||
|
enableGRO bool
|
||||||
|
gsoMaxSegments int
|
||||||
|
batchUDPQueueGauge metrics.Gauge
|
||||||
|
batchUDPFlushCounter metrics.Counter
|
||||||
|
batchTunQueueGauge metrics.Gauge
|
||||||
|
batchTunFlushCounter metrics.Counter
|
||||||
|
batchFlushInterval atomic.Int64
|
||||||
|
sendSem chan struct{}
|
||||||
|
|
||||||
writers []udp.Conn
|
writers []udp.Conn
|
||||||
readers []io.ReadWriteCloser
|
readers []io.ReadWriteCloser
|
||||||
|
batches batchPipelines
|
||||||
|
|
||||||
metricHandshakes metrics.Histogram
|
metricHandshakes metrics.Histogram
|
||||||
messageMetrics *MessageMetrics
|
messageMetrics *MessageMetrics
|
||||||
@@ -161,6 +183,22 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
return nil, errors.New("no connection manager")
|
return nil, errors.New("no connection manager")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.GSOMaxSegments <= 0 {
|
||||||
|
c.GSOMaxSegments = defaultGSOMaxSegments
|
||||||
|
}
|
||||||
|
if c.GSOMaxSegments > maxKernelGSOSegments {
|
||||||
|
c.GSOMaxSegments = maxKernelGSOSegments
|
||||||
|
}
|
||||||
|
if c.BatchQueueDepth <= 0 {
|
||||||
|
c.BatchQueueDepth = c.routines * defaultBatchQueueDepthFactor
|
||||||
|
}
|
||||||
|
if c.BatchFlushInterval < 0 {
|
||||||
|
c.BatchFlushInterval = 0
|
||||||
|
}
|
||||||
|
if c.BatchFlushInterval == 0 && c.EnableGSO {
|
||||||
|
c.BatchFlushInterval = defaultGSOFlushInterval
|
||||||
|
}
|
||||||
|
|
||||||
cs := c.pki.getCertState()
|
cs := c.pki.getCertState()
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
pki: c.pki,
|
pki: c.pki,
|
||||||
@@ -186,6 +224,10 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
relayManager: c.relayManager,
|
relayManager: c.relayManager,
|
||||||
connectionManager: c.connectionManager,
|
connectionManager: c.connectionManager,
|
||||||
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
||||||
|
batchQueueDepth: c.BatchQueueDepth,
|
||||||
|
enableGSO: c.EnableGSO,
|
||||||
|
enableGRO: c.EnableGRO,
|
||||||
|
gsoMaxSegments: c.GSOMaxSegments,
|
||||||
|
|
||||||
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
||||||
messageMetrics: c.MessageMetrics,
|
messageMetrics: c.MessageMetrics,
|
||||||
@@ -198,8 +240,25 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
||||||
|
ifce.batchUDPQueueGauge = metrics.GetOrRegisterGauge("batch.udp.queue_depth", nil)
|
||||||
|
ifce.batchUDPFlushCounter = metrics.GetOrRegisterCounter("batch.udp.flushes", nil)
|
||||||
|
ifce.batchTunQueueGauge = metrics.GetOrRegisterGauge("batch.tun.queue_depth", nil)
|
||||||
|
ifce.batchTunFlushCounter = metrics.GetOrRegisterCounter("batch.tun.flushes", nil)
|
||||||
|
ifce.batchFlushInterval.Store(int64(c.BatchFlushInterval))
|
||||||
|
ifce.sendSem = make(chan struct{}, c.routines)
|
||||||
|
ifce.batches.init(c.Inside, c.routines, c.BatchQueueDepth, c.GSOMaxSegments)
|
||||||
ifce.reQueryEvery.Store(c.reQueryEvery)
|
ifce.reQueryEvery.Store(c.reQueryEvery)
|
||||||
ifce.reQueryWait.Store(int64(c.reQueryWait))
|
ifce.reQueryWait.Store(int64(c.reQueryWait))
|
||||||
|
if c.l.Level >= logrus.DebugLevel {
|
||||||
|
c.l.WithFields(logrus.Fields{
|
||||||
|
"enableGSO": c.EnableGSO,
|
||||||
|
"enableGRO": c.EnableGRO,
|
||||||
|
"gsoMaxSegments": c.GSOMaxSegments,
|
||||||
|
"batchQueueDepth": c.BatchQueueDepth,
|
||||||
|
"batchFlush": c.BatchFlushInterval,
|
||||||
|
"batching": ifce.batches.Enabled(),
|
||||||
|
}).Debug("initialized batch pipelines")
|
||||||
|
}
|
||||||
|
|
||||||
ifce.connectionManager.intf = ifce
|
ifce.connectionManager.intf = ifce
|
||||||
|
|
||||||
@@ -248,6 +307,18 @@ func (f *Interface) run() {
|
|||||||
go f.listenOut(i)
|
go f.listenOut(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithField("batching", f.batches.Enabled()).Debug("starting interface run loops")
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.batches.Enabled() {
|
||||||
|
for i := 0; i < f.routines; i++ {
|
||||||
|
go f.runInsideBatchWorker(i)
|
||||||
|
go f.runTunWriteQueue(i)
|
||||||
|
go f.runSendQueue(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Launch n queues to read packets from tun dev
|
// Launch n queues to read packets from tun dev
|
||||||
for i := 0; i < f.routines; i++ {
|
for i := 0; i < f.routines; i++ {
|
||||||
go f.listenIn(f.readers[i], i)
|
go f.listenIn(f.readers[i], i)
|
||||||
@@ -279,6 +350,17 @@ func (f *Interface) listenOut(i int) {
|
|||||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
|
||||||
|
if f.batches.Enabled() {
|
||||||
|
if br, ok := reader.(overlay.BatchReader); ok {
|
||||||
|
f.listenInBatchLocked(reader, br, i)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f.listenInLegacyLocked(reader, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) listenInLegacyLocked(reader io.ReadWriteCloser, i int) {
|
||||||
packet := make([]byte, mtu)
|
packet := make([]byte, mtu)
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
fwPacket := &firewall.Packet{}
|
fwPacket := &firewall.Packet{}
|
||||||
@@ -302,6 +384,489 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *Interface) listenInBatchLocked(raw io.ReadWriteCloser, reader overlay.BatchReader, i int) {
|
||||||
|
pool := f.batches.Pool()
|
||||||
|
if pool == nil {
|
||||||
|
f.l.Warn("batch pipeline enabled without an allocated pool; falling back to single-packet reads")
|
||||||
|
f.listenInLegacyLocked(raw, i)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
packets, err := reader.ReadIntoBatch(pool)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
f.l.WithError(err).Error("Error while reading outbound packet batch")
|
||||||
|
os.Exit(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(packets) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, pkt := range packets {
|
||||||
|
if pkt == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !f.batches.enqueueRx(i, pkt) {
|
||||||
|
pkt.Release()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) runInsideBatchWorker(i int) {
|
||||||
|
queue := f.batches.rxQueue(i)
|
||||||
|
if queue == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]byte, mtu)
|
||||||
|
fwPacket := &firewall.Packet{}
|
||||||
|
nb := make([]byte, 12, 12)
|
||||||
|
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||||
|
|
||||||
|
for pkt := range queue {
|
||||||
|
if pkt == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
f.consumeInsidePacket(pkt.Payload(), fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
||||||
|
pkt.Release()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) runSendQueue(i int) {
|
||||||
|
queue := f.batches.txQueue(i)
|
||||||
|
if queue == nil {
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithField("queue", i).Debug("tx queue not initialized; batching disabled for writer")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writer := f.writerForIndex(i)
|
||||||
|
if writer == nil {
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithField("queue", i).Debug("no UDP writer for batch queue")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithField("queue", i).Debug("send queue worker started")
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if f.l.Level >= logrus.WarnLevel {
|
||||||
|
f.l.WithField("queue", i).Warn("send queue worker exited")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
batchCap := f.batches.batchSizeHint()
|
||||||
|
if batchCap <= 0 {
|
||||||
|
batchCap = 1
|
||||||
|
}
|
||||||
|
gsoLimit := f.effectiveGSOMaxSegments()
|
||||||
|
if gsoLimit > batchCap {
|
||||||
|
batchCap = gsoLimit
|
||||||
|
}
|
||||||
|
pending := make([]queuedDatagram, 0, batchCap)
|
||||||
|
var (
|
||||||
|
flushTimer *time.Timer
|
||||||
|
flushC <-chan time.Time
|
||||||
|
)
|
||||||
|
dispatch := func(reason string, timerFired bool) {
|
||||||
|
if len(pending) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
batch := pending
|
||||||
|
f.flushAndReleaseBatch(i, writer, batch, reason)
|
||||||
|
for idx := range batch {
|
||||||
|
batch[idx] = queuedDatagram{}
|
||||||
|
}
|
||||||
|
pending = pending[:0]
|
||||||
|
if flushTimer != nil {
|
||||||
|
if !timerFired {
|
||||||
|
if !flushTimer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-flushTimer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
flushTimer = nil
|
||||||
|
flushC = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
armTimer := func() {
|
||||||
|
delay := f.currentBatchFlushInterval()
|
||||||
|
if delay <= 0 {
|
||||||
|
dispatch("nogso", false)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if flushTimer == nil {
|
||||||
|
flushTimer = time.NewTimer(delay)
|
||||||
|
flushC = flushTimer.C
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case d := <-queue:
|
||||||
|
if d.packet == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithFields(logrus.Fields{
|
||||||
|
"queue": i,
|
||||||
|
"payload_len": d.packet.Len,
|
||||||
|
"dest": d.addr,
|
||||||
|
}).Debug("send queue received packet")
|
||||||
|
}
|
||||||
|
pending = append(pending, d)
|
||||||
|
if gsoLimit > 0 && len(pending) >= gsoLimit {
|
||||||
|
dispatch("gso", false)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(pending) >= cap(pending) {
|
||||||
|
dispatch("cap", false)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
armTimer()
|
||||||
|
f.observeUDPQueueLen(i)
|
||||||
|
case <-flushC:
|
||||||
|
dispatch("timer", true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) runTunWriteQueue(i int) {
|
||||||
|
queue := f.batches.tunQueue(i)
|
||||||
|
if queue == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writer := f.batches.inside
|
||||||
|
if writer == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
batchCap := f.batches.batchSizeHint()
|
||||||
|
if batchCap <= 0 {
|
||||||
|
batchCap = 1
|
||||||
|
}
|
||||||
|
pending := make([]*overlay.Packet, 0, batchCap)
|
||||||
|
var (
|
||||||
|
flushTimer *time.Timer
|
||||||
|
flushC <-chan time.Time
|
||||||
|
)
|
||||||
|
flush := func(reason string, timerFired bool) {
|
||||||
|
if len(pending) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := writer.WriteBatch(pending); err != nil {
|
||||||
|
f.l.WithError(err).
|
||||||
|
WithField("queue", i).
|
||||||
|
WithField("reason", reason).
|
||||||
|
Warn("Failed to write tun batch")
|
||||||
|
}
|
||||||
|
for idx := range pending {
|
||||||
|
if pending[idx] != nil {
|
||||||
|
pending[idx].Release()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pending = pending[:0]
|
||||||
|
if flushTimer != nil {
|
||||||
|
if !timerFired {
|
||||||
|
if !flushTimer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-flushTimer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
flushTimer = nil
|
||||||
|
flushC = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
armTimer := func() {
|
||||||
|
delay := f.currentBatchFlushInterval()
|
||||||
|
if delay <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if flushTimer == nil {
|
||||||
|
flushTimer = time.NewTimer(delay)
|
||||||
|
flushC = flushTimer.C
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case pkt := <-queue:
|
||||||
|
if pkt == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pending = append(pending, pkt)
|
||||||
|
if len(pending) >= cap(pending) {
|
||||||
|
flush("cap", false)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
armTimer()
|
||||||
|
f.observeTunQueueLen(i)
|
||||||
|
case <-flushC:
|
||||||
|
flush("timer", true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) flushAndReleaseBatch(index int, writer udp.Conn, batch []queuedDatagram, reason string) {
|
||||||
|
if len(batch) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.flushDatagrams(index, writer, batch, reason)
|
||||||
|
for idx := range batch {
|
||||||
|
if batch[idx].packet != nil {
|
||||||
|
batch[idx].packet.Release()
|
||||||
|
batch[idx].packet = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if f.batchUDPFlushCounter != nil {
|
||||||
|
f.batchUDPFlushCounter.Inc(int64(len(batch)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) flushDatagrams(index int, writer udp.Conn, batch []queuedDatagram, reason string) {
|
||||||
|
if len(batch) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithFields(logrus.Fields{
|
||||||
|
"writer": index,
|
||||||
|
"reason": reason,
|
||||||
|
"pending": len(batch),
|
||||||
|
}).Debug("udp batch flush summary")
|
||||||
|
}
|
||||||
|
maxSeg := f.effectiveGSOMaxSegments()
|
||||||
|
if bw, ok := writer.(udp.BatchConn); ok {
|
||||||
|
chunkCap := maxSeg
|
||||||
|
if chunkCap <= 0 {
|
||||||
|
chunkCap = len(batch)
|
||||||
|
}
|
||||||
|
chunk := make([]udp.Datagram, 0, chunkCap)
|
||||||
|
var (
|
||||||
|
currentAddr netip.AddrPort
|
||||||
|
segments int
|
||||||
|
)
|
||||||
|
flushChunk := func() {
|
||||||
|
if len(chunk) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithFields(logrus.Fields{
|
||||||
|
"writer": index,
|
||||||
|
"segments": len(chunk),
|
||||||
|
"dest": chunk[0].Addr,
|
||||||
|
"reason": reason,
|
||||||
|
"pending_total": len(batch),
|
||||||
|
}).Debug("flushing UDP batch")
|
||||||
|
}
|
||||||
|
if err := bw.WriteBatch(chunk); err != nil {
|
||||||
|
f.l.WithError(err).
|
||||||
|
WithField("writer", index).
|
||||||
|
WithField("reason", reason).
|
||||||
|
Warn("Failed to write UDP batch")
|
||||||
|
}
|
||||||
|
chunk = chunk[:0]
|
||||||
|
segments = 0
|
||||||
|
}
|
||||||
|
for _, item := range batch {
|
||||||
|
if item.packet == nil || !item.addr.IsValid() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
payload := item.packet.Payload()[:item.packet.Len]
|
||||||
|
if segments == 0 {
|
||||||
|
currentAddr = item.addr
|
||||||
|
}
|
||||||
|
if item.addr != currentAddr || (maxSeg > 0 && segments >= maxSeg) {
|
||||||
|
flushChunk()
|
||||||
|
currentAddr = item.addr
|
||||||
|
}
|
||||||
|
chunk = append(chunk, udp.Datagram{Payload: payload, Addr: item.addr})
|
||||||
|
segments++
|
||||||
|
}
|
||||||
|
flushChunk()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, item := range batch {
|
||||||
|
if item.packet == nil || !item.addr.IsValid() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithFields(logrus.Fields{
|
||||||
|
"writer": index,
|
||||||
|
"reason": reason,
|
||||||
|
"dest": item.addr,
|
||||||
|
"segments": 1,
|
||||||
|
}).Debug("flushing UDP batch")
|
||||||
|
}
|
||||||
|
if err := writer.WriteTo(item.packet.Payload()[:item.packet.Len], item.addr); err != nil {
|
||||||
|
f.l.WithError(err).
|
||||||
|
WithField("writer", index).
|
||||||
|
WithField("udpAddr", item.addr).
|
||||||
|
WithField("reason", reason).
|
||||||
|
Warn("Failed to write UDP packet")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) tryQueueDatagram(q int, buf []byte, addr netip.AddrPort) bool {
|
||||||
|
if !addr.IsValid() || !f.batches.Enabled() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
pkt := f.batches.newPacket()
|
||||||
|
if pkt == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
payload := pkt.Payload()
|
||||||
|
if len(payload) < len(buf) {
|
||||||
|
pkt.Release()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
copy(payload, buf)
|
||||||
|
pkt.Len = len(buf)
|
||||||
|
if f.batches.enqueueTx(q, pkt, addr) {
|
||||||
|
f.observeUDPQueueLen(q)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
pkt.Release()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) writerForIndex(i int) udp.Conn {
|
||||||
|
if i < 0 || i >= len(f.writers) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return f.writers[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) writeImmediate(q int, buf []byte, addr netip.AddrPort, hostinfo *HostInfo) {
|
||||||
|
writer := f.writerForIndex(q)
|
||||||
|
if writer == nil {
|
||||||
|
f.l.WithField("udpAddr", addr).
|
||||||
|
WithField("writer", q).
|
||||||
|
Error("Failed to write outgoing packet: no writer available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := writer.WriteTo(buf, addr); err != nil {
|
||||||
|
hostinfo.logger(f.l).
|
||||||
|
WithError(err).
|
||||||
|
WithField("udpAddr", addr).
|
||||||
|
Error("Failed to write outgoing packet")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) tryQueuePacket(q int, pkt *overlay.Packet, addr netip.AddrPort) bool {
|
||||||
|
if pkt == nil || !addr.IsValid() || !f.batches.Enabled() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if f.batches.enqueueTx(q, pkt, addr) {
|
||||||
|
f.observeUDPQueueLen(q)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) writeImmediatePacket(q int, pkt *overlay.Packet, addr netip.AddrPort, hostinfo *HostInfo) {
|
||||||
|
if pkt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writer := f.writerForIndex(q)
|
||||||
|
if writer == nil {
|
||||||
|
f.l.WithField("udpAddr", addr).
|
||||||
|
WithField("writer", q).
|
||||||
|
Error("Failed to write outgoing packet: no writer available")
|
||||||
|
pkt.Release()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := writer.WriteTo(pkt.Payload()[:pkt.Len], addr); err != nil {
|
||||||
|
hostinfo.logger(f.l).
|
||||||
|
WithError(err).
|
||||||
|
WithField("udpAddr", addr).
|
||||||
|
Error("Failed to write outgoing packet")
|
||||||
|
}
|
||||||
|
pkt.Release()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) writePacketToTun(q int, pkt *overlay.Packet) {
|
||||||
|
if pkt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writer := f.readers[q]
|
||||||
|
if writer == nil {
|
||||||
|
pkt.Release()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := writer.Write(pkt.Payload()[:pkt.Len]); err != nil {
|
||||||
|
f.l.WithError(err).Error("Failed to write to tun")
|
||||||
|
}
|
||||||
|
pkt.Release()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) observeUDPQueueLen(i int) {
|
||||||
|
if f.batchUDPQueueGauge == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.batchUDPQueueGauge.Update(int64(f.batches.txQueueLen(i)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) observeTunQueueLen(i int) {
|
||||||
|
if f.batchTunQueueGauge == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.batchTunQueueGauge.Update(int64(f.batches.tunQueueLen(i)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) currentBatchFlushInterval() time.Duration {
|
||||||
|
if v := f.batchFlushInterval.Load(); v > 0 {
|
||||||
|
return time.Duration(v)
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) effectiveGSOMaxSegments() int {
|
||||||
|
max := f.gsoMaxSegments
|
||||||
|
if max <= 0 {
|
||||||
|
max = defaultGSOMaxSegments
|
||||||
|
}
|
||||||
|
if max > maxKernelGSOSegments {
|
||||||
|
max = maxKernelGSOSegments
|
||||||
|
}
|
||||||
|
if !f.enableGSO {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return max
|
||||||
|
}
|
||||||
|
|
||||||
|
type udpOffloadConfigurator interface {
|
||||||
|
ConfigureOffload(enableGSO, enableGRO bool, maxSegments int)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) applyOffloadConfig(enableGSO, enableGRO bool, maxSegments int) {
|
||||||
|
if maxSegments <= 0 {
|
||||||
|
maxSegments = defaultGSOMaxSegments
|
||||||
|
}
|
||||||
|
if maxSegments > maxKernelGSOSegments {
|
||||||
|
maxSegments = maxKernelGSOSegments
|
||||||
|
}
|
||||||
|
f.enableGSO = enableGSO
|
||||||
|
f.enableGRO = enableGRO
|
||||||
|
f.gsoMaxSegments = maxSegments
|
||||||
|
for _, writer := range f.writers {
|
||||||
|
if cfg, ok := writer.(udpOffloadConfigurator); ok {
|
||||||
|
cfg.ConfigureOffload(enableGSO, enableGRO, maxSegments)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
||||||
c.RegisterReloadCallback(f.reloadFirewall)
|
c.RegisterReloadCallback(f.reloadFirewall)
|
||||||
c.RegisterReloadCallback(f.reloadSendRecvError)
|
c.RegisterReloadCallback(f.reloadSendRecvError)
|
||||||
@@ -404,6 +969,42 @@ func (f *Interface) reloadMisc(c *config.C) {
|
|||||||
f.reQueryWait.Store(int64(n))
|
f.reQueryWait.Store(int64(n))
|
||||||
f.l.Info("timers.requery_wait_duration has changed")
|
f.l.Info("timers.requery_wait_duration has changed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.HasChanged("listen.gso_flush_timeout") {
|
||||||
|
d := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushInterval)
|
||||||
|
if d < 0 {
|
||||||
|
d = 0
|
||||||
|
}
|
||||||
|
f.batchFlushInterval.Store(int64(d))
|
||||||
|
f.l.WithField("duration", d).Info("listen.gso_flush_timeout has changed")
|
||||||
|
} else if c.HasChanged("batch.flush_interval") {
|
||||||
|
d := c.GetDuration("batch.flush_interval", defaultGSOFlushInterval)
|
||||||
|
if d < 0 {
|
||||||
|
d = 0
|
||||||
|
}
|
||||||
|
f.batchFlushInterval.Store(int64(d))
|
||||||
|
f.l.WithField("duration", d).Warn("batch.flush_interval is deprecated; use listen.gso_flush_timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.HasChanged("batch.queue_depth") {
|
||||||
|
n := c.GetInt("batch.queue_depth", f.batchQueueDepth)
|
||||||
|
if n != f.batchQueueDepth {
|
||||||
|
f.batchQueueDepth = n
|
||||||
|
f.l.Warn("batch.queue_depth changes require a restart to take effect")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.HasChanged("listen.enable_gso") || c.HasChanged("listen.enable_gro") || c.HasChanged("listen.gso_max_segments") {
|
||||||
|
enableGSO := c.GetBool("listen.enable_gso", f.enableGSO)
|
||||||
|
enableGRO := c.GetBool("listen.enable_gro", f.enableGRO)
|
||||||
|
maxSeg := c.GetInt("listen.gso_max_segments", f.gsoMaxSegments)
|
||||||
|
f.applyOffloadConfig(enableGSO, enableGRO, maxSeg)
|
||||||
|
f.l.WithFields(logrus.Fields{
|
||||||
|
"enableGSO": enableGSO,
|
||||||
|
"enableGRO": enableGRO,
|
||||||
|
"gsoMaxSegments": maxSeg,
|
||||||
|
}).Info("listen GSO/GRO configuration updated")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
||||||
|
|||||||
25
main.go
25
main.go
@@ -144,6 +144,20 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
// set up our UDP listener
|
// set up our UDP listener
|
||||||
udpConns := make([]udp.Conn, routines)
|
udpConns := make([]udp.Conn, routines)
|
||||||
port := c.GetInt("listen.port", 0)
|
port := c.GetInt("listen.port", 0)
|
||||||
|
enableGSO := c.GetBool("listen.enable_gso", true)
|
||||||
|
enableGRO := c.GetBool("listen.enable_gro", true)
|
||||||
|
gsoMaxSegments := c.GetInt("listen.gso_max_segments", defaultGSOMaxSegments)
|
||||||
|
if gsoMaxSegments <= 0 {
|
||||||
|
gsoMaxSegments = defaultGSOMaxSegments
|
||||||
|
}
|
||||||
|
if gsoMaxSegments > maxKernelGSOSegments {
|
||||||
|
gsoMaxSegments = maxKernelGSOSegments
|
||||||
|
}
|
||||||
|
gsoFlushTimeout := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushInterval)
|
||||||
|
if gsoFlushTimeout < 0 {
|
||||||
|
gsoFlushTimeout = 0
|
||||||
|
}
|
||||||
|
batchQueueDepth := c.GetInt("batch.queue_depth", 0)
|
||||||
|
|
||||||
if !configTest {
|
if !configTest {
|
||||||
rawListenHost := c.GetString("listen.host", "0.0.0.0")
|
rawListenHost := c.GetString("listen.host", "0.0.0.0")
|
||||||
@@ -179,6 +193,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
||||||
}
|
}
|
||||||
udpServer.ReloadConfig(c)
|
udpServer.ReloadConfig(c)
|
||||||
|
if cfg, ok := udpServer.(interface {
|
||||||
|
ConfigureOffload(bool, bool, int)
|
||||||
|
}); ok {
|
||||||
|
cfg.ConfigureOffload(enableGSO, enableGRO, gsoMaxSegments)
|
||||||
|
}
|
||||||
udpConns[i] = udpServer
|
udpConns[i] = udpServer
|
||||||
|
|
||||||
// If port is dynamic, discover it before the next pass through the for loop
|
// If port is dynamic, discover it before the next pass through the for loop
|
||||||
@@ -246,12 +265,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
|
reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
|
||||||
DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false),
|
DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false),
|
||||||
DropMulticast: c.GetBool("tun.drop_multicast", false),
|
DropMulticast: c.GetBool("tun.drop_multicast", false),
|
||||||
|
EnableGSO: enableGSO,
|
||||||
|
EnableGRO: enableGRO,
|
||||||
|
GSOMaxSegments: gsoMaxSegments,
|
||||||
routines: routines,
|
routines: routines,
|
||||||
MessageMetrics: messageMetrics,
|
MessageMetrics: messageMetrics,
|
||||||
version: buildVersion,
|
version: buildVersion,
|
||||||
relayManager: NewRelayManager(ctx, l, hostMap, c),
|
relayManager: NewRelayManager(ctx, l, hostMap, c),
|
||||||
punchy: punchy,
|
punchy: punchy,
|
||||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||||
|
BatchFlushInterval: gsoFlushTimeout,
|
||||||
|
BatchQueueDepth: batchQueueDepth,
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -263,6 +287,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
}
|
}
|
||||||
|
|
||||||
ifce.writers = udpConns
|
ifce.writers = udpConns
|
||||||
|
ifce.applyOffloadConfig(enableGSO, enableGRO, gsoMaxSegments)
|
||||||
lightHouse.ifce = ifce
|
lightHouse.ifce = ifce
|
||||||
|
|
||||||
ifce.RegisterConfigChangeCallbacks(c)
|
ifce.RegisterConfigChangeCallbacks(c)
|
||||||
|
|||||||
38
outside.go
38
outside.go
@@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/overlay"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -466,22 +467,41 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
|
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
|
||||||
var err error
|
var (
|
||||||
|
err error
|
||||||
|
pkt *overlay.Packet
|
||||||
|
)
|
||||||
|
|
||||||
|
if f.batches.tunQueue(q) != nil {
|
||||||
|
pkt = f.batches.newPacket()
|
||||||
|
if pkt != nil {
|
||||||
|
out = pkt.Payload()[:0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if pkt != nil {
|
||||||
|
pkt.Release()
|
||||||
|
}
|
||||||
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
err = newPacket(out, true, fwPacket)
|
err = newPacket(out, true, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if pkt != nil {
|
||||||
|
pkt.Release()
|
||||||
|
}
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
|
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
|
||||||
Warnf("Error while validating inbound packet")
|
Warnf("Error while validating inbound packet")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
|
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
|
||||||
|
if pkt != nil {
|
||||||
|
pkt.Release()
|
||||||
|
}
|
||||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
||||||
Debugln("dropping out of window packet")
|
Debugln("dropping out of window packet")
|
||||||
return false
|
return false
|
||||||
@@ -489,6 +509,9 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
|
|
||||||
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
|
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
|
||||||
if dropReason != nil {
|
if dropReason != nil {
|
||||||
|
if pkt != nil {
|
||||||
|
pkt.Release()
|
||||||
|
}
|
||||||
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
||||||
// This gives us a buffer to build the reject packet in
|
// This gives us a buffer to build the reject packet in
|
||||||
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
|
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
|
||||||
@@ -501,8 +524,17 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
}
|
}
|
||||||
|
|
||||||
f.connectionManager.In(hostinfo)
|
f.connectionManager.In(hostinfo)
|
||||||
_, err = f.readers[q].Write(out)
|
if pkt != nil {
|
||||||
if err != nil {
|
pkt.Len = len(out)
|
||||||
|
if f.batches.enqueueTun(q, pkt) {
|
||||||
|
f.observeTunQueueLen(q)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
f.writePacketToTun(q, pkt)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = f.readers[q].Write(out); err != nil {
|
||||||
f.l.WithError(err).Error("Failed to write to tun")
|
f.l.WithError(err).Error("Failed to write to tun")
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package overlay
|
|||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
@@ -15,3 +16,84 @@ type Device interface {
|
|||||||
RoutesFor(netip.Addr) routing.Gateways
|
RoutesFor(netip.Addr) routing.Gateways
|
||||||
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Packet represents a single packet buffer with optional headroom to carry
|
||||||
|
// metadata (for example virtio-net headers).
|
||||||
|
type Packet struct {
|
||||||
|
Buf []byte
|
||||||
|
Offset int
|
||||||
|
Len int
|
||||||
|
release func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Packet) Payload() []byte {
|
||||||
|
return p.Buf[p.Offset : p.Offset+p.Len]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Packet) Reset() {
|
||||||
|
p.Len = 0
|
||||||
|
p.Offset = 0
|
||||||
|
p.release = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Packet) Release() {
|
||||||
|
if p.release != nil {
|
||||||
|
p.release()
|
||||||
|
p.release = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Packet) Capacity() int {
|
||||||
|
return len(p.Buf) - p.Offset
|
||||||
|
}
|
||||||
|
|
||||||
|
// PacketPool manages reusable buffers with headroom.
|
||||||
|
type PacketPool struct {
|
||||||
|
headroom int
|
||||||
|
blksz int
|
||||||
|
pool sync.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPacketPool(headroom, payload int) *PacketPool {
|
||||||
|
p := &PacketPool{headroom: headroom, blksz: headroom + payload}
|
||||||
|
p.pool.New = func() any {
|
||||||
|
buf := make([]byte, p.blksz)
|
||||||
|
return &Packet{Buf: buf, Offset: headroom}
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketPool) Get() *Packet {
|
||||||
|
pkt := p.pool.Get().(*Packet)
|
||||||
|
pkt.Offset = p.headroom
|
||||||
|
pkt.Len = 0
|
||||||
|
pkt.release = func() { p.put(pkt) }
|
||||||
|
return pkt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketPool) put(pkt *Packet) {
|
||||||
|
pkt.Reset()
|
||||||
|
p.pool.Put(pkt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchReader allows reading multiple packets into a shared pool with
|
||||||
|
// preallocated headroom (e.g. virtio-net headers).
|
||||||
|
type BatchReader interface {
|
||||||
|
ReadIntoBatch(pool *PacketPool) ([]*Packet, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchWriter writes a slice of packets that carry their own metadata.
|
||||||
|
type BatchWriter interface {
|
||||||
|
WriteBatch(packets []*Packet) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchCapableDevice describes a device that can efficiently read and write
|
||||||
|
// batches of packets with virtio headroom.
|
||||||
|
type BatchCapableDevice interface {
|
||||||
|
Device
|
||||||
|
BatchReader
|
||||||
|
BatchWriter
|
||||||
|
BatchHeadroom() int
|
||||||
|
BatchPayloadCap() int
|
||||||
|
BatchSize() int
|
||||||
|
}
|
||||||
|
|||||||
56
overlay/tun_linux_batch.go
Normal file
56
overlay/tun_linux_batch.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
//go:build linux && !android && !e2e_testing
|
||||||
|
|
||||||
|
package overlay
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
func (t *tun) batchIO() (*wireguardTunIO, bool) {
|
||||||
|
io, ok := t.ReadWriteCloser.(*wireguardTunIO)
|
||||||
|
return io, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) ReadIntoBatch(pool *PacketPool) ([]*Packet, error) {
|
||||||
|
io, ok := t.batchIO()
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("wireguard batch I/O not enabled")
|
||||||
|
}
|
||||||
|
return io.ReadIntoBatch(pool)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) WriteBatch(packets []*Packet) (int, error) {
|
||||||
|
io, ok := t.batchIO()
|
||||||
|
if ok {
|
||||||
|
return io.WriteBatch(packets)
|
||||||
|
}
|
||||||
|
for _, pkt := range packets {
|
||||||
|
if pkt == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, err := t.Write(pkt.Payload()[:pkt.Len]); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
pkt.Release()
|
||||||
|
}
|
||||||
|
return len(packets), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) BatchHeadroom() int {
|
||||||
|
if io, ok := t.batchIO(); ok {
|
||||||
|
return io.BatchHeadroom()
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) BatchPayloadCap() int {
|
||||||
|
if io, ok := t.batchIO(); ok {
|
||||||
|
return io.BatchPayloadCap()
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) BatchSize() int {
|
||||||
|
if io, ok := t.batchIO(); ok {
|
||||||
|
return io.BatchSize()
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
@@ -15,14 +15,14 @@ type wireguardTunIO struct {
|
|||||||
batchSize int
|
batchSize int
|
||||||
|
|
||||||
readMu sync.Mutex
|
readMu sync.Mutex
|
||||||
readBufs [][]byte
|
readBuffers [][]byte
|
||||||
readLens []int
|
readLens []int
|
||||||
pending [][]byte
|
legacyBuf []byte
|
||||||
pendIdx int
|
|
||||||
|
|
||||||
writeMu sync.Mutex
|
writeMu sync.Mutex
|
||||||
writeBuf []byte
|
writeBuf []byte
|
||||||
writeWrap [][]byte
|
writeWrap [][]byte
|
||||||
|
writeBuffers [][]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func newWireguardTunIO(dev wgtun.Device, mtu int) *wireguardTunIO {
|
func newWireguardTunIO(dev wgtun.Device, mtu int) *wireguardTunIO {
|
||||||
@@ -33,17 +33,12 @@ func newWireguardTunIO(dev wgtun.Device, mtu int) *wireguardTunIO {
|
|||||||
if mtu <= 0 {
|
if mtu <= 0 {
|
||||||
mtu = DefaultMTU
|
mtu = DefaultMTU
|
||||||
}
|
}
|
||||||
bufs := make([][]byte, batch)
|
|
||||||
for i := range bufs {
|
|
||||||
bufs[i] = make([]byte, wgtun.VirtioNetHdrLen+mtu)
|
|
||||||
}
|
|
||||||
return &wireguardTunIO{
|
return &wireguardTunIO{
|
||||||
dev: dev,
|
dev: dev,
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
batchSize: batch,
|
batchSize: batch,
|
||||||
readBufs: bufs,
|
|
||||||
readLens: make([]int, batch),
|
readLens: make([]int, batch),
|
||||||
pending: make([][]byte, 0, batch),
|
legacyBuf: make([]byte, wgtun.VirtioNetHdrLen+mtu),
|
||||||
writeBuf: make([]byte, wgtun.VirtioNetHdrLen+mtu),
|
writeBuf: make([]byte, wgtun.VirtioNetHdrLen+mtu),
|
||||||
writeWrap: make([][]byte, 1),
|
writeWrap: make([][]byte, 1),
|
||||||
}
|
}
|
||||||
@@ -53,29 +48,21 @@ func (w *wireguardTunIO) Read(p []byte) (int, error) {
|
|||||||
w.readMu.Lock()
|
w.readMu.Lock()
|
||||||
defer w.readMu.Unlock()
|
defer w.readMu.Unlock()
|
||||||
|
|
||||||
for {
|
bufs := w.readBuffers
|
||||||
if w.pendIdx < len(w.pending) {
|
if len(bufs) == 0 {
|
||||||
segment := w.pending[w.pendIdx]
|
bufs = [][]byte{w.legacyBuf}
|
||||||
w.pendIdx++
|
w.readBuffers = bufs
|
||||||
n := copy(p, segment)
|
|
||||||
return n, nil
|
|
||||||
}
|
}
|
||||||
|
n, err := w.dev.Read(bufs[:1], w.readLens[:1], wgtun.VirtioNetHdrLen)
|
||||||
n, err := w.dev.Read(w.readBufs, w.readLens, wgtun.VirtioNetHdrLen)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
w.pending = w.pending[:0]
|
if n == 0 {
|
||||||
w.pendIdx = 0
|
return 0, nil
|
||||||
for i := 0; i < n; i++ {
|
|
||||||
length := w.readLens[i]
|
|
||||||
if length == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
segment := w.readBufs[i][wgtun.VirtioNetHdrLen : wgtun.VirtioNetHdrLen+length]
|
|
||||||
w.pending = append(w.pending, segment)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
length := w.readLens[0]
|
||||||
|
copy(p, w.legacyBuf[wgtun.VirtioNetHdrLen:wgtun.VirtioNetHdrLen+length])
|
||||||
|
return length, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *wireguardTunIO) Write(p []byte) (int, error) {
|
func (w *wireguardTunIO) Write(p []byte) (int, error) {
|
||||||
@@ -97,6 +84,134 @@ func (w *wireguardTunIO) Write(p []byte) (int, error) {
|
|||||||
return len(p), nil
|
return len(p), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *wireguardTunIO) ReadIntoBatch(pool *PacketPool) ([]*Packet, error) {
|
||||||
|
if pool == nil {
|
||||||
|
return nil, fmt.Errorf("wireguard tun: packet pool is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
w.readMu.Lock()
|
||||||
|
defer w.readMu.Unlock()
|
||||||
|
|
||||||
|
if len(w.readBuffers) < w.batchSize {
|
||||||
|
w.readBuffers = make([][]byte, w.batchSize)
|
||||||
|
}
|
||||||
|
if len(w.readLens) < w.batchSize {
|
||||||
|
w.readLens = make([]int, w.batchSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
packets := make([]*Packet, w.batchSize)
|
||||||
|
requiredHeadroom := w.BatchHeadroom()
|
||||||
|
requiredPayload := w.BatchPayloadCap()
|
||||||
|
headroom := 0
|
||||||
|
for i := 0; i < w.batchSize; i++ {
|
||||||
|
pkt := pool.Get()
|
||||||
|
if pkt == nil {
|
||||||
|
releasePackets(packets[:i])
|
||||||
|
return nil, fmt.Errorf("wireguard tun: packet pool returned nil packet")
|
||||||
|
}
|
||||||
|
if pkt.Capacity() < requiredPayload {
|
||||||
|
pkt.Release()
|
||||||
|
releasePackets(packets[:i])
|
||||||
|
return nil, fmt.Errorf("wireguard tun: packet capacity %d below required %d", pkt.Capacity(), requiredPayload)
|
||||||
|
}
|
||||||
|
if i == 0 {
|
||||||
|
headroom = pkt.Offset
|
||||||
|
if headroom < requiredHeadroom {
|
||||||
|
pkt.Release()
|
||||||
|
releasePackets(packets[:i])
|
||||||
|
return nil, fmt.Errorf("wireguard tun: packet headroom %d below virtio requirement %d", headroom, requiredHeadroom)
|
||||||
|
}
|
||||||
|
} else if pkt.Offset != headroom {
|
||||||
|
pkt.Release()
|
||||||
|
releasePackets(packets[:i])
|
||||||
|
return nil, fmt.Errorf("wireguard tun: inconsistent packet headroom (%d != %d)", pkt.Offset, headroom)
|
||||||
|
}
|
||||||
|
packets[i] = pkt
|
||||||
|
w.readBuffers[i] = pkt.Buf
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := w.dev.Read(w.readBuffers[:w.batchSize], w.readLens[:w.batchSize], headroom)
|
||||||
|
if err != nil {
|
||||||
|
releasePackets(packets)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
releasePackets(packets)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
packets[i].Len = w.readLens[i]
|
||||||
|
}
|
||||||
|
for i := n; i < w.batchSize; i++ {
|
||||||
|
packets[i].Release()
|
||||||
|
packets[i] = nil
|
||||||
|
}
|
||||||
|
return packets[:n], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wireguardTunIO) WriteBatch(packets []*Packet) (int, error) {
|
||||||
|
if len(packets) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
requiredHeadroom := w.BatchHeadroom()
|
||||||
|
offset := packets[0].Offset
|
||||||
|
if offset < requiredHeadroom {
|
||||||
|
releasePackets(packets)
|
||||||
|
return 0, fmt.Errorf("wireguard tun: packet offset %d smaller than required headroom %d", offset, requiredHeadroom)
|
||||||
|
}
|
||||||
|
for _, pkt := range packets {
|
||||||
|
if pkt == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if pkt.Offset != offset {
|
||||||
|
releasePackets(packets)
|
||||||
|
return 0, fmt.Errorf("wireguard tun: mixed packet offsets not supported")
|
||||||
|
}
|
||||||
|
limit := pkt.Offset + pkt.Len
|
||||||
|
if limit > len(pkt.Buf) {
|
||||||
|
releasePackets(packets)
|
||||||
|
return 0, fmt.Errorf("wireguard tun: packet length %d exceeds buffer capacity %d", pkt.Len, len(pkt.Buf)-pkt.Offset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.writeMu.Lock()
|
||||||
|
defer w.writeMu.Unlock()
|
||||||
|
|
||||||
|
if len(w.writeBuffers) < len(packets) {
|
||||||
|
w.writeBuffers = make([][]byte, len(packets))
|
||||||
|
}
|
||||||
|
for i, pkt := range packets {
|
||||||
|
if pkt == nil {
|
||||||
|
w.writeBuffers[i] = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
limit := pkt.Offset + pkt.Len
|
||||||
|
w.writeBuffers[i] = pkt.Buf[:limit]
|
||||||
|
}
|
||||||
|
n, err := w.dev.Write(w.writeBuffers[:len(packets)], offset)
|
||||||
|
releasePackets(packets)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wireguardTunIO) BatchHeadroom() int {
|
||||||
|
return wgtun.VirtioNetHdrLen
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wireguardTunIO) BatchPayloadCap() int {
|
||||||
|
return w.mtu
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wireguardTunIO) BatchSize() int {
|
||||||
|
return w.batchSize
|
||||||
|
}
|
||||||
|
|
||||||
func (w *wireguardTunIO) Close() error {
|
func (w *wireguardTunIO) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func releasePackets(pkts []*Packet) {
|
||||||
|
for _, pkt := range pkts {
|
||||||
|
if pkt != nil {
|
||||||
|
pkt.Release()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
12
udp/conn.go
12
udp/conn.go
@@ -22,6 +22,18 @@ type Conn interface {
|
|||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Datagram represents a UDP payload destined to a specific address.
|
||||||
|
type Datagram struct {
|
||||||
|
Payload []byte
|
||||||
|
Addr netip.AddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchConn can send multiple datagrams in one syscall.
|
||||||
|
type BatchConn interface {
|
||||||
|
Conn
|
||||||
|
WriteBatch(pkts []Datagram) error
|
||||||
|
}
|
||||||
|
|
||||||
type NoopConn struct{}
|
type NoopConn struct{}
|
||||||
|
|
||||||
func (NoopConn) Rebind() error {
|
func (NoopConn) Rebind() error {
|
||||||
|
|||||||
@@ -20,8 +20,12 @@ type WGConn struct {
|
|||||||
bind *wgconn.StdNetBind
|
bind *wgconn.StdNetBind
|
||||||
recvers []wgconn.ReceiveFunc
|
recvers []wgconn.ReceiveFunc
|
||||||
batch int
|
batch int
|
||||||
|
reqBatch int
|
||||||
localIP netip.Addr
|
localIP netip.Addr
|
||||||
localPort uint16
|
localPort uint16
|
||||||
|
enableGSO bool
|
||||||
|
enableGRO bool
|
||||||
|
gsoMaxSeg int
|
||||||
closed atomic.Bool
|
closed atomic.Bool
|
||||||
|
|
||||||
closeOnce sync.Once
|
closeOnce sync.Once
|
||||||
@@ -34,7 +38,9 @@ func NewWireguardListener(l *logrus.Logger, ip netip.Addr, port int, multi bool,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if batch <= 0 || batch > bind.BatchSize() {
|
if batch <= 0 {
|
||||||
|
batch = bind.BatchSize()
|
||||||
|
} else if batch > bind.BatchSize() {
|
||||||
batch = bind.BatchSize()
|
batch = bind.BatchSize()
|
||||||
}
|
}
|
||||||
return &WGConn{
|
return &WGConn{
|
||||||
@@ -42,6 +48,7 @@ func NewWireguardListener(l *logrus.Logger, ip netip.Addr, port int, multi bool,
|
|||||||
bind: bind,
|
bind: bind,
|
||||||
recvers: recvers,
|
recvers: recvers,
|
||||||
batch: batch,
|
batch: batch,
|
||||||
|
reqBatch: batch,
|
||||||
localIP: ip,
|
localIP: ip,
|
||||||
localPort: actualPort,
|
localPort: actualPort,
|
||||||
}, nil
|
}, nil
|
||||||
@@ -118,6 +125,92 @@ func (c *WGConn) WriteTo(b []byte, addr netip.AddrPort) error {
|
|||||||
return c.bind.Send([][]byte{b}, ep)
|
return c.bind.Send([][]byte{b}, ep)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *WGConn) WriteBatch(datagrams []Datagram) error {
|
||||||
|
if len(datagrams) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if c.closed.Load() {
|
||||||
|
return net.ErrClosed
|
||||||
|
}
|
||||||
|
max := c.batch
|
||||||
|
if max <= 0 {
|
||||||
|
max = len(datagrams)
|
||||||
|
if max == 0 {
|
||||||
|
max = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bufs := make([][]byte, 0, max)
|
||||||
|
var (
|
||||||
|
current netip.AddrPort
|
||||||
|
endpoint *wgconn.StdNetEndpoint
|
||||||
|
haveAddr bool
|
||||||
|
)
|
||||||
|
flush := func() error {
|
||||||
|
if len(bufs) == 0 || endpoint == nil {
|
||||||
|
bufs = bufs[:0]
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
err := c.bind.Send(bufs, endpoint)
|
||||||
|
bufs = bufs[:0]
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, d := range datagrams {
|
||||||
|
if len(d.Payload) == 0 || !d.Addr.IsValid() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !haveAddr || d.Addr != current {
|
||||||
|
if err := flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
current = d.Addr
|
||||||
|
endpoint = &wgconn.StdNetEndpoint{AddrPort: current}
|
||||||
|
haveAddr = true
|
||||||
|
}
|
||||||
|
bufs = append(bufs, d.Payload)
|
||||||
|
if len(bufs) >= max {
|
||||||
|
if err := flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WGConn) ConfigureOffload(enableGSO, enableGRO bool, maxSegments int) {
|
||||||
|
c.enableGSO = enableGSO
|
||||||
|
c.enableGRO = enableGRO
|
||||||
|
if maxSegments <= 0 {
|
||||||
|
maxSegments = 1
|
||||||
|
} else if maxSegments > wgconn.IdealBatchSize {
|
||||||
|
maxSegments = wgconn.IdealBatchSize
|
||||||
|
}
|
||||||
|
c.gsoMaxSeg = maxSegments
|
||||||
|
|
||||||
|
effectiveBatch := c.reqBatch
|
||||||
|
if enableGSO && c.bind != nil {
|
||||||
|
bindBatch := c.bind.BatchSize()
|
||||||
|
if effectiveBatch < bindBatch {
|
||||||
|
if c.l != nil {
|
||||||
|
c.l.WithFields(logrus.Fields{
|
||||||
|
"requested": c.reqBatch,
|
||||||
|
"effective": bindBatch,
|
||||||
|
}).Warn("listen.batch below wireguard minimum; using bind batch size for UDP GSO support")
|
||||||
|
}
|
||||||
|
effectiveBatch = bindBatch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.batch = effectiveBatch
|
||||||
|
|
||||||
|
if c.l != nil {
|
||||||
|
c.l.WithFields(logrus.Fields{
|
||||||
|
"enableGSO": enableGSO,
|
||||||
|
"enableGRO": enableGRO,
|
||||||
|
"gsoMaxSegments": maxSegments,
|
||||||
|
}).Debug("configured wireguard UDP offload")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *WGConn) ReloadConfig(*config.C) {
|
func (c *WGConn) ReloadConfig(*config.C) {
|
||||||
// WireGuard bind currently does not expose runtime configuration knobs.
|
// WireGuard bind currently does not expose runtime configuration knobs.
|
||||||
}
|
}
|
||||||
|
|||||||
12
wgstack/conn/errors_default.go
Normal file
12
wgstack/conn/errors_default.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
func errShouldDisableUDPGSO(err error) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
26
wgstack/conn/errors_linux.go
Normal file
26
wgstack/conn/errors_linux.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func errShouldDisableUDPGSO(err error) bool {
|
||||||
|
var serr *os.SyscallError
|
||||||
|
if errors.As(err, &serr) {
|
||||||
|
// EIO is returned by udp_send_skb() if the device driver does not have
|
||||||
|
// tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
|
||||||
|
// See:
|
||||||
|
// https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
|
||||||
|
// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
|
||||||
|
return serr.Err == unix.EIO
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
15
wgstack/conn/features_default.go
Normal file
15
wgstack/conn/features_default.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
//go:build !linux
|
||||||
|
// +build !linux
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import "net"
|
||||||
|
|
||||||
|
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
|
||||||
|
return
|
||||||
|
}
|
||||||
29
wgstack/conn/features_linux.go
Normal file
29
wgstack/conn/features_linux.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
|
||||||
|
rc, err := conn.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = rc.Control(func(fd uintptr) {
|
||||||
|
_, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
|
||||||
|
txOffload = errSyscall == nil
|
||||||
|
opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO)
|
||||||
|
rxOffload = errSyscall == nil && opt == 1
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
return txOffload, rxOffload
|
||||||
|
}
|
||||||
21
wgstack/conn/gso_default.go
Normal file
21
wgstack/conn/gso_default.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
|
||||||
|
func getGSOSize(control []byte) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
|
||||||
|
func setGSOSize(control *[]byte, gsoSize uint16) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// gsoControlSize returns the recommended buffer size for pooling sticky and UDP
|
||||||
|
// offloading control data.
|
||||||
|
const gsoControlSize = 0
|
||||||
65
wgstack/conn/gso_linux.go
Normal file
65
wgstack/conn/gso_linux.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
sizeOfGSOData = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
|
||||||
|
func getGSOSize(control []byte) (int, error) {
|
||||||
|
var (
|
||||||
|
hdr unix.Cmsghdr
|
||||||
|
data []byte
|
||||||
|
rem = control
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
for len(rem) > unix.SizeofCmsghdr {
|
||||||
|
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("error parsing socket control message: %w", err)
|
||||||
|
}
|
||||||
|
if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData {
|
||||||
|
var gso uint16
|
||||||
|
copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
|
||||||
|
return int(gso), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
|
||||||
|
// data in control untouched.
|
||||||
|
func setGSOSize(control *[]byte, gsoSize uint16) {
|
||||||
|
existingLen := len(*control)
|
||||||
|
avail := cap(*control) - existingLen
|
||||||
|
space := unix.CmsgSpace(sizeOfGSOData)
|
||||||
|
if avail < space {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
*control = (*control)[:cap(*control)]
|
||||||
|
gsoControl := (*control)[existingLen:]
|
||||||
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
|
||||||
|
hdr.Level = unix.SOL_UDP
|
||||||
|
hdr.Type = unix.UDP_SEGMENT
|
||||||
|
hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
|
||||||
|
copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
|
||||||
|
*control = (*control)[:existingLen+space]
|
||||||
|
}
|
||||||
|
|
||||||
|
// gsoControlSize returns the recommended buffer size for pooling UDP
|
||||||
|
// offloading control data.
|
||||||
|
var gsoControlSize = unix.CmsgSpace(sizeOfGSOData)
|
||||||
42
wgstack/conn/sticky_default.go
Normal file
42
wgstack/conn/sticky_default.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
//go:build !linux || android
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import "net/netip"
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) SrcIP() netip.Addr {
|
||||||
|
return netip.Addr{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) SrcIfidx() int32 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) SrcToString() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
|
||||||
|
// {get,set}srcControl feature set, but use alternatively named flags and need
|
||||||
|
// ports and require testing.
|
||||||
|
|
||||||
|
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
||||||
|
// the source information found.
|
||||||
|
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// setSrcControl parses the control for PKTINFO and if found updates ep with
|
||||||
|
// the source information found.
|
||||||
|
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// stickyControlSize returns the recommended buffer size for pooling sticky
|
||||||
|
// offloading control data.
|
||||||
|
const stickyControlSize = 0
|
||||||
|
|
||||||
|
const StdNetSupportsStickySockets = false
|
||||||
Reference in New Issue
Block a user