mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 08:24:25 +01:00
2922 lines
69 KiB
Go
2922 lines
69 KiB
Go
//go:build !android && !e2e_testing
|
|
// +build !android,!e2e_testing
|
|
|
|
package udp
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"runtime"
|
|
"sync"
|
|
"sync/atomic"
|
|
"syscall"
|
|
"time"
|
|
"unsafe"
|
|
|
|
"github.com/rcrowley/go-metrics"
|
|
"github.com/sirupsen/logrus"
|
|
"github.com/slackhq/nebula/config"
|
|
"github.com/slackhq/nebula/header"
|
|
"golang.org/x/sys/unix"
|
|
)
|
|
|
|
const (
|
|
defaultGSOMaxSegments = 8
|
|
defaultGSOMaxBytes = MTU * defaultGSOMaxSegments
|
|
defaultGROReadBufferSize = MTU * defaultGSOMaxSegments
|
|
defaultGSOFlushTimeout = 150 * time.Microsecond
|
|
linuxMaxGSOBatchBytes = 0xFFFF // Linux UDP GSO still limits the datagram payload to 64 KiB
|
|
maxSendmmsgBatch = 32
|
|
)
|
|
|
|
var (
|
|
// Global mutex to serialize io_uring initialization across all sockets
|
|
ioUringInitMu sync.Mutex
|
|
)
|
|
|
|
type StdConn struct {
|
|
sysFd int
|
|
isV4 bool
|
|
l *logrus.Logger
|
|
batch int
|
|
|
|
enableGRO bool
|
|
enableGSO bool
|
|
|
|
controlLen atomic.Int32
|
|
|
|
gsoMaxSegments int
|
|
gsoMaxBytes int
|
|
gsoFlushTimeout time.Duration
|
|
|
|
groBufSize atomic.Int64
|
|
rxBufferPool chan []byte
|
|
gsoBufferPool sync.Pool
|
|
|
|
gsoBatches metrics.Counter
|
|
gsoSegments metrics.Counter
|
|
gsoSingles metrics.Counter
|
|
groBatches metrics.Counter
|
|
groSegments metrics.Counter
|
|
gsoFallbacks metrics.Counter
|
|
gsoFallbackMu sync.Mutex
|
|
gsoFallbackReasons map[string]*atomic.Int64
|
|
gsoBatchTick atomic.Int64
|
|
gsoBatchSegmentsTick atomic.Int64
|
|
gsoSingleTick atomic.Int64
|
|
groBatchTick atomic.Int64
|
|
groSegmentsTick atomic.Int64
|
|
|
|
ioState atomic.Pointer[ioUringState]
|
|
ioRecvState atomic.Pointer[ioUringRecvState]
|
|
ioActive atomic.Bool
|
|
ioRecvActive atomic.Bool
|
|
ioAttempted atomic.Bool
|
|
ioClosing atomic.Bool
|
|
ioUringHoldoff atomic.Int64
|
|
ioUringMaxBatch atomic.Int64
|
|
|
|
sendShards []*sendShard
|
|
shardCounter atomic.Uint32
|
|
}
|
|
|
|
type sendTask struct {
|
|
buf []byte
|
|
addr netip.AddrPort
|
|
segSize int
|
|
segments int
|
|
owned bool
|
|
}
|
|
|
|
type batchSendItem struct {
|
|
task *sendTask
|
|
addr netip.AddrPort
|
|
payload []byte
|
|
control []byte
|
|
msgFlags uint32
|
|
resultBytes int
|
|
err error
|
|
}
|
|
|
|
const sendShardQueueDepth = 128
|
|
const (
|
|
ioUringDefaultMaxBatch = 32
|
|
ioUringMinMaxBatch = 1
|
|
ioUringMaxMaxBatch = 4096
|
|
ioUringDefaultHoldoff = 25 * time.Microsecond
|
|
ioUringMinHoldoff = 0
|
|
ioUringMaxHoldoff = 500 * time.Millisecond
|
|
ioUringHoldoffSpinThreshold = 50 * time.Microsecond
|
|
)
|
|
|
|
var ioUringSendmsgBatch = func(state *ioUringState, entries []ioUringBatchEntry) error {
|
|
return state.SendmsgBatch(entries)
|
|
}
|
|
|
|
type sendShard struct {
|
|
parent *StdConn
|
|
|
|
mu sync.Mutex
|
|
|
|
pendingBuf []byte
|
|
pendingSegments int
|
|
pendingAddr netip.AddrPort
|
|
pendingSegSize int
|
|
flushTimer *time.Timer
|
|
controlBuf []byte
|
|
|
|
mmsgHeaders []linuxMmsgHdr
|
|
mmsgIovecs []unix.Iovec
|
|
mmsgLengths []int
|
|
|
|
outQueue chan *sendTask
|
|
workerDone sync.WaitGroup
|
|
}
|
|
|
|
func clampIoUringBatchSize(requested int, ringEntries uint32) int {
|
|
if requested < ioUringMinMaxBatch {
|
|
requested = ioUringDefaultMaxBatch
|
|
}
|
|
if requested < ioUringMinMaxBatch {
|
|
requested = ioUringMinMaxBatch
|
|
}
|
|
if requested > ioUringMaxMaxBatch {
|
|
requested = ioUringMaxMaxBatch
|
|
}
|
|
if ringEntries > 0 && requested > int(ringEntries) {
|
|
requested = int(ringEntries)
|
|
}
|
|
if requested < ioUringMinMaxBatch {
|
|
requested = ioUringMinMaxBatch
|
|
}
|
|
return requested
|
|
}
|
|
|
|
func (s *sendShard) currentHoldoff() time.Duration {
|
|
if s.parent == nil {
|
|
return 0
|
|
}
|
|
holdoff := s.parent.ioUringHoldoff.Load()
|
|
if holdoff < 0 {
|
|
holdoff = 0
|
|
}
|
|
if holdoff <= 0 {
|
|
return 0
|
|
}
|
|
return time.Duration(holdoff)
|
|
}
|
|
|
|
func (s *sendShard) currentMaxBatch() int {
|
|
if s == nil || s.parent == nil {
|
|
return ioUringDefaultMaxBatch
|
|
}
|
|
maxBatch := s.parent.ioUringMaxBatch.Load()
|
|
if maxBatch <= 0 {
|
|
return ioUringDefaultMaxBatch
|
|
}
|
|
if maxBatch > ioUringMaxMaxBatch {
|
|
maxBatch = ioUringMaxMaxBatch
|
|
}
|
|
return int(maxBatch)
|
|
}
|
|
|
|
func (u *StdConn) initSendShards() {
|
|
shardCount := runtime.GOMAXPROCS(0)
|
|
if shardCount < 1 {
|
|
shardCount = 1
|
|
}
|
|
u.resizeSendShards(shardCount)
|
|
}
|
|
|
|
func toIPv4Mapped(v4 [4]byte) [16]byte {
|
|
return [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, v4[0], v4[1], v4[2], v4[3]}
|
|
}
|
|
|
|
func (u *StdConn) populateSockaddrInet6(sa6 *unix.RawSockaddrInet6, addr netip.Addr) {
|
|
sa6.Family = unix.AF_INET6
|
|
if addr.Is4() {
|
|
// Convert IPv4 to IPv4-mapped IPv6 format for dual-stack socket
|
|
sa6.Addr = toIPv4Mapped(addr.As4())
|
|
} else {
|
|
sa6.Addr = addr.As16()
|
|
}
|
|
sa6.Scope_id = 0
|
|
}
|
|
|
|
func (u *StdConn) selectSendShard(addr netip.AddrPort) *sendShard {
|
|
if len(u.sendShards) == 0 {
|
|
return nil
|
|
}
|
|
if len(u.sendShards) == 1 {
|
|
return u.sendShards[0]
|
|
}
|
|
idx := int(u.shardCounter.Add(1)-1) % len(u.sendShards)
|
|
if idx < 0 {
|
|
idx = -idx
|
|
}
|
|
return u.sendShards[idx]
|
|
}
|
|
|
|
func (u *StdConn) resizeSendShards(count int) {
|
|
if count <= 0 {
|
|
count = runtime.GOMAXPROCS(0)
|
|
if count < 1 {
|
|
count = 1
|
|
}
|
|
}
|
|
|
|
if len(u.sendShards) == count {
|
|
return
|
|
}
|
|
|
|
// Give existing shard workers time to fully initialize before stopping
|
|
// This prevents a race where we try to stop shards before they're ready
|
|
if len(u.sendShards) > 0 {
|
|
time.Sleep(time.Millisecond)
|
|
}
|
|
|
|
for _, shard := range u.sendShards {
|
|
if shard == nil {
|
|
continue
|
|
}
|
|
shard.mu.Lock()
|
|
if shard.pendingSegments > 0 {
|
|
if err := shard.flushPendingLocked(); err != nil {
|
|
u.l.WithError(err).Warn("Failed to flush send shard while resizing")
|
|
}
|
|
} else {
|
|
shard.stopFlushTimerLocked()
|
|
}
|
|
buf := shard.pendingBuf
|
|
shard.pendingBuf = nil
|
|
shard.mu.Unlock()
|
|
if buf != nil {
|
|
u.releaseGSOBuf(buf)
|
|
}
|
|
shard.stopSender()
|
|
}
|
|
|
|
newShards := make([]*sendShard, count)
|
|
for i := range newShards {
|
|
shard := &sendShard{parent: u}
|
|
shard.startSender()
|
|
newShards[i] = shard
|
|
}
|
|
u.sendShards = newShards
|
|
u.shardCounter.Store(0)
|
|
u.l.WithField("send_shards", count).Debug("Configured UDP send shards")
|
|
}
|
|
|
|
func (u *StdConn) setGroBufferSize(size int) {
|
|
if size < defaultGROReadBufferSize {
|
|
size = defaultGROReadBufferSize
|
|
}
|
|
u.groBufSize.Store(int64(size))
|
|
if u.rxBufferPool == nil {
|
|
poolSize := u.batch * 4
|
|
if poolSize < u.batch {
|
|
poolSize = u.batch
|
|
}
|
|
u.rxBufferPool = make(chan []byte, poolSize)
|
|
for i := 0; i < poolSize; i++ {
|
|
u.rxBufferPool <- make([]byte, size)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (u *StdConn) borrowRxBuffer(desired int) []byte {
|
|
if desired < MTU {
|
|
desired = MTU
|
|
}
|
|
if u.rxBufferPool == nil {
|
|
return make([]byte, desired)
|
|
}
|
|
buf := <-u.rxBufferPool
|
|
if cap(buf) < desired {
|
|
buf = make([]byte, desired)
|
|
}
|
|
return buf[:desired]
|
|
}
|
|
|
|
func (u *StdConn) recycleBuffer(buf []byte) {
|
|
if buf == nil {
|
|
return
|
|
}
|
|
if u.rxBufferPool == nil {
|
|
return
|
|
}
|
|
buf = buf[:cap(buf)]
|
|
desired := int(u.groBufSize.Load())
|
|
if desired < MTU {
|
|
desired = MTU
|
|
}
|
|
if cap(buf) < desired {
|
|
return
|
|
}
|
|
select {
|
|
case u.rxBufferPool <- buf[:desired]:
|
|
default:
|
|
}
|
|
}
|
|
|
|
func (u *StdConn) recycleBufferSet(bufs [][]byte) {
|
|
for i := range bufs {
|
|
u.recycleBuffer(bufs[i])
|
|
}
|
|
}
|
|
|
|
func isSocketCloseError(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
if errors.Is(err, unix.EPIPE) || errors.Is(err, unix.ENOTCONN) || errors.Is(err, unix.EINVAL) || errors.Is(err, unix.EBADF) {
|
|
return true
|
|
}
|
|
var opErr *net.OpError
|
|
if errors.As(err, &opErr) {
|
|
if errno, ok := opErr.Err.(syscall.Errno); ok {
|
|
switch errno {
|
|
case unix.EPIPE, unix.ENOTCONN, unix.EINVAL, unix.EBADF:
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (u *StdConn) recordGSOFallback(reason string) {
|
|
if u == nil {
|
|
return
|
|
}
|
|
if reason == "" {
|
|
reason = "unknown"
|
|
}
|
|
if u.gsoFallbacks != nil {
|
|
u.gsoFallbacks.Inc(1)
|
|
}
|
|
u.gsoFallbackMu.Lock()
|
|
counter, ok := u.gsoFallbackReasons[reason]
|
|
if !ok {
|
|
counter = &atomic.Int64{}
|
|
u.gsoFallbackReasons[reason] = counter
|
|
}
|
|
counter.Add(1)
|
|
u.gsoFallbackMu.Unlock()
|
|
}
|
|
|
|
func (u *StdConn) recordGSOSingle(count int) {
|
|
if u == nil || count <= 0 {
|
|
return
|
|
}
|
|
if u.gsoSingles != nil {
|
|
u.gsoSingles.Inc(int64(count))
|
|
}
|
|
u.gsoSingleTick.Add(int64(count))
|
|
}
|
|
|
|
func (u *StdConn) snapshotGSOFallbacks() map[string]int64 {
|
|
u.gsoFallbackMu.Lock()
|
|
defer u.gsoFallbackMu.Unlock()
|
|
if len(u.gsoFallbackReasons) == 0 {
|
|
return nil
|
|
}
|
|
out := make(map[string]int64, len(u.gsoFallbackReasons))
|
|
for reason, counter := range u.gsoFallbackReasons {
|
|
if counter == nil {
|
|
continue
|
|
}
|
|
count := counter.Swap(0)
|
|
if count != 0 {
|
|
out[reason] = count
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (u *StdConn) logGSOTick() {
|
|
u.gsoBatchTick.Store(0)
|
|
u.gsoBatchSegmentsTick.Store(0)
|
|
u.gsoSingleTick.Store(0)
|
|
u.groBatchTick.Store(0)
|
|
u.groSegmentsTick.Store(0)
|
|
u.snapshotGSOFallbacks()
|
|
}
|
|
|
|
func (u *StdConn) borrowGSOBuf() []byte {
|
|
size := u.gsoMaxBytes
|
|
if size <= 0 {
|
|
size = MTU
|
|
}
|
|
if v := u.gsoBufferPool.Get(); v != nil {
|
|
buf := v.([]byte)
|
|
if cap(buf) < size {
|
|
u.gsoBufferPool.Put(buf[:0])
|
|
return make([]byte, 0, size)
|
|
}
|
|
return buf[:0]
|
|
}
|
|
return make([]byte, 0, size)
|
|
}
|
|
|
|
func (u *StdConn) borrowIOBuf(size int) []byte {
|
|
if size <= 0 {
|
|
size = MTU
|
|
}
|
|
if v := u.gsoBufferPool.Get(); v != nil {
|
|
buf := v.([]byte)
|
|
if cap(buf) < size {
|
|
u.gsoBufferPool.Put(buf[:0])
|
|
return make([]byte, 0, size)
|
|
}
|
|
return buf[:0]
|
|
}
|
|
return make([]byte, 0, size)
|
|
}
|
|
|
|
func (u *StdConn) releaseGSOBuf(buf []byte) {
|
|
if buf == nil {
|
|
return
|
|
}
|
|
size := u.gsoMaxBytes
|
|
if size <= 0 {
|
|
size = MTU
|
|
}
|
|
buf = buf[:0]
|
|
if cap(buf) > size*4 {
|
|
return
|
|
}
|
|
u.gsoBufferPool.Put(buf)
|
|
}
|
|
|
|
func (s *sendShard) ensureMmsgCapacity(n int) {
|
|
if cap(s.mmsgHeaders) < n {
|
|
s.mmsgHeaders = make([]linuxMmsgHdr, n)
|
|
}
|
|
s.mmsgHeaders = s.mmsgHeaders[:n]
|
|
if cap(s.mmsgIovecs) < n {
|
|
s.mmsgIovecs = make([]unix.Iovec, n)
|
|
}
|
|
s.mmsgIovecs = s.mmsgIovecs[:n]
|
|
if cap(s.mmsgLengths) < n {
|
|
s.mmsgLengths = make([]int, n)
|
|
}
|
|
s.mmsgLengths = s.mmsgLengths[:n]
|
|
}
|
|
|
|
func (s *sendShard) ensurePendingBuf(p *StdConn) {
|
|
if s.pendingBuf == nil {
|
|
s.pendingBuf = p.borrowGSOBuf()
|
|
}
|
|
}
|
|
|
|
func (s *sendShard) startSender() {
|
|
if s.outQueue != nil {
|
|
return
|
|
}
|
|
s.outQueue = make(chan *sendTask, sendShardQueueDepth)
|
|
s.workerDone.Add(1)
|
|
go s.senderLoop()
|
|
}
|
|
|
|
func (s *sendShard) stopSender() {
|
|
s.closeSender()
|
|
s.workerDone.Wait()
|
|
}
|
|
|
|
func (s *sendShard) closeSender() {
|
|
s.mu.Lock()
|
|
queue := s.outQueue
|
|
s.outQueue = nil
|
|
s.mu.Unlock()
|
|
if queue != nil {
|
|
close(queue)
|
|
}
|
|
}
|
|
|
|
func (s *sendShard) submitTask(task *sendTask) error {
|
|
if task == nil {
|
|
return nil
|
|
}
|
|
if len(task.buf) == 0 {
|
|
if task.owned && task.buf != nil && s.parent != nil {
|
|
s.parent.releaseGSOBuf(task.buf)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
if s.parent != nil && s.parent.ioClosing.Load() {
|
|
if task.owned && task.buf != nil {
|
|
s.parent.releaseGSOBuf(task.buf)
|
|
}
|
|
return &net.OpError{Op: "sendmsg", Err: net.ErrClosed}
|
|
}
|
|
|
|
queue := s.outQueue
|
|
if queue != nil {
|
|
sent := false
|
|
func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
sent = false
|
|
}
|
|
}()
|
|
select {
|
|
case queue <- task:
|
|
sent = true
|
|
default:
|
|
}
|
|
}()
|
|
if sent {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return s.processTask(task)
|
|
}
|
|
|
|
func (s *sendShard) senderLoop() {
|
|
defer s.workerDone.Done()
|
|
initialCap := s.currentMaxBatch()
|
|
if initialCap <= 0 {
|
|
initialCap = ioUringDefaultMaxBatch
|
|
}
|
|
batch := make([]*sendTask, 0, initialCap)
|
|
var holdoffTimer *time.Timer
|
|
var holdoffCh <-chan time.Time
|
|
|
|
stopTimer := func() {
|
|
if holdoffTimer == nil {
|
|
return
|
|
}
|
|
if !holdoffTimer.Stop() {
|
|
select {
|
|
case <-holdoffTimer.C:
|
|
default:
|
|
}
|
|
}
|
|
holdoffTimer = nil
|
|
holdoffCh = nil
|
|
}
|
|
|
|
resetTimer := func() {
|
|
holdoff := s.currentHoldoff()
|
|
if holdoff <= 0 {
|
|
return
|
|
}
|
|
if holdoffTimer == nil {
|
|
holdoffTimer = time.NewTimer(holdoff)
|
|
holdoffCh = holdoffTimer.C
|
|
return
|
|
}
|
|
if !holdoffTimer.Stop() {
|
|
select {
|
|
case <-holdoffTimer.C:
|
|
default:
|
|
}
|
|
}
|
|
holdoffTimer.Reset(holdoff)
|
|
holdoffCh = holdoffTimer.C
|
|
}
|
|
|
|
flush := func() {
|
|
if len(batch) == 0 {
|
|
return
|
|
}
|
|
stopTimer()
|
|
if err := s.processTasksBatch(batch); err != nil && s.parent != nil {
|
|
s.parent.l.WithError(err).Debug("io_uring batch send encountered error")
|
|
}
|
|
for i := range batch {
|
|
batch[i] = nil
|
|
}
|
|
batch = batch[:0]
|
|
}
|
|
|
|
for {
|
|
if len(batch) == 0 {
|
|
if s.parent != nil && s.parent.ioClosing.Load() {
|
|
flush()
|
|
stopTimer()
|
|
return
|
|
}
|
|
task, ok := <-s.outQueue
|
|
if !ok {
|
|
flush()
|
|
stopTimer()
|
|
return
|
|
}
|
|
if task == nil {
|
|
continue
|
|
}
|
|
batch = append(batch, task)
|
|
maxBatch := s.currentMaxBatch()
|
|
holdoff := s.currentHoldoff()
|
|
if len(batch) >= maxBatch || holdoff <= 0 {
|
|
flush()
|
|
continue
|
|
}
|
|
if holdoff <= ioUringHoldoffSpinThreshold {
|
|
deadline := time.Now().Add(holdoff)
|
|
for {
|
|
if len(batch) >= maxBatch {
|
|
break
|
|
}
|
|
remaining := time.Until(deadline)
|
|
if remaining <= 0 {
|
|
break
|
|
}
|
|
select {
|
|
case next, ok := <-s.outQueue:
|
|
if !ok {
|
|
flush()
|
|
return
|
|
}
|
|
if next == nil {
|
|
continue
|
|
}
|
|
if s.parent != nil && s.parent.ioClosing.Load() {
|
|
flush()
|
|
return
|
|
}
|
|
batch = append(batch, next)
|
|
default:
|
|
if remaining > 5*time.Microsecond {
|
|
runtime.Gosched()
|
|
}
|
|
}
|
|
}
|
|
flush()
|
|
continue
|
|
}
|
|
resetTimer()
|
|
continue
|
|
}
|
|
|
|
select {
|
|
case task, ok := <-s.outQueue:
|
|
if !ok {
|
|
flush()
|
|
stopTimer()
|
|
return
|
|
}
|
|
if task == nil {
|
|
continue
|
|
}
|
|
if s.parent != nil && s.parent.ioClosing.Load() {
|
|
flush()
|
|
stopTimer()
|
|
return
|
|
}
|
|
batch = append(batch, task)
|
|
if len(batch) >= s.currentMaxBatch() {
|
|
flush()
|
|
} else if s.currentHoldoff() > 0 {
|
|
resetTimer()
|
|
}
|
|
case <-holdoffCh:
|
|
stopTimer()
|
|
flush()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *sendShard) processTask(task *sendTask) error {
|
|
return s.processTasksBatch([]*sendTask{task})
|
|
}
|
|
|
|
func (s *sendShard) processTasksBatch(tasks []*sendTask) error {
|
|
if len(tasks) == 0 {
|
|
return nil
|
|
}
|
|
p := s.parent
|
|
state := p.ioState.Load()
|
|
var firstErr error
|
|
if state != nil {
|
|
if err := s.processTasksBatchIOUring(state, tasks); err != nil {
|
|
firstErr = err
|
|
}
|
|
} else {
|
|
for _, task := range tasks {
|
|
if err := s.processTaskFallback(task); err != nil && firstErr == nil {
|
|
firstErr = err
|
|
}
|
|
}
|
|
}
|
|
for _, task := range tasks {
|
|
if task == nil {
|
|
continue
|
|
}
|
|
if task.owned && task.buf != nil {
|
|
p.releaseGSOBuf(task.buf)
|
|
}
|
|
task.buf = nil
|
|
}
|
|
return firstErr
|
|
}
|
|
|
|
func (s *sendShard) processTasksBatchIOUring(state *ioUringState, tasks []*sendTask) error {
|
|
capEstimate := 0
|
|
maxSeg := 1
|
|
if s.parent != nil && s.parent.ioUringMaxBatch.Load() > 0 {
|
|
maxSeg = int(s.parent.ioUringMaxBatch.Load())
|
|
}
|
|
for _, task := range tasks {
|
|
if task == nil || len(task.buf) == 0 {
|
|
continue
|
|
}
|
|
if task.segSize > 0 && task.segSize < len(task.buf) {
|
|
capEstimate += (len(task.buf) + task.segSize - 1) / task.segSize
|
|
} else {
|
|
capEstimate++
|
|
}
|
|
}
|
|
if capEstimate <= 0 {
|
|
capEstimate = len(tasks)
|
|
}
|
|
if capEstimate > maxSeg {
|
|
capEstimate = maxSeg
|
|
}
|
|
items := make([]*batchSendItem, 0, capEstimate)
|
|
for _, task := range tasks {
|
|
if task == nil || len(task.buf) == 0 {
|
|
continue
|
|
}
|
|
useGSO := s.parent.enableGSO && task.segments > 1
|
|
if useGSO {
|
|
control := make([]byte, unix.CmsgSpace(2))
|
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
|
setCmsgLen(hdr, 2)
|
|
hdr.Level = unix.SOL_UDP
|
|
hdr.Type = unix.UDP_SEGMENT
|
|
dataOff := unix.CmsgLen(0)
|
|
binary.NativeEndian.PutUint16(control[dataOff:dataOff+2], uint16(task.segSize))
|
|
items = append(items, &batchSendItem{
|
|
task: task,
|
|
addr: task.addr,
|
|
payload: task.buf,
|
|
control: control,
|
|
msgFlags: 0,
|
|
})
|
|
continue
|
|
}
|
|
|
|
segSize := task.segSize
|
|
if segSize <= 0 || segSize >= len(task.buf) {
|
|
items = append(items, &batchSendItem{
|
|
task: task,
|
|
addr: task.addr,
|
|
payload: task.buf,
|
|
})
|
|
continue
|
|
}
|
|
|
|
for offset := 0; offset < len(task.buf); offset += segSize {
|
|
end := offset + segSize
|
|
if end > len(task.buf) {
|
|
end = len(task.buf)
|
|
}
|
|
segment := task.buf[offset:end]
|
|
items = append(items, &batchSendItem{
|
|
task: task,
|
|
addr: task.addr,
|
|
payload: segment,
|
|
})
|
|
}
|
|
}
|
|
|
|
if len(items) == 0 {
|
|
return nil
|
|
}
|
|
|
|
if err := s.parent.sendMsgIOUringBatch(state, items); err != nil {
|
|
return err
|
|
}
|
|
|
|
var firstErr error
|
|
for _, item := range items {
|
|
if item.err != nil && firstErr == nil {
|
|
firstErr = item.err
|
|
}
|
|
}
|
|
if firstErr != nil {
|
|
return firstErr
|
|
}
|
|
|
|
for _, task := range tasks {
|
|
if task == nil {
|
|
continue
|
|
}
|
|
if s.parent.enableGSO && task.segments > 1 {
|
|
s.recordGSOMetrics(task)
|
|
} else {
|
|
s.parent.recordGSOSingle(task.segments)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *sendShard) processTaskFallback(task *sendTask) error {
|
|
if task == nil || len(task.buf) == 0 {
|
|
return nil
|
|
}
|
|
p := s.parent
|
|
useGSO := p.enableGSO && task.segments > 1
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if useGSO {
|
|
if err := s.sendSegmentedLocked(task.buf, task.addr, task.segSize); err != nil {
|
|
return err
|
|
}
|
|
s.recordGSOMetrics(task)
|
|
return nil
|
|
}
|
|
if err := s.sendSequentialLocked(task.buf, task.addr, task.segSize); err != nil {
|
|
return err
|
|
}
|
|
p.recordGSOSingle(task.segments)
|
|
return nil
|
|
}
|
|
|
|
func (s *sendShard) recordGSOMetrics(task *sendTask) {
|
|
p := s.parent
|
|
if p.gsoBatches != nil {
|
|
p.gsoBatches.Inc(1)
|
|
}
|
|
if p.gsoSegments != nil {
|
|
p.gsoSegments.Inc(int64(task.segments))
|
|
}
|
|
p.gsoBatchTick.Add(1)
|
|
p.gsoBatchSegmentsTick.Add(int64(task.segments))
|
|
if p.l.IsLevelEnabled(logrus.DebugLevel) {
|
|
p.l.WithFields(logrus.Fields{
|
|
"tag": "gso-debug",
|
|
"stage": "flush",
|
|
"segments": task.segments,
|
|
"segment_size": task.segSize,
|
|
"batch_bytes": len(task.buf),
|
|
"remote_addr": task.addr.String(),
|
|
}).Debug("gso batch sent")
|
|
}
|
|
}
|
|
|
|
func (s *sendShard) write(b []byte, addr netip.AddrPort) error {
|
|
if len(b) == 0 {
|
|
return nil
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
p := s.parent
|
|
|
|
if !p.enableGSO || !addr.IsValid() {
|
|
p.recordGSOSingle(1)
|
|
return p.directWrite(b, addr)
|
|
}
|
|
|
|
s.ensurePendingBuf(p)
|
|
|
|
if s.pendingSegments > 0 && s.pendingAddr != addr {
|
|
if err := s.flushPendingLocked(); err != nil {
|
|
return err
|
|
}
|
|
s.ensurePendingBuf(p)
|
|
}
|
|
|
|
if len(b) > p.gsoMaxBytes || p.gsoMaxSegments <= 1 {
|
|
if err := s.flushPendingLocked(); err != nil {
|
|
return err
|
|
}
|
|
p.recordGSOSingle(1)
|
|
return p.directWrite(b, addr)
|
|
}
|
|
|
|
if s.pendingSegments == 0 {
|
|
s.pendingAddr = addr
|
|
s.pendingSegSize = len(b)
|
|
} else if len(b) != s.pendingSegSize {
|
|
if err := s.flushPendingLocked(); err != nil {
|
|
return err
|
|
}
|
|
s.pendingAddr = addr
|
|
s.pendingSegSize = len(b)
|
|
s.ensurePendingBuf(p)
|
|
}
|
|
|
|
if len(s.pendingBuf)+len(b) > p.gsoMaxBytes {
|
|
if err := s.flushPendingLocked(); err != nil {
|
|
return err
|
|
}
|
|
s.pendingAddr = addr
|
|
s.pendingSegSize = len(b)
|
|
s.ensurePendingBuf(p)
|
|
}
|
|
|
|
s.pendingBuf = append(s.pendingBuf, b...)
|
|
s.pendingSegments++
|
|
|
|
if s.pendingSegments >= p.gsoMaxSegments {
|
|
return s.flushPendingLocked()
|
|
}
|
|
|
|
if p.gsoFlushTimeout <= 0 {
|
|
return s.flushPendingLocked()
|
|
}
|
|
|
|
s.scheduleFlushLocked()
|
|
return nil
|
|
}
|
|
|
|
func (s *sendShard) flushPendingLocked() error {
|
|
if s.pendingSegments == 0 {
|
|
s.stopFlushTimerLocked()
|
|
return nil
|
|
}
|
|
|
|
buf := s.pendingBuf
|
|
task := &sendTask{
|
|
buf: buf,
|
|
addr: s.pendingAddr,
|
|
segSize: s.pendingSegSize,
|
|
segments: s.pendingSegments,
|
|
owned: true,
|
|
}
|
|
|
|
s.pendingBuf = nil
|
|
s.pendingSegments = 0
|
|
s.pendingSegSize = 0
|
|
s.pendingAddr = netip.AddrPort{}
|
|
|
|
s.stopFlushTimerLocked()
|
|
|
|
s.mu.Unlock()
|
|
err := s.submitTask(task)
|
|
s.mu.Lock()
|
|
return err
|
|
}
|
|
|
|
func (s *sendShard) enqueueImmediate(payload []byte, addr netip.AddrPort) error {
|
|
if len(payload) == 0 {
|
|
return nil
|
|
}
|
|
if !addr.IsValid() {
|
|
return &net.OpError{Op: "sendmsg", Err: unix.EINVAL}
|
|
}
|
|
if s.parent != nil && s.parent.ioClosing.Load() {
|
|
return &net.OpError{Op: "sendmsg", Err: net.ErrClosed}
|
|
}
|
|
|
|
buf := s.parent.borrowIOBuf(len(payload))
|
|
buf = append(buf[:0], payload...)
|
|
|
|
task := &sendTask{
|
|
buf: buf,
|
|
addr: addr,
|
|
segSize: len(payload),
|
|
segments: 1,
|
|
owned: true,
|
|
}
|
|
if err := s.submitTask(task); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *sendShard) sendSegmentedIOUring(state *ioUringState, buf []byte, addr netip.AddrPort, segSize int) error {
|
|
if state == nil || len(buf) == 0 {
|
|
return nil
|
|
}
|
|
if segSize <= 0 {
|
|
segSize = len(buf)
|
|
}
|
|
if len(s.controlBuf) < unix.CmsgSpace(2) {
|
|
s.controlBuf = make([]byte, unix.CmsgSpace(2))
|
|
}
|
|
control := s.controlBuf[:unix.CmsgSpace(2)]
|
|
for i := range control {
|
|
control[i] = 0
|
|
}
|
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
|
setCmsgLen(hdr, 2)
|
|
hdr.Level = unix.SOL_UDP
|
|
hdr.Type = unix.UDP_SEGMENT
|
|
dataOff := unix.CmsgLen(0)
|
|
binary.NativeEndian.PutUint16(control[dataOff:dataOff+2], uint16(segSize))
|
|
|
|
n, err := s.parent.sendMsgIOUring(state, addr, buf, control, 0)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if n != len(buf) {
|
|
return &net.OpError{Op: "sendmsg", Err: unix.EIO}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *sendShard) sendSequentialIOUring(state *ioUringState, buf []byte, addr netip.AddrPort, segSize int) error {
|
|
if state == nil || len(buf) == 0 {
|
|
return nil
|
|
}
|
|
if segSize <= 0 {
|
|
segSize = len(buf)
|
|
}
|
|
if segSize >= len(buf) {
|
|
n, err := s.parent.sendMsgIOUring(state, addr, buf, nil, 0)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if n != len(buf) {
|
|
return &net.OpError{Op: "sendmsg", Err: unix.EIO}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
total := len(buf)
|
|
offset := 0
|
|
for offset < total {
|
|
end := offset + segSize
|
|
if end > total {
|
|
end = total
|
|
}
|
|
segment := buf[offset:end]
|
|
n, err := s.parent.sendMsgIOUring(state, addr, segment, nil, 0)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if n != len(segment) {
|
|
return &net.OpError{Op: "sendmsg", Err: unix.EIO}
|
|
}
|
|
offset = end
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *sendShard) sendSegmentedLocked(buf []byte, addr netip.AddrPort, segSize int) error {
|
|
if len(buf) == 0 {
|
|
return nil
|
|
}
|
|
if segSize <= 0 {
|
|
segSize = len(buf)
|
|
}
|
|
|
|
if len(s.controlBuf) < unix.CmsgSpace(2) {
|
|
s.controlBuf = make([]byte, unix.CmsgSpace(2))
|
|
}
|
|
control := s.controlBuf[:unix.CmsgSpace(2)]
|
|
for i := range control {
|
|
control[i] = 0
|
|
}
|
|
|
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
|
setCmsgLen(hdr, 2)
|
|
hdr.Level = unix.SOL_UDP
|
|
hdr.Type = unix.UDP_SEGMENT
|
|
|
|
dataOff := unix.CmsgLen(0)
|
|
binary.NativeEndian.PutUint16(control[dataOff:dataOff+2], uint16(segSize))
|
|
|
|
var sa unix.Sockaddr
|
|
if s.parent.isV4 {
|
|
sa4 := &unix.SockaddrInet4{Port: int(addr.Port())}
|
|
sa4.Addr = addr.Addr().As4()
|
|
sa = sa4
|
|
} else {
|
|
sa6 := &unix.SockaddrInet6{Port: int(addr.Port())}
|
|
sa6.Addr = addr.Addr().As16()
|
|
sa = sa6
|
|
}
|
|
|
|
for {
|
|
n, err := unix.SendmsgN(s.parent.sysFd, buf, control[:unix.CmsgSpace(2)], sa, 0)
|
|
if err != nil {
|
|
if err == unix.EINTR {
|
|
continue
|
|
}
|
|
return &net.OpError{Op: "sendmsg", Err: err}
|
|
}
|
|
if n != len(buf) {
|
|
return &net.OpError{Op: "sendmsg", Err: unix.EIO}
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (s *sendShard) sendSequentialLocked(buf []byte, addr netip.AddrPort, segSize int) error {
|
|
if len(buf) == 0 {
|
|
return nil
|
|
}
|
|
if segSize <= 0 {
|
|
segSize = len(buf)
|
|
}
|
|
if segSize >= len(buf) {
|
|
return s.parent.directWrite(buf, addr)
|
|
}
|
|
|
|
var (
|
|
namePtr *byte
|
|
nameLen uint32
|
|
)
|
|
if s.parent.isV4 {
|
|
var sa4 unix.RawSockaddrInet4
|
|
sa4.Family = unix.AF_INET
|
|
sa4.Addr = addr.Addr().As4()
|
|
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa4.Port))[:], addr.Port())
|
|
namePtr = (*byte)(unsafe.Pointer(&sa4))
|
|
nameLen = uint32(unsafe.Sizeof(sa4))
|
|
} else {
|
|
var sa6 unix.RawSockaddrInet6
|
|
sa6.Family = unix.AF_INET6
|
|
sa6.Addr = addr.Addr().As16()
|
|
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa6.Port))[:], addr.Port())
|
|
namePtr = (*byte)(unsafe.Pointer(&sa6))
|
|
nameLen = uint32(unsafe.Sizeof(sa6))
|
|
}
|
|
|
|
total := len(buf)
|
|
if total == 0 {
|
|
return nil
|
|
}
|
|
basePtr := uintptr(unsafe.Pointer(&buf[0]))
|
|
offset := 0
|
|
|
|
for offset < total {
|
|
remaining := total - offset
|
|
segments := (remaining + segSize - 1) / segSize
|
|
if segments > maxSendmmsgBatch {
|
|
segments = maxSendmmsgBatch
|
|
}
|
|
|
|
s.ensureMmsgCapacity(segments)
|
|
msgs := s.mmsgHeaders[:segments]
|
|
iovecs := s.mmsgIovecs[:segments]
|
|
lens := s.mmsgLengths[:segments]
|
|
|
|
batchStart := offset
|
|
segOffset := offset
|
|
actual := 0
|
|
for actual < segments && segOffset < total {
|
|
segLen := segSize
|
|
if segLen > total-segOffset {
|
|
segLen = total - segOffset
|
|
}
|
|
|
|
msgs[actual] = linuxMmsgHdr{}
|
|
lens[actual] = segLen
|
|
iovecs[actual].Base = &buf[segOffset]
|
|
setIovecLen(&iovecs[actual], segLen)
|
|
msgs[actual].Hdr.Iov = &iovecs[actual]
|
|
setMsghdrIovlen(&msgs[actual].Hdr, 1)
|
|
msgs[actual].Hdr.Name = namePtr
|
|
msgs[actual].Hdr.Namelen = nameLen
|
|
msgs[actual].Hdr.Control = nil
|
|
msgs[actual].Hdr.Controllen = 0
|
|
msgs[actual].Hdr.Flags = 0
|
|
msgs[actual].Len = 0
|
|
|
|
actual++
|
|
segOffset += segLen
|
|
}
|
|
if actual == 0 {
|
|
break
|
|
}
|
|
msgs = msgs[:actual]
|
|
lens = lens[:actual]
|
|
|
|
retry:
|
|
sent, err := sendmmsg(s.parent.sysFd, msgs, 0)
|
|
if err != nil {
|
|
if err == unix.EINTR {
|
|
goto retry
|
|
}
|
|
return &net.OpError{Op: "sendmmsg", Err: err}
|
|
}
|
|
if sent == 0 {
|
|
goto retry
|
|
}
|
|
|
|
bytesSent := 0
|
|
for i := 0; i < sent; i++ {
|
|
bytesSent += lens[i]
|
|
}
|
|
offset = batchStart + bytesSent
|
|
|
|
if sent < len(msgs) {
|
|
for j := sent; j < len(msgs); j++ {
|
|
start := int(uintptr(unsafe.Pointer(iovecs[j].Base)) - basePtr)
|
|
if start < 0 || start >= total {
|
|
continue
|
|
}
|
|
end := start + lens[j]
|
|
if end > total {
|
|
end = total
|
|
}
|
|
if err := s.parent.directWrite(buf[start:end], addr); err != nil {
|
|
return err
|
|
}
|
|
if end > offset {
|
|
offset = end
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *sendShard) scheduleFlushLocked() {
|
|
timeout := s.parent.gsoFlushTimeout
|
|
if timeout <= 0 {
|
|
_ = s.flushPendingLocked()
|
|
return
|
|
}
|
|
if s.flushTimer == nil {
|
|
s.flushTimer = time.AfterFunc(timeout, s.flushTimerHandler)
|
|
return
|
|
}
|
|
if !s.flushTimer.Stop() {
|
|
// allow existing timer to drain
|
|
}
|
|
if !s.flushTimer.Reset(timeout) {
|
|
s.flushTimer = time.AfterFunc(timeout, s.flushTimerHandler)
|
|
}
|
|
}
|
|
|
|
func (s *sendShard) stopFlushTimerLocked() {
|
|
if s.flushTimer != nil {
|
|
s.flushTimer.Stop()
|
|
}
|
|
}
|
|
|
|
func (s *sendShard) flushTimerHandler() {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.pendingSegments == 0 {
|
|
return
|
|
}
|
|
if err := s.flushPendingLocked(); err != nil {
|
|
if !isSocketCloseError(err) {
|
|
s.parent.l.WithError(err).Warn("Failed to flush GSO batch")
|
|
}
|
|
}
|
|
}
|
|
|
|
func maybeIPV4(ip net.IP) (net.IP, bool) {
|
|
ip4 := ip.To4()
|
|
if ip4 != nil {
|
|
return ip4, true
|
|
}
|
|
return ip, false
|
|
}
|
|
|
|
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
|
af := unix.AF_INET6
|
|
if ip.Is4() {
|
|
af = unix.AF_INET
|
|
}
|
|
syscall.ForkLock.RLock()
|
|
fd, err := unix.Socket(af, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
|
|
if err == nil {
|
|
unix.CloseOnExec(fd)
|
|
}
|
|
syscall.ForkLock.RUnlock()
|
|
|
|
if err != nil {
|
|
unix.Close(fd)
|
|
return nil, fmt.Errorf("unable to open socket: %s", err)
|
|
}
|
|
|
|
if af == unix.AF_INET6 {
|
|
if err := unix.SetsockoptInt(fd, unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 0); err != nil {
|
|
l.WithError(err).Warn("Failed to clear IPV6_V6ONLY on IPv6 UDP socket")
|
|
} else if v6only, err := unix.GetsockoptInt(fd, unix.IPPROTO_IPV6, unix.IPV6_V6ONLY); err == nil {
|
|
l.WithField("v6only", v6only).Debug("Configured IPv6 UDP socket V6ONLY state")
|
|
}
|
|
}
|
|
|
|
if multi {
|
|
if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
|
|
return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err)
|
|
}
|
|
}
|
|
|
|
var sa unix.Sockaddr
|
|
if ip.Is4() {
|
|
sa4 := &unix.SockaddrInet4{Port: port}
|
|
sa4.Addr = ip.As4()
|
|
sa = sa4
|
|
} else {
|
|
sa6 := &unix.SockaddrInet6{Port: port}
|
|
sa6.Addr = ip.As16()
|
|
sa = sa6
|
|
}
|
|
if err = unix.Bind(fd, sa); err != nil {
|
|
return nil, fmt.Errorf("unable to bind to socket: %s", err)
|
|
}
|
|
|
|
if ip.Is4() && udpChecksumDisabled() {
|
|
if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_NO_CHECK, 1); err != nil {
|
|
l.WithError(err).Warn("Failed to disable IPv4 UDP checksum via SO_NO_CHECK")
|
|
} else {
|
|
l.Debug("Disabled IPv4 UDP checksum using SO_NO_CHECK")
|
|
}
|
|
}
|
|
|
|
conn := &StdConn{
|
|
sysFd: fd,
|
|
isV4: ip.Is4(),
|
|
l: l,
|
|
batch: batch,
|
|
gsoMaxSegments: defaultGSOMaxSegments,
|
|
gsoMaxBytes: defaultGSOMaxBytes,
|
|
gsoFlushTimeout: defaultGSOFlushTimeout,
|
|
gsoBatches: metrics.GetOrRegisterCounter("udp.gso.batches", nil),
|
|
gsoSegments: metrics.GetOrRegisterCounter("udp.gso.segments", nil),
|
|
gsoSingles: metrics.GetOrRegisterCounter("udp.gso.singles", nil),
|
|
groBatches: metrics.GetOrRegisterCounter("udp.gro.batches", nil),
|
|
groSegments: metrics.GetOrRegisterCounter("udp.gro.segments", nil),
|
|
gsoFallbacks: metrics.GetOrRegisterCounter("udp.gso.fallbacks", nil),
|
|
gsoFallbackReasons: make(map[string]*atomic.Int64),
|
|
}
|
|
conn.ioUringHoldoff.Store(int64(ioUringDefaultHoldoff))
|
|
conn.ioUringMaxBatch.Store(int64(ioUringDefaultMaxBatch))
|
|
conn.setGroBufferSize(defaultGROReadBufferSize)
|
|
conn.initSendShards()
|
|
return conn, err
|
|
}
|
|
|
|
func (u *StdConn) Rebind() error {
|
|
return nil
|
|
}
|
|
|
|
func (u *StdConn) SetRecvBuffer(n int) error {
|
|
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n)
|
|
}
|
|
|
|
func (u *StdConn) SetSendBuffer(n int) error {
|
|
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n)
|
|
}
|
|
|
|
func (u *StdConn) SetSoMark(mark int) error {
|
|
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_MARK, mark)
|
|
}
|
|
|
|
func (u *StdConn) GetRecvBuffer() (int, error) {
|
|
return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF)
|
|
}
|
|
|
|
func (u *StdConn) GetSendBuffer() (int, error) {
|
|
return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF)
|
|
}
|
|
|
|
func (u *StdConn) GetSoMark() (int, error) {
|
|
return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_MARK)
|
|
}
|
|
|
|
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
|
sa, err := unix.Getsockname(u.sysFd)
|
|
if err != nil {
|
|
return netip.AddrPort{}, err
|
|
}
|
|
|
|
switch sa := sa.(type) {
|
|
case *unix.SockaddrInet4:
|
|
return netip.AddrPortFrom(netip.AddrFrom4(sa.Addr), uint16(sa.Port)), nil
|
|
|
|
case *unix.SockaddrInet6:
|
|
return netip.AddrPortFrom(netip.AddrFrom16(sa.Addr), uint16(sa.Port)), nil
|
|
|
|
default:
|
|
return netip.AddrPort{}, fmt.Errorf("unsupported sock type: %T", sa)
|
|
}
|
|
}
|
|
|
|
func (u *StdConn) ListenOut(r EncReader) {
|
|
var ip netip.Addr
|
|
|
|
// Check if io_uring receive ring is available
|
|
recvRing := u.ioRecvState.Load()
|
|
useIoUringRecv := recvRing != nil && u.ioRecvActive.Load()
|
|
|
|
u.l.WithFields(logrus.Fields{
|
|
"batch": u.batch,
|
|
"io_uring_send": u.ioState.Load() != nil,
|
|
"io_uring_recv": useIoUringRecv,
|
|
}).Info("ListenOut starting")
|
|
|
|
if useIoUringRecv {
|
|
// Use dedicated io_uring receive ring
|
|
u.l.Info("ListenOut: using io_uring receive path")
|
|
|
|
// Pre-fill the receive queue now that we're ready to receive
|
|
if err := recvRing.fillRecvQueue(); err != nil {
|
|
u.l.WithError(err).Error("Failed to fill receive queue")
|
|
return
|
|
}
|
|
|
|
for {
|
|
// Receive packets from io_uring (wait=true blocks until at least one packet arrives)
|
|
packets, err := recvRing.receivePackets(true)
|
|
if err != nil {
|
|
u.l.WithError(err).Error("io_uring receive failed")
|
|
return
|
|
}
|
|
|
|
if len(packets) > 0 && u.l.IsLevelEnabled(logrus.DebugLevel) {
|
|
totalBytes := 0
|
|
groPackets := 0
|
|
groSegments := 0
|
|
for i := range packets {
|
|
totalBytes += packets[i].N
|
|
if packets[i].Controllen > 0 {
|
|
if _, segCount := u.parseGROSegmentFromControl(packets[i].Control, packets[i].Controllen); segCount > 1 {
|
|
groPackets++
|
|
groSegments += segCount
|
|
}
|
|
}
|
|
}
|
|
fields := logrus.Fields{
|
|
"entry_count": len(packets),
|
|
"payload_bytes": totalBytes,
|
|
}
|
|
if groPackets > 0 {
|
|
fields["gro_packets"] = groPackets
|
|
fields["gro_segments"] = groSegments
|
|
}
|
|
u.l.WithFields(fields).Debug("io_uring recv batch")
|
|
}
|
|
|
|
for _, pkt := range packets {
|
|
// Extract address from RawSockaddrInet6
|
|
if pkt.From.Family != unix.AF_INET6 {
|
|
u.l.WithField("family", pkt.From.Family).Warn("Received packet with unexpected address family")
|
|
continue
|
|
}
|
|
|
|
ip, _ = netip.AddrFromSlice(pkt.From.Addr[:])
|
|
addr := netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16((*[2]byte)(unsafe.Pointer(&pkt.From.Port))[:]))
|
|
payload := pkt.Data[:pkt.N]
|
|
release := pkt.RecycleFunc
|
|
released := false
|
|
releaseOnce := func() {
|
|
if !released {
|
|
released = true
|
|
release()
|
|
}
|
|
}
|
|
|
|
// Check for GRO segments
|
|
handled := false
|
|
if pkt.Controllen > 0 && len(pkt.Control) > 0 {
|
|
if segSize, segCount := u.parseGROSegmentFromControl(pkt.Control, pkt.Controllen); segSize > 0 && segSize < pkt.N {
|
|
if segCount > 1 && u.l.IsLevelEnabled(logrus.DebugLevel) {
|
|
u.l.WithFields(logrus.Fields{
|
|
"segments": segCount,
|
|
"segment_size": segSize,
|
|
"batch_bytes": pkt.N,
|
|
"remote_addr": addr.String(),
|
|
}).Debug("gro batch received")
|
|
}
|
|
if u.emitSegments(r, addr, payload, segSize, segCount, releaseOnce) {
|
|
handled = true
|
|
} else if segCount > 1 {
|
|
u.l.WithFields(logrus.Fields{
|
|
"tag": "gro-debug",
|
|
"stage": "io_uring_recv",
|
|
"reason": "emit_failed",
|
|
"payload_len": pkt.N,
|
|
"seg_size": segSize,
|
|
"seg_count": segCount,
|
|
}).Debug("gro-debug fallback to single packet")
|
|
}
|
|
}
|
|
}
|
|
|
|
if !handled {
|
|
r(addr, payload, releaseOnce)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Fallback path: use standard recvmsg
|
|
u.l.Info("ListenOut: using standard recvmsg path")
|
|
msgs, buffers, names, controls := u.PrepareRawMessages(u.batch)
|
|
read := u.ReadMulti
|
|
if u.batch == 1 {
|
|
read = u.ReadSingle
|
|
}
|
|
|
|
u.l.WithFields(logrus.Fields{
|
|
"using_ReadSingle": u.batch == 1,
|
|
"using_ReadMulti": u.batch != 1,
|
|
}).Info("ListenOut read function selected")
|
|
|
|
for {
|
|
desiredGroSize := int(u.groBufSize.Load())
|
|
if desiredGroSize < MTU {
|
|
desiredGroSize = MTU
|
|
}
|
|
if len(buffers) == 0 || cap(buffers[0]) < desiredGroSize {
|
|
u.recycleBufferSet(buffers)
|
|
msgs, buffers, names, controls = u.PrepareRawMessages(u.batch)
|
|
}
|
|
desiredControl := int(u.controlLen.Load())
|
|
hasControl := len(controls) > 0
|
|
if (desiredControl > 0) != hasControl || (desiredControl > 0 && hasControl && len(controls[0]) != desiredControl) {
|
|
u.recycleBufferSet(buffers)
|
|
msgs, buffers, names, controls = u.PrepareRawMessages(u.batch)
|
|
hasControl = len(controls) > 0
|
|
}
|
|
|
|
if hasControl {
|
|
for i := range msgs {
|
|
if len(controls) <= i || len(controls[i]) == 0 {
|
|
continue
|
|
}
|
|
msgs[i].Hdr.Controllen = controllen(len(controls[i]))
|
|
}
|
|
}
|
|
|
|
u.l.Debug("ListenOut: about to call read(msgs)")
|
|
n, err := read(msgs)
|
|
if err != nil {
|
|
u.l.WithError(err).Error("ListenOut: read(msgs) failed, exiting read loop")
|
|
u.recycleBufferSet(buffers)
|
|
return
|
|
}
|
|
u.l.WithField("packets_read", n).Debug("ListenOut: read(msgs) returned")
|
|
|
|
for i := 0; i < n; i++ {
|
|
payloadLen := int(msgs[i].Len)
|
|
if payloadLen == 0 {
|
|
continue
|
|
}
|
|
|
|
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
|
|
if u.isV4 {
|
|
ip, _ = netip.AddrFromSlice(names[i][4:8])
|
|
} else {
|
|
ip, _ = netip.AddrFromSlice(names[i][8:24])
|
|
}
|
|
addr := netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
|
|
buf := buffers[i]
|
|
payload := buf[:payloadLen]
|
|
released := false
|
|
release := func() {
|
|
if !released {
|
|
released = true
|
|
u.recycleBuffer(buf)
|
|
}
|
|
}
|
|
handled := false
|
|
|
|
if len(controls) > i && len(controls[i]) > 0 {
|
|
if segSize, segCount := u.parseGROSegment(&msgs[i], controls[i]); segSize > 0 && segSize < payloadLen {
|
|
if segCount > 1 && u.l.IsLevelEnabled(logrus.DebugLevel) {
|
|
u.l.WithFields(logrus.Fields{
|
|
"segments": segCount,
|
|
"segment_size": segSize,
|
|
"batch_bytes": payloadLen,
|
|
"remote_addr": addr.String(),
|
|
}).Debug("gro batch received")
|
|
}
|
|
if u.emitSegments(r, addr, payload, segSize, segCount, release) {
|
|
handled = true
|
|
} else if segCount > 1 {
|
|
u.l.WithFields(logrus.Fields{
|
|
"tag": "gro-debug",
|
|
"stage": "listen_out",
|
|
"reason": "emit_failed",
|
|
"payload_len": payloadLen,
|
|
"seg_size": segSize,
|
|
"seg_count": segCount,
|
|
}).Debug("gro-debug fallback to single packet")
|
|
}
|
|
}
|
|
}
|
|
|
|
if !handled {
|
|
r(addr, payload, release)
|
|
}
|
|
|
|
buffers[i] = u.borrowRxBuffer(desiredGroSize)
|
|
setIovecBase(&msgs[i], buffers[i])
|
|
}
|
|
}
|
|
}
|
|
func isEAgain(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
var opErr *net.OpError
|
|
if errors.As(err, &opErr) {
|
|
if errno, ok := opErr.Err.(syscall.Errno); ok {
|
|
return errno == unix.EAGAIN || errno == unix.EWOULDBLOCK
|
|
}
|
|
}
|
|
if errno, ok := err.(syscall.Errno); ok {
|
|
return errno == unix.EAGAIN || errno == unix.EWOULDBLOCK
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (u *StdConn) readSingleSyscall(msgs []rawMessage) (int, error) {
|
|
if len(msgs) == 0 {
|
|
return 0, nil
|
|
}
|
|
for {
|
|
n, _, errno := unix.Syscall6(
|
|
unix.SYS_RECVMSG,
|
|
uintptr(u.sysFd),
|
|
uintptr(unsafe.Pointer(&msgs[0].Hdr)),
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
)
|
|
if errno != 0 {
|
|
err := syscall.Errno(errno)
|
|
if err == unix.EINTR {
|
|
continue
|
|
}
|
|
return 0, &net.OpError{Op: "recvmsg", Err: err}
|
|
}
|
|
msgs[0].Len = uint32(n)
|
|
return 1, nil
|
|
}
|
|
}
|
|
|
|
func (u *StdConn) readMultiSyscall(msgs []rawMessage) (int, error) {
|
|
if len(msgs) == 0 {
|
|
return 0, nil
|
|
}
|
|
for {
|
|
n, _, errno := unix.Syscall6(
|
|
unix.SYS_RECVMMSG,
|
|
uintptr(u.sysFd),
|
|
uintptr(unsafe.Pointer(&msgs[0])),
|
|
uintptr(len(msgs)),
|
|
unix.MSG_WAITFORONE,
|
|
0,
|
|
0,
|
|
)
|
|
if errno != 0 {
|
|
err := syscall.Errno(errno)
|
|
if err == unix.EINTR {
|
|
continue
|
|
}
|
|
return 0, &net.OpError{Op: "recvmmsg", Err: err}
|
|
}
|
|
return int(n), nil
|
|
}
|
|
}
|
|
|
|
func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
|
|
if len(msgs) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
u.l.Debug("ReadSingle called")
|
|
|
|
state := u.ioState.Load()
|
|
if state == nil {
|
|
return u.readSingleSyscall(msgs)
|
|
}
|
|
|
|
u.l.Debug("ReadSingle: converting rawMessage to unix.Msghdr")
|
|
hdr, iov, err := rawMessageToUnixMsghdr(&msgs[0])
|
|
if err != nil {
|
|
u.l.WithError(err).Error("ReadSingle: rawMessageToUnixMsghdr failed")
|
|
return 0, &net.OpError{Op: "recvmsg", Err: err}
|
|
}
|
|
|
|
u.l.WithFields(logrus.Fields{
|
|
"bufLen": iov.Len,
|
|
"nameLen": hdr.Namelen,
|
|
"ctrlLen": hdr.Controllen,
|
|
}).Debug("ReadSingle: calling state.Recvmsg")
|
|
|
|
n, _, recvErr := state.Recvmsg(u.sysFd, &hdr, 0)
|
|
if recvErr != nil {
|
|
u.l.WithError(recvErr).Error("ReadSingle: state.Recvmsg failed")
|
|
return 0, recvErr
|
|
}
|
|
|
|
u.l.WithFields(logrus.Fields{
|
|
"bytesRead": n,
|
|
}).Debug("ReadSingle: successfully received")
|
|
|
|
updateRawMessageFromUnixMsghdr(&msgs[0], &hdr, n)
|
|
runtime.KeepAlive(iov)
|
|
runtime.KeepAlive(hdr)
|
|
return 1, nil
|
|
}
|
|
|
|
func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
|
|
if len(msgs) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
u.l.WithField("batch_size", len(msgs)).Debug("ReadMulti called")
|
|
|
|
state := u.ioState.Load()
|
|
if state == nil {
|
|
return u.readMultiSyscall(msgs)
|
|
}
|
|
|
|
count := 0
|
|
for i := range msgs {
|
|
hdr, iov, err := rawMessageToUnixMsghdr(&msgs[i])
|
|
if err != nil {
|
|
u.l.WithError(err).WithField("index", i).Error("ReadMulti: rawMessageToUnixMsghdr failed")
|
|
if count > 0 {
|
|
return count, nil
|
|
}
|
|
return 0, &net.OpError{Op: "recvmsg", Err: err}
|
|
}
|
|
|
|
flags := uint32(0)
|
|
if i > 0 {
|
|
flags = unix.MSG_DONTWAIT
|
|
}
|
|
|
|
u.l.WithFields(logrus.Fields{
|
|
"index": i,
|
|
"flags": flags,
|
|
"bufLen": iov.Len,
|
|
}).Debug("ReadMulti: calling state.Recvmsg")
|
|
|
|
n, _, recvErr := state.Recvmsg(u.sysFd, &hdr, flags)
|
|
if recvErr != nil {
|
|
u.l.WithError(recvErr).WithFields(logrus.Fields{
|
|
"index": i,
|
|
"count": count,
|
|
}).Debug("ReadMulti: state.Recvmsg error")
|
|
if isEAgain(recvErr) && count > 0 {
|
|
u.l.WithField("count", count).Debug("ReadMulti: EAGAIN with existing packets, returning")
|
|
return count, nil
|
|
}
|
|
if count > 0 {
|
|
return count, recvErr
|
|
}
|
|
return 0, recvErr
|
|
}
|
|
|
|
u.l.WithFields(logrus.Fields{
|
|
"index": i,
|
|
"bytesRead": n,
|
|
}).Debug("ReadMulti: packet received")
|
|
|
|
updateRawMessageFromUnixMsghdr(&msgs[i], &hdr, n)
|
|
runtime.KeepAlive(iov)
|
|
runtime.KeepAlive(hdr)
|
|
count++
|
|
}
|
|
|
|
u.l.WithField("total_count", count).Debug("ReadMulti: completed")
|
|
return count, nil
|
|
}
|
|
|
|
func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
|
|
if len(b) == 0 {
|
|
return nil
|
|
}
|
|
if u.ioClosing.Load() {
|
|
return &net.OpError{Op: "sendmsg", Err: net.ErrClosed}
|
|
}
|
|
if u.enableGSO {
|
|
return u.writeToGSO(b, ip)
|
|
}
|
|
if u.ioState.Load() != nil {
|
|
if shard := u.selectSendShard(ip); shard != nil {
|
|
if err := shard.enqueueImmediate(b, ip); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
u.recordGSOSingle(1)
|
|
return u.directWrite(b, ip)
|
|
}
|
|
|
|
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
|
|
var rsa unix.RawSockaddrInet6
|
|
rsa.Family = unix.AF_INET6
|
|
rsa.Addr = ip.Addr().As16()
|
|
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port())
|
|
|
|
for {
|
|
_, _, err := unix.Syscall6(
|
|
unix.SYS_SENDTO,
|
|
uintptr(u.sysFd),
|
|
uintptr(unsafe.Pointer(&b[0])),
|
|
uintptr(len(b)),
|
|
uintptr(0),
|
|
uintptr(unsafe.Pointer(&rsa)),
|
|
uintptr(unix.SizeofSockaddrInet6),
|
|
)
|
|
|
|
if err != 0 {
|
|
return &net.OpError{Op: "sendto", Err: err}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
|
|
if !ip.Addr().Is4() {
|
|
return ErrInvalidIPv6RemoteForSocket
|
|
}
|
|
|
|
var rsa unix.RawSockaddrInet4
|
|
rsa.Family = unix.AF_INET
|
|
rsa.Addr = ip.Addr().As4()
|
|
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port())
|
|
|
|
for {
|
|
_, _, err := unix.Syscall6(
|
|
unix.SYS_SENDTO,
|
|
uintptr(u.sysFd),
|
|
uintptr(unsafe.Pointer(&b[0])),
|
|
uintptr(len(b)),
|
|
uintptr(0),
|
|
uintptr(unsafe.Pointer(&rsa)),
|
|
uintptr(unix.SizeofSockaddrInet4),
|
|
)
|
|
|
|
if err != 0 {
|
|
return &net.OpError{Op: "sendto", Err: err}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (u *StdConn) writeToGSO(b []byte, addr netip.AddrPort) error {
|
|
if len(b) == 0 {
|
|
return nil
|
|
}
|
|
shard := u.selectSendShard(addr)
|
|
if shard == nil {
|
|
u.recordGSOSingle(1)
|
|
return u.directWrite(b, addr)
|
|
}
|
|
return shard.write(b, addr)
|
|
}
|
|
|
|
func (u *StdConn) sendMsgIOUring(state *ioUringState, addr netip.AddrPort, payload []byte, control []byte, msgFlags uint32) (int, error) {
|
|
if state == nil {
|
|
return 0, &net.OpError{Op: "sendmsg", Err: syscall.EINVAL}
|
|
}
|
|
if len(payload) == 0 {
|
|
return 0, nil
|
|
}
|
|
if !addr.IsValid() {
|
|
return 0, &net.OpError{Op: "sendmsg", Err: unix.EINVAL}
|
|
}
|
|
if !u.ioAttempted.Load() {
|
|
u.ioAttempted.Store(true)
|
|
u.l.WithFields(logrus.Fields{
|
|
"addr": addr.String(),
|
|
"len": len(payload),
|
|
"ctrl": control != nil,
|
|
}).Debug("io_uring send attempt")
|
|
}
|
|
u.l.WithFields(logrus.Fields{
|
|
"addr": addr.String(),
|
|
"len": len(payload),
|
|
"ctrl": control != nil,
|
|
}).Debug("io_uring sendMsgIOUring invoked")
|
|
|
|
var iov unix.Iovec
|
|
iov.Base = &payload[0]
|
|
setIovecLen(&iov, len(payload))
|
|
|
|
var msg unix.Msghdr
|
|
msg.Iov = &iov
|
|
setMsghdrIovlen(&msg, 1)
|
|
|
|
if len(control) > 0 {
|
|
msg.Control = &control[0]
|
|
msg.Controllen = controllen(len(control))
|
|
}
|
|
|
|
u.l.WithFields(logrus.Fields{
|
|
"addr": addr.String(),
|
|
"payload_len": len(payload),
|
|
"ctrl_len": len(control),
|
|
"msg_iovlen": msg.Iovlen,
|
|
"msg_controllen": msg.Controllen,
|
|
}).Debug("io_uring prepared msghdr")
|
|
|
|
var (
|
|
n int
|
|
err error
|
|
)
|
|
|
|
if u.isV4 {
|
|
if !addr.Addr().Is4() {
|
|
return 0, ErrInvalidIPv6RemoteForSocket
|
|
}
|
|
var sa4 unix.RawSockaddrInet4
|
|
sa4.Family = unix.AF_INET
|
|
sa4.Addr = addr.Addr().As4()
|
|
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa4.Port))[:], addr.Port())
|
|
msg.Name = (*byte)(unsafe.Pointer(&sa4))
|
|
msg.Namelen = uint32(unsafe.Sizeof(sa4))
|
|
u.l.WithFields(logrus.Fields{
|
|
"addr": addr.String(),
|
|
"sa_family": sa4.Family,
|
|
"sa_port": sa4.Port,
|
|
"msg_namelen": msg.Namelen,
|
|
}).Debug("io_uring sendmsg sockaddr v4")
|
|
n, err = state.Sendmsg(u.sysFd, &msg, msgFlags, uint32(len(payload)))
|
|
runtime.KeepAlive(sa4)
|
|
} else {
|
|
// For IPv6 sockets, always use RawSockaddrInet6, even for IPv4 addresses
|
|
// (convert IPv4 to IPv4-mapped IPv6 format)
|
|
var sa6 unix.RawSockaddrInet6
|
|
u.populateSockaddrInet6(&sa6, addr.Addr())
|
|
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa6.Port))[:], addr.Port())
|
|
msg.Name = (*byte)(unsafe.Pointer(&sa6))
|
|
msg.Namelen = uint32(unsafe.Sizeof(sa6))
|
|
u.l.WithFields(logrus.Fields{
|
|
"addr": addr.String(),
|
|
"sa_family": sa6.Family,
|
|
"sa_port": sa6.Port,
|
|
"scope_id": sa6.Scope_id,
|
|
"msg_namelen": msg.Namelen,
|
|
"is_v4": addr.Addr().Is4(),
|
|
}).Debug("io_uring sendmsg sockaddr v6")
|
|
n, err = state.Sendmsg(u.sysFd, &msg, msgFlags, uint32(len(payload)))
|
|
runtime.KeepAlive(sa6)
|
|
}
|
|
|
|
if err == nil && n == len(payload) {
|
|
u.noteIoUringSuccess()
|
|
}
|
|
runtime.KeepAlive(payload)
|
|
runtime.KeepAlive(control)
|
|
u.logIoUringResult(addr, len(payload), n, err)
|
|
if err == nil && n == 0 && len(payload) > 0 {
|
|
syncWritten, syncErr := u.sendMsgSync(addr, payload, control, int(msgFlags))
|
|
if syncErr == nil && syncWritten == len(payload) {
|
|
u.l.WithFields(logrus.Fields{
|
|
"addr": addr.String(),
|
|
"expected": len(payload),
|
|
"sync_written": syncWritten,
|
|
}).Warn("io_uring returned short write; used synchronous sendmsg fallback")
|
|
u.noteIoUringSuccess()
|
|
u.logIoUringResult(addr, len(payload), syncWritten, syncErr)
|
|
return syncWritten, nil
|
|
}
|
|
u.l.WithFields(logrus.Fields{
|
|
"addr": addr.String(),
|
|
"expected": len(payload),
|
|
"sync_written": syncWritten,
|
|
"sync_err": syncErr,
|
|
}).Warn("sync sendmsg result after io_uring short write")
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
func (u *StdConn) sendMsgIOUringBatch(state *ioUringState, items []*batchSendItem) error {
|
|
if u.ioClosing.Load() {
|
|
for _, item := range items {
|
|
if item != nil {
|
|
item.err = &net.OpError{Op: "sendmsg", Err: net.ErrClosed}
|
|
}
|
|
}
|
|
return &net.OpError{Op: "sendmsg", Err: net.ErrClosed}
|
|
}
|
|
if state == nil {
|
|
return &net.OpError{Op: "sendmsg", Err: syscall.EINVAL}
|
|
}
|
|
if len(items) == 0 {
|
|
return nil
|
|
}
|
|
|
|
results := make([]ioUringBatchResult, len(items))
|
|
payloads := make([][]byte, len(items))
|
|
controls := make([][]byte, len(items))
|
|
entries := make([]ioUringBatchEntry, len(items))
|
|
msgs := make([]unix.Msghdr, len(items))
|
|
iovecs := make([]unix.Iovec, len(items))
|
|
var sa4 []unix.RawSockaddrInet4
|
|
var sa6 []unix.RawSockaddrInet6
|
|
if u.isV4 {
|
|
sa4 = make([]unix.RawSockaddrInet4, len(items))
|
|
} else {
|
|
sa6 = make([]unix.RawSockaddrInet6, len(items))
|
|
}
|
|
|
|
entryIdx := 0
|
|
totalPayload := 0
|
|
skipped := 0
|
|
for i, item := range items {
|
|
if item == nil || len(item.payload) == 0 {
|
|
item.resultBytes = 0
|
|
item.err = nil
|
|
skipped++
|
|
continue
|
|
}
|
|
|
|
addr := item.addr
|
|
if !addr.IsValid() {
|
|
item.err = &net.OpError{Op: "sendmsg", Err: unix.EINVAL}
|
|
skipped++
|
|
continue
|
|
}
|
|
if u.isV4 && !addr.Addr().Is4() {
|
|
item.err = ErrInvalidIPv6RemoteForSocket
|
|
skipped++
|
|
continue
|
|
}
|
|
|
|
payload := item.payload
|
|
payloads[i] = payload
|
|
totalPayload += len(payload)
|
|
|
|
iov := &iovecs[entryIdx]
|
|
iov.Base = &payload[0]
|
|
setIovecLen(iov, len(payload))
|
|
|
|
msg := &msgs[entryIdx]
|
|
msg.Iov = iov
|
|
setMsghdrIovlen(msg, 1)
|
|
|
|
if len(item.control) > 0 {
|
|
controls[i] = item.control
|
|
msg.Control = &item.control[0]
|
|
msg.Controllen = controllen(len(item.control))
|
|
}
|
|
|
|
if u.isV4 {
|
|
sa := &sa4[entryIdx]
|
|
sa.Family = unix.AF_INET
|
|
sa.Addr = addr.Addr().As4()
|
|
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
|
|
msg.Name = (*byte)(unsafe.Pointer(sa))
|
|
msg.Namelen = uint32(unsafe.Sizeof(*sa))
|
|
} else {
|
|
sa := &sa6[entryIdx]
|
|
sa.Family = unix.AF_INET6
|
|
u.populateSockaddrInet6(sa, addr.Addr())
|
|
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
|
|
msg.Name = (*byte)(unsafe.Pointer(sa))
|
|
msg.Namelen = uint32(unsafe.Sizeof(*sa))
|
|
}
|
|
|
|
entries[entryIdx] = ioUringBatchEntry{
|
|
fd: u.sysFd,
|
|
msg: msg,
|
|
msgFlags: item.msgFlags,
|
|
payloadLen: uint32(len(payload)),
|
|
result: &results[i],
|
|
}
|
|
entryIdx++
|
|
}
|
|
|
|
if entryIdx == 0 {
|
|
for _, payload := range payloads {
|
|
runtime.KeepAlive(payload)
|
|
}
|
|
for _, control := range controls {
|
|
runtime.KeepAlive(control)
|
|
}
|
|
var firstErr error
|
|
for _, item := range items {
|
|
if item != nil && item.err != nil {
|
|
firstErr = item.err
|
|
break
|
|
}
|
|
}
|
|
return firstErr
|
|
}
|
|
|
|
if err := ioUringSendmsgBatch(state, entries[:entryIdx]); err != nil {
|
|
for _, payload := range payloads {
|
|
runtime.KeepAlive(payload)
|
|
}
|
|
for _, control := range controls {
|
|
runtime.KeepAlive(control)
|
|
}
|
|
if len(sa4) > 0 {
|
|
runtime.KeepAlive(sa4[:entryIdx])
|
|
}
|
|
if len(sa6) > 0 {
|
|
runtime.KeepAlive(sa6[:entryIdx])
|
|
}
|
|
return err
|
|
}
|
|
|
|
if u.l.IsLevelEnabled(logrus.DebugLevel) {
|
|
u.l.WithFields(logrus.Fields{
|
|
"entry_count": entryIdx,
|
|
"skipped_items": skipped,
|
|
"payload_bytes": totalPayload,
|
|
}).Debug("io_uring batch submitted")
|
|
}
|
|
|
|
var firstErr error
|
|
for i, item := range items {
|
|
if item == nil || len(item.payload) == 0 {
|
|
continue
|
|
}
|
|
if item.err != nil {
|
|
if firstErr == nil {
|
|
firstErr = item.err
|
|
}
|
|
continue
|
|
}
|
|
|
|
res := results[i]
|
|
if res.err != nil {
|
|
item.err = res.err
|
|
} else if res.res < 0 {
|
|
item.err = syscall.Errno(-res.res)
|
|
} else if int(res.res) != len(item.payload) {
|
|
item.err = fmt.Errorf("io_uring short write: wrote %d expected %d", res.res, len(item.payload))
|
|
} else {
|
|
item.err = nil
|
|
item.resultBytes = int(res.res)
|
|
}
|
|
|
|
u.logIoUringResult(item.addr, len(item.payload), int(res.res), item.err)
|
|
if item.err != nil && firstErr == nil {
|
|
firstErr = item.err
|
|
}
|
|
}
|
|
|
|
for _, payload := range payloads {
|
|
runtime.KeepAlive(payload)
|
|
}
|
|
for _, control := range controls {
|
|
runtime.KeepAlive(control)
|
|
}
|
|
if len(sa4) > 0 {
|
|
runtime.KeepAlive(sa4[:entryIdx])
|
|
}
|
|
if len(sa6) > 0 {
|
|
runtime.KeepAlive(sa6[:entryIdx])
|
|
}
|
|
|
|
if firstErr == nil {
|
|
u.noteIoUringSuccess()
|
|
}
|
|
|
|
return firstErr
|
|
}
|
|
|
|
func (u *StdConn) sendMsgSync(addr netip.AddrPort, payload []byte, control []byte, msgFlags int) (int, error) {
|
|
if len(payload) == 0 {
|
|
return 0, nil
|
|
}
|
|
if u.isV4 {
|
|
if !addr.Addr().Is4() {
|
|
return 0, ErrInvalidIPv6RemoteForSocket
|
|
}
|
|
sa := &unix.SockaddrInet4{Port: int(addr.Port())}
|
|
sa.Addr = addr.Addr().As4()
|
|
return unix.SendmsgN(u.sysFd, payload, control, sa, msgFlags)
|
|
}
|
|
sa := &unix.SockaddrInet6{Port: int(addr.Port())}
|
|
if addr.Addr().Is4() {
|
|
sa.Addr = toIPv4Mapped(addr.Addr().As4())
|
|
} else {
|
|
sa.Addr = addr.Addr().As16()
|
|
}
|
|
if zone := addr.Addr().Zone(); zone != "" {
|
|
if iface, err := net.InterfaceByName(zone); err == nil {
|
|
sa.ZoneId = uint32(iface.Index)
|
|
} else {
|
|
u.l.WithFields(logrus.Fields{
|
|
"addr": addr.Addr().String(),
|
|
"zone": zone,
|
|
}).WithError(err).Debug("io_uring failed to resolve IPv6 zone")
|
|
}
|
|
}
|
|
return unix.SendmsgN(u.sysFd, payload, control, sa, msgFlags)
|
|
}
|
|
|
|
func (u *StdConn) directWrite(b []byte, addr netip.AddrPort) error {
|
|
if len(b) == 0 {
|
|
return nil
|
|
}
|
|
if !addr.IsValid() {
|
|
return &net.OpError{Op: "sendmsg", Err: unix.EINVAL}
|
|
}
|
|
state := u.ioState.Load()
|
|
u.l.WithFields(logrus.Fields{
|
|
"addr": addr.String(),
|
|
"len": len(b),
|
|
"state_nil": state == nil,
|
|
"socket_v4": u.isV4,
|
|
"remote_is_v4": addr.Addr().Is4(),
|
|
"remote_is_v6": addr.Addr().Is6(),
|
|
}).Debug("io_uring directWrite invoked")
|
|
if state == nil {
|
|
written, err := u.sendMsgSync(addr, b, nil, 0)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if written != len(b) {
|
|
return fmt.Errorf("sendmsg short write: wrote %d expected %d", written, len(b))
|
|
}
|
|
return nil
|
|
}
|
|
n, err := u.sendMsgIOUring(state, addr, b, nil, 0)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if n != len(b) {
|
|
return fmt.Errorf("io_uring short write: wrote %d expected %d", n, len(b))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (u *StdConn) noteIoUringSuccess() {
|
|
if u == nil {
|
|
return
|
|
}
|
|
if u.ioActive.Load() {
|
|
return
|
|
}
|
|
if u.ioActive.CompareAndSwap(false, true) {
|
|
u.l.Debug("io_uring send path active")
|
|
}
|
|
}
|
|
|
|
func (u *StdConn) logIoUringResult(addr netip.AddrPort, expected, written int, err error) {
|
|
if u == nil {
|
|
return
|
|
}
|
|
u.l.WithFields(logrus.Fields{
|
|
"addr": addr.String(),
|
|
"expected": expected,
|
|
"written": written,
|
|
"err": err,
|
|
"socket_v4": u.isV4,
|
|
"remote_is_v4": addr.Addr().Is4(),
|
|
"remote_is_v6": addr.Addr().Is6(),
|
|
}).Debug("io_uring send result")
|
|
}
|
|
|
|
func (u *StdConn) emitSegments(r EncReader, addr netip.AddrPort, payload []byte, segSize, segCount int, release func()) bool {
|
|
var releaseFlag atomic.Bool
|
|
releaseOnce := func() {
|
|
if release != nil && releaseFlag.CompareAndSwap(false, true) {
|
|
release()
|
|
}
|
|
}
|
|
|
|
if segSize <= 0 || segSize >= len(payload) {
|
|
u.l.WithFields(logrus.Fields{
|
|
"tag": "gro-debug",
|
|
"stage": "emit",
|
|
"reason": "invalid_seg_size",
|
|
"payload_len": len(payload),
|
|
"seg_size": segSize,
|
|
"seg_count": segCount,
|
|
}).Debug("gro-debug skip emit")
|
|
releaseOnce()
|
|
return false
|
|
}
|
|
|
|
totalLen := len(payload)
|
|
if segCount <= 0 {
|
|
segCount = (totalLen + segSize - 1) / segSize
|
|
}
|
|
if segCount <= 1 {
|
|
u.l.WithFields(logrus.Fields{
|
|
"tag": "gro-debug",
|
|
"stage": "emit",
|
|
"reason": "single_segment",
|
|
"payload_len": totalLen,
|
|
"seg_size": segSize,
|
|
"seg_count": segCount,
|
|
}).Debug("gro-debug skip emit")
|
|
releaseOnce()
|
|
return false
|
|
}
|
|
|
|
starts := make([]int, 0, segCount)
|
|
lens := make([]int, 0, segCount)
|
|
debugEnabled := u.l.IsLevelEnabled(logrus.DebugLevel)
|
|
var firstHeader header.H
|
|
var firstParsed bool
|
|
var firstCounter uint64
|
|
var firstRemote uint32
|
|
actualSegments := 0
|
|
start := 0
|
|
|
|
for start < totalLen && actualSegments < segCount {
|
|
end := start + segSize
|
|
if end > totalLen {
|
|
end = totalLen
|
|
}
|
|
segLen := end - start
|
|
segment := payload[start:end]
|
|
|
|
if debugEnabled && !firstParsed {
|
|
if err := firstHeader.Parse(segment); err == nil {
|
|
firstParsed = true
|
|
firstCounter = firstHeader.MessageCounter
|
|
firstRemote = firstHeader.RemoteIndex
|
|
} else {
|
|
u.l.WithFields(logrus.Fields{
|
|
"tag": "gro-debug",
|
|
"stage": "emit",
|
|
"event": "parse_fail",
|
|
"seg_index": actualSegments,
|
|
"seg_size": segSize,
|
|
"seg_count": segCount,
|
|
"payload_len": totalLen,
|
|
"err": err,
|
|
}).Debug("gro-debug segment parse failed")
|
|
}
|
|
}
|
|
|
|
starts = append(starts, start)
|
|
lens = append(lens, segLen)
|
|
start = end
|
|
actualSegments++
|
|
|
|
if debugEnabled && actualSegments == segCount && segLen < segSize {
|
|
var tail header.H
|
|
if err := tail.Parse(segment); err == nil {
|
|
u.l.WithFields(logrus.Fields{
|
|
"tag": "gro-debug",
|
|
"stage": "emit",
|
|
"event": "tail_segment",
|
|
"segment_len": segLen,
|
|
"remote_index": tail.RemoteIndex,
|
|
"message_counter": tail.MessageCounter,
|
|
}).Debug("gro-debug tail segment metadata")
|
|
}
|
|
}
|
|
}
|
|
|
|
if actualSegments == 0 {
|
|
releaseOnce()
|
|
return false
|
|
}
|
|
|
|
var remaining int32 = int32(actualSegments)
|
|
for i := range starts {
|
|
segment := payload[starts[i] : starts[i]+lens[i]]
|
|
segmentRelease := func() {
|
|
if atomic.AddInt32(&remaining, -1) == 0 {
|
|
releaseOnce()
|
|
}
|
|
}
|
|
r(addr, segment, segmentRelease)
|
|
}
|
|
|
|
if u.groBatches != nil {
|
|
u.groBatches.Inc(1)
|
|
}
|
|
if u.groSegments != nil {
|
|
u.groSegments.Inc(int64(actualSegments))
|
|
}
|
|
u.groBatchTick.Add(1)
|
|
u.groSegmentsTick.Add(int64(actualSegments))
|
|
|
|
if debugEnabled {
|
|
lastLen := segSize
|
|
if tail := totalLen % segSize; tail != 0 {
|
|
lastLen = tail
|
|
}
|
|
u.l.WithFields(logrus.Fields{
|
|
"tag": "gro-debug",
|
|
"stage": "emit",
|
|
"event": "success",
|
|
"payload_len": totalLen,
|
|
"seg_size": segSize,
|
|
"seg_count": segCount,
|
|
"actual_segs": actualSegments,
|
|
"last_seg_len": lastLen,
|
|
"addr": addr.String(),
|
|
"first_remote": firstRemote,
|
|
"first_counter": firstCounter,
|
|
}).Debug("gro-debug emit")
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func (u *StdConn) parseGROSegment(msg *rawMessage, control []byte) (int, int) {
|
|
ctrlLen := int(msg.Hdr.Controllen)
|
|
if ctrlLen <= 0 {
|
|
return 0, 0
|
|
}
|
|
if ctrlLen > len(control) {
|
|
ctrlLen = len(control)
|
|
}
|
|
return u.parseGROSegmentFromControl(control, ctrlLen)
|
|
}
|
|
|
|
func (u *StdConn) parseGROSegmentFromControl(control []byte, ctrlLen int) (int, int) {
|
|
if ctrlLen <= 0 {
|
|
return 0, 0
|
|
}
|
|
if ctrlLen > len(control) {
|
|
ctrlLen = len(control)
|
|
}
|
|
|
|
cmsgs, err := unix.ParseSocketControlMessage(control[:ctrlLen])
|
|
if err != nil {
|
|
u.l.WithError(err).Debug("failed to parse UDP GRO control message")
|
|
return 0, 0
|
|
}
|
|
|
|
for _, c := range cmsgs {
|
|
if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 {
|
|
segSize := int(binary.NativeEndian.Uint16(c.Data[:2]))
|
|
segCount := 0
|
|
if len(c.Data) >= 4 {
|
|
segCount = int(binary.NativeEndian.Uint16(c.Data[2:4]))
|
|
}
|
|
u.l.WithFields(logrus.Fields{
|
|
"tag": "gro-debug",
|
|
"stage": "parse",
|
|
"seg_size": segSize,
|
|
"seg_count": segCount,
|
|
}).Debug("gro-debug control parsed")
|
|
return segSize, segCount
|
|
}
|
|
}
|
|
|
|
return 0, 0
|
|
}
|
|
|
|
func (u *StdConn) configureIOUring(enable bool, c *config.C) {
|
|
if enable {
|
|
if u.ioState.Load() != nil {
|
|
return
|
|
}
|
|
|
|
// Serialize io_uring initialization globally to avoid kernel resource races
|
|
ioUringInitMu.Lock()
|
|
defer ioUringInitMu.Unlock()
|
|
|
|
var configured uint32
|
|
requestedBatch := ioUringDefaultMaxBatch
|
|
if c != nil {
|
|
entries := c.GetInt("listen.io_uring_entries", 0)
|
|
if entries < 0 {
|
|
entries = 0
|
|
}
|
|
configured = uint32(entries)
|
|
holdoff := c.GetDuration("listen.io_uring_batch_holdoff", -1)
|
|
if holdoff < 0 {
|
|
holdoffVal := c.GetInt("listen.io_uring_batch_holdoff", int(ioUringDefaultHoldoff/time.Microsecond))
|
|
holdoff = time.Duration(holdoffVal) * time.Microsecond
|
|
}
|
|
if holdoff < ioUringMinHoldoff {
|
|
holdoff = ioUringMinHoldoff
|
|
}
|
|
if holdoff > ioUringMaxHoldoff {
|
|
holdoff = ioUringMaxHoldoff
|
|
}
|
|
u.ioUringHoldoff.Store(int64(holdoff))
|
|
requestedBatch = clampIoUringBatchSize(c.GetInt("listen.io_uring_max_batch", ioUringDefaultMaxBatch), 0)
|
|
} else {
|
|
u.ioUringHoldoff.Store(int64(ioUringDefaultHoldoff))
|
|
requestedBatch = ioUringDefaultMaxBatch
|
|
}
|
|
if !u.enableGSO {
|
|
if len(u.sendShards) != 1 {
|
|
u.resizeSendShards(1)
|
|
}
|
|
}
|
|
u.ioUringMaxBatch.Store(int64(requestedBatch))
|
|
ring, err := newIoUringState(configured)
|
|
if err != nil {
|
|
u.l.WithError(err).Warn("Failed to enable io_uring; falling back to sendmmsg path")
|
|
return
|
|
}
|
|
u.ioState.Store(ring)
|
|
finalBatch := clampIoUringBatchSize(requestedBatch, ring.sqEntryCount)
|
|
u.ioUringMaxBatch.Store(int64(finalBatch))
|
|
fields := logrus.Fields{
|
|
"entries": ring.sqEntryCount,
|
|
"max_batch": finalBatch,
|
|
}
|
|
if finalBatch != requestedBatch {
|
|
fields["requested_batch"] = requestedBatch
|
|
}
|
|
u.l.WithFields(fields).Debug("io_uring ioState pointer initialized")
|
|
desired := configured
|
|
if desired == 0 {
|
|
desired = defaultIoUringEntries
|
|
}
|
|
if ring.sqEntryCount < desired {
|
|
fields["requested_entries"] = desired
|
|
u.l.WithFields(fields).Warn("UDP io_uring send path enabled with reduced queue depth (ENOMEM)")
|
|
} else {
|
|
u.l.WithFields(fields).Debug("UDP io_uring send path enabled")
|
|
}
|
|
|
|
// Initialize dedicated receive ring with retry logic
|
|
recvPoolSize := 128 // Number of receive operations to keep queued
|
|
recvBufferSize := defaultGROReadBufferSize
|
|
if recvBufferSize < MTU {
|
|
recvBufferSize = MTU
|
|
}
|
|
|
|
var recvRing *ioUringRecvState
|
|
maxRetries := 10
|
|
retryDelay := 10 * time.Millisecond
|
|
|
|
for attempt := 0; attempt < maxRetries; attempt++ {
|
|
var err error
|
|
recvRing, err = newIoUringRecvState(u.sysFd, configured, recvPoolSize, recvBufferSize)
|
|
if err == nil {
|
|
break
|
|
}
|
|
|
|
if attempt < maxRetries-1 {
|
|
u.l.WithFields(logrus.Fields{
|
|
"attempt": attempt + 1,
|
|
"error": err,
|
|
"delay": retryDelay,
|
|
}).Warn("Failed to create io_uring receive ring, retrying")
|
|
time.Sleep(retryDelay)
|
|
retryDelay *= 2 // Exponential backoff
|
|
} else {
|
|
u.l.WithError(err).Error("Failed to create io_uring receive ring after retries; will use standard recvmsg")
|
|
}
|
|
}
|
|
|
|
if recvRing != nil {
|
|
u.ioRecvState.Store(recvRing)
|
|
u.ioRecvActive.Store(true)
|
|
u.l.WithFields(logrus.Fields{
|
|
"entries": recvRing.sqEntryCount,
|
|
"poolSize": recvPoolSize,
|
|
"bufferSize": recvBufferSize,
|
|
}).Info("UDP io_uring receive path enabled")
|
|
// Note: receive queue will be filled on first receivePackets() call
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
if c != nil {
|
|
if u.ioState.Load() != nil {
|
|
u.l.Warn("Runtime disabling of io_uring is not supported; keeping existing ring active until shutdown")
|
|
}
|
|
holdoff := c.GetDuration("listen.io_uring_batch_holdoff", -1)
|
|
if holdoff < 0 {
|
|
holdoffVal := c.GetInt("listen.io_uring_batch_holdoff", int(ioUringDefaultHoldoff/time.Microsecond))
|
|
holdoff = time.Duration(holdoffVal) * time.Microsecond
|
|
}
|
|
if holdoff < ioUringMinHoldoff {
|
|
holdoff = ioUringMinHoldoff
|
|
}
|
|
if holdoff > ioUringMaxHoldoff {
|
|
holdoff = ioUringMaxHoldoff
|
|
}
|
|
u.ioUringHoldoff.Store(int64(holdoff))
|
|
requestedBatch := clampIoUringBatchSize(c.GetInt("listen.io_uring_max_batch", ioUringDefaultMaxBatch), 0)
|
|
if ring := u.ioState.Load(); ring != nil {
|
|
requestedBatch = clampIoUringBatchSize(requestedBatch, ring.sqEntryCount)
|
|
}
|
|
u.ioUringMaxBatch.Store(int64(requestedBatch))
|
|
if !u.enableGSO {
|
|
// io_uring uses a single shared ring with a global mutex,
|
|
// so multiple shards cause severe lock contention.
|
|
// Force 1 shard for optimal io_uring batching performance.
|
|
if ring := u.ioState.Load(); ring != nil {
|
|
if len(u.sendShards) != 1 {
|
|
u.resizeSendShards(1)
|
|
}
|
|
} else {
|
|
// No io_uring, allow config override
|
|
shards := c.GetInt("listen.send_shards", 0)
|
|
if shards <= 0 {
|
|
shards = 1
|
|
}
|
|
if len(u.sendShards) != shards {
|
|
u.resizeSendShards(shards)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (u *StdConn) disableIOUring(reason error) {
|
|
if ring := u.ioState.Swap(nil); ring != nil {
|
|
if err := ring.Close(); err != nil {
|
|
u.l.WithError(err).Warn("Failed to close io_uring state during disable")
|
|
}
|
|
if reason != nil {
|
|
u.l.WithError(reason).Warn("Disabling io_uring send/receive path; falling back to sendmmsg/recvmmsg")
|
|
} else {
|
|
u.l.Warn("Disabling io_uring send/receive path; falling back to sendmmsg/recvmmsg")
|
|
}
|
|
}
|
|
}
|
|
|
|
func (u *StdConn) configureGRO(enable bool) {
|
|
if enable == u.enableGRO {
|
|
if enable {
|
|
u.controlLen.Store(int32(unix.CmsgSpace(2)))
|
|
} else {
|
|
u.controlLen.Store(0)
|
|
}
|
|
return
|
|
}
|
|
|
|
if enable {
|
|
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 1); err != nil {
|
|
u.l.WithError(err).Warn("Failed to enable UDP GRO")
|
|
u.enableGRO = false
|
|
u.controlLen.Store(0)
|
|
return
|
|
}
|
|
u.enableGRO = true
|
|
u.controlLen.Store(int32(unix.CmsgSpace(2)))
|
|
u.l.Info("UDP GRO enabled")
|
|
} else {
|
|
if u.enableGRO {
|
|
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 0); err != nil {
|
|
u.l.WithError(err).Warn("Failed to disable UDP GRO")
|
|
}
|
|
}
|
|
u.enableGRO = false
|
|
u.controlLen.Store(0)
|
|
}
|
|
}
|
|
|
|
func (u *StdConn) configureGSO(enable bool, c *config.C) {
|
|
if len(u.sendShards) == 0 {
|
|
u.initSendShards()
|
|
}
|
|
desiredShards := 0
|
|
if c != nil {
|
|
desiredShards = c.GetInt("listen.send_shards", 0)
|
|
}
|
|
|
|
// io_uring requires 1 shard due to shared ring mutex contention
|
|
if u.ioState.Load() != nil {
|
|
if desiredShards > 1 {
|
|
u.l.WithField("requested_shards", desiredShards).Warn("listen.send_shards ignored because io_uring is enabled; forcing 1 send shard")
|
|
}
|
|
desiredShards = 1
|
|
} else if !enable {
|
|
if c != nil && desiredShards > 1 {
|
|
u.l.WithField("requested_shards", desiredShards).Warn("listen.send_shards ignored because UDP GSO is disabled; forcing 1 send shard")
|
|
}
|
|
desiredShards = 1
|
|
}
|
|
|
|
// Only resize if actually changing shard count
|
|
if len(u.sendShards) != desiredShards {
|
|
u.resizeSendShards(desiredShards)
|
|
}
|
|
|
|
if !enable {
|
|
if u.enableGSO {
|
|
for _, shard := range u.sendShards {
|
|
shard.mu.Lock()
|
|
if shard.pendingSegments > 0 {
|
|
if err := shard.flushPendingLocked(); err != nil {
|
|
u.l.WithError(err).Warn("Failed to flush GSO buffers while disabling")
|
|
}
|
|
} else {
|
|
shard.stopFlushTimerLocked()
|
|
}
|
|
buf := shard.pendingBuf
|
|
shard.pendingBuf = nil
|
|
shard.mu.Unlock()
|
|
if buf != nil {
|
|
u.releaseGSOBuf(buf)
|
|
}
|
|
}
|
|
u.enableGSO = false
|
|
u.l.Info("UDP GSO disabled")
|
|
}
|
|
u.setGroBufferSize(defaultGROReadBufferSize)
|
|
return
|
|
}
|
|
|
|
maxSegments := c.GetInt("listen.gso_max_segments", defaultGSOMaxSegments)
|
|
if maxSegments < 2 {
|
|
maxSegments = 2
|
|
}
|
|
|
|
maxBytes := c.GetInt("listen.gso_max_bytes", 0)
|
|
if maxBytes <= 0 {
|
|
maxBytes = defaultGSOMaxBytes
|
|
}
|
|
if maxBytes < MTU {
|
|
maxBytes = MTU
|
|
}
|
|
if maxBytes > linuxMaxGSOBatchBytes {
|
|
u.l.WithFields(logrus.Fields{
|
|
"configured_bytes": maxBytes,
|
|
"clamped_bytes": linuxMaxGSOBatchBytes,
|
|
}).Warn("listen.gso_max_bytes exceeds Linux UDP limit; clamping")
|
|
maxBytes = linuxMaxGSOBatchBytes
|
|
}
|
|
|
|
flushTimeout := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushTimeout)
|
|
if flushTimeout < 0 {
|
|
flushTimeout = 0
|
|
}
|
|
|
|
u.enableGSO = true
|
|
u.gsoMaxSegments = maxSegments
|
|
u.gsoMaxBytes = maxBytes
|
|
u.gsoFlushTimeout = flushTimeout
|
|
bufSize := defaultGROReadBufferSize
|
|
if u.gsoMaxBytes > bufSize {
|
|
bufSize = u.gsoMaxBytes
|
|
}
|
|
u.setGroBufferSize(bufSize)
|
|
|
|
for _, shard := range u.sendShards {
|
|
shard.mu.Lock()
|
|
if shard.pendingBuf != nil {
|
|
u.releaseGSOBuf(shard.pendingBuf)
|
|
shard.pendingBuf = nil
|
|
}
|
|
shard.pendingSegments = 0
|
|
shard.pendingSegSize = 0
|
|
shard.pendingAddr = netip.AddrPort{}
|
|
shard.stopFlushTimerLocked()
|
|
if len(shard.controlBuf) < unix.CmsgSpace(2) {
|
|
shard.controlBuf = make([]byte, unix.CmsgSpace(2))
|
|
}
|
|
shard.mu.Unlock()
|
|
}
|
|
|
|
u.l.WithFields(logrus.Fields{
|
|
"segments": u.gsoMaxSegments,
|
|
"bytes": u.gsoMaxBytes,
|
|
"flush_timeout": u.gsoFlushTimeout,
|
|
}).Info("UDP GSO configured")
|
|
}
|
|
|
|
func (u *StdConn) ReloadConfig(c *config.C) {
|
|
b := c.GetInt("listen.read_buffer", 0)
|
|
if b > 0 {
|
|
err := u.SetRecvBuffer(b)
|
|
if err == nil {
|
|
s, err := u.GetRecvBuffer()
|
|
if err == nil {
|
|
u.l.WithField("size", s).Info("listen.read_buffer was set")
|
|
} else {
|
|
u.l.WithError(err).Warn("Failed to get listen.read_buffer")
|
|
}
|
|
} else {
|
|
u.l.WithError(err).Error("Failed to set listen.read_buffer")
|
|
}
|
|
}
|
|
|
|
b = c.GetInt("listen.write_buffer", 0)
|
|
if b > 0 {
|
|
err := u.SetSendBuffer(b)
|
|
if err == nil {
|
|
s, err := u.GetSendBuffer()
|
|
if err == nil {
|
|
u.l.WithField("size", s).Info("listen.write_buffer was set")
|
|
} else {
|
|
u.l.WithError(err).Warn("Failed to get listen.write_buffer")
|
|
}
|
|
} else {
|
|
u.l.WithError(err).Error("Failed to set listen.write_buffer")
|
|
}
|
|
}
|
|
|
|
b = c.GetInt("listen.so_mark", 0)
|
|
s, err := u.GetSoMark()
|
|
if b > 0 || (err == nil && s != 0) {
|
|
err := u.SetSoMark(b)
|
|
if err == nil {
|
|
s, err := u.GetSoMark()
|
|
if err == nil {
|
|
u.l.WithField("mark", s).Info("listen.so_mark was set")
|
|
} else {
|
|
u.l.WithError(err).Warn("Failed to get listen.so_mark")
|
|
}
|
|
} else {
|
|
u.l.WithError(err).Error("Failed to set listen.so_mark")
|
|
}
|
|
}
|
|
|
|
u.configureIOUring(c.GetBool("listen.use_io_uring", false), c)
|
|
u.configureGRO(c.GetBool("listen.enable_gro", false))
|
|
u.configureGSO(c.GetBool("listen.enable_gso", false), c)
|
|
}
|
|
|
|
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
|
var vallen uint32 = 4 * unix.SK_MEMINFO_VARS
|
|
_, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0)
|
|
if err != 0 {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (u *StdConn) Close() error {
|
|
if !u.ioClosing.CompareAndSwap(false, true) {
|
|
return nil
|
|
}
|
|
// Attempt to unblock any outstanding sendmsg/sendmmsg calls so the shard
|
|
// workers can drain promptly during shutdown. Ignoring errors here is fine
|
|
// because some platforms/kernels may not support shutdown on UDP sockets.
|
|
if err := unix.Shutdown(u.sysFd, unix.SHUT_RDWR); err != nil && err != unix.ENOTCONN && err != unix.EINVAL && err != unix.EBADF {
|
|
u.l.WithError(err).Debug("Failed to shutdown UDP socket for close")
|
|
}
|
|
|
|
var flushErr error
|
|
for _, shard := range u.sendShards {
|
|
if shard == nil {
|
|
continue
|
|
}
|
|
shard.mu.Lock()
|
|
if shard.pendingSegments > 0 {
|
|
if err := shard.flushPendingLocked(); err != nil && flushErr == nil {
|
|
flushErr = err
|
|
}
|
|
} else {
|
|
shard.stopFlushTimerLocked()
|
|
}
|
|
buf := shard.pendingBuf
|
|
shard.pendingBuf = nil
|
|
shard.mu.Unlock()
|
|
if buf != nil {
|
|
u.releaseGSOBuf(buf)
|
|
}
|
|
shard.stopSender()
|
|
}
|
|
|
|
closeErr := syscall.Close(u.sysFd)
|
|
if ring := u.ioState.Swap(nil); ring != nil {
|
|
if err := ring.Close(); err != nil && flushErr == nil {
|
|
flushErr = err
|
|
}
|
|
}
|
|
if recvRing := u.ioRecvState.Swap(nil); recvRing != nil {
|
|
u.ioRecvActive.Store(false)
|
|
if err := recvRing.Close(); err != nil && flushErr == nil {
|
|
flushErr = err
|
|
}
|
|
}
|
|
if flushErr != nil {
|
|
return flushErr
|
|
}
|
|
return closeErr
|
|
}
|
|
|
|
func NewUDPStatsEmitter(udpConns []Conn) func() {
|
|
// Check if our kernel supports SO_MEMINFO before registering the gauges
|
|
var udpGauges [][unix.SK_MEMINFO_VARS]metrics.Gauge
|
|
var meminfo [unix.SK_MEMINFO_VARS]uint32
|
|
if err := udpConns[0].(*StdConn).getMemInfo(&meminfo); err == nil {
|
|
udpGauges = make([][unix.SK_MEMINFO_VARS]metrics.Gauge, len(udpConns))
|
|
for i := range udpConns {
|
|
udpGauges[i] = [unix.SK_MEMINFO_VARS]metrics.Gauge{
|
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rmem_alloc", i), nil),
|
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rcvbuf", i), nil),
|
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_alloc", i), nil),
|
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.sndbuf", i), nil),
|
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.fwd_alloc", i), nil),
|
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_queued", i), nil),
|
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.optmem", i), nil),
|
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.backlog", i), nil),
|
|
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.drops", i), nil),
|
|
}
|
|
}
|
|
}
|
|
|
|
var stdConns []*StdConn
|
|
for _, conn := range udpConns {
|
|
if sc, ok := conn.(*StdConn); ok {
|
|
stdConns = append(stdConns, sc)
|
|
}
|
|
}
|
|
|
|
return func() {
|
|
for i, gauges := range udpGauges {
|
|
if err := udpConns[i].(*StdConn).getMemInfo(&meminfo); err == nil {
|
|
for j := 0; j < unix.SK_MEMINFO_VARS; j++ {
|
|
gauges[j].Update(int64(meminfo[j]))
|
|
}
|
|
}
|
|
}
|
|
|
|
for _, sc := range stdConns {
|
|
sc.logGSOTick()
|
|
}
|
|
}
|
|
}
|