Compare commits

...

5 Commits

Author SHA1 Message Date
Ryan
2c6f81c224 config tweaks for batching 2025-11-06 10:01:20 -05:00
Ryan
ad37749c5e add batching of packets 2025-11-06 09:42:13 -05:00
Ryan
a0f8cb2098 works properly 2025-11-05 22:09:06 -05:00
Ryan
d18d1aea67 first 2025-11-05 20:34:02 -05:00
Ryan
f5ff534671 make it work with dnclient 2025-11-05 19:25:32 -05:00
10 changed files with 838 additions and 53 deletions

View File

@@ -1,8 +1,10 @@
package cert package cert
import ( import (
"encoding/hex"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"time"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
) )
@@ -138,6 +140,101 @@ func MarshalSigningPrivateKeyToPEM(curve Curve, b []byte) []byte {
} }
} }
// Backward compatibility functions for older API
func MarshalX25519PublicKey(b []byte) []byte {
return MarshalPublicKeyToPEM(Curve_CURVE25519, b)
}
func MarshalX25519PrivateKey(b []byte) []byte {
return MarshalPrivateKeyToPEM(Curve_CURVE25519, b)
}
func MarshalPublicKey(curve Curve, b []byte) []byte {
return MarshalPublicKeyToPEM(curve, b)
}
func MarshalPrivateKey(curve Curve, b []byte) []byte {
return MarshalPrivateKeyToPEM(curve, b)
}
// NebulaCertificate is a compatibility wrapper for the old API
type NebulaCertificate struct {
Details NebulaCertificateDetails
Signature []byte
cert Certificate
}
// NebulaCertificateDetails is a compatibility wrapper for certificate details
type NebulaCertificateDetails struct {
Name string
NotBefore time.Time
NotAfter time.Time
PublicKey []byte
IsCA bool
Issuer []byte
Curve Curve
}
// UnmarshalNebulaCertificateFromPEM provides backward compatibility with the old API
func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) {
c, rest, err := UnmarshalCertificateFromPEM(b)
if err != nil {
return nil, rest, err
}
issuerBytes, err := func() ([]byte, error) {
issuer := c.Issuer()
if issuer == "" {
return nil, nil
}
decoded, err := hex.DecodeString(issuer)
if err != nil {
return nil, fmt.Errorf("failed to decode issuer fingerprint: %w", err)
}
return decoded, nil
}()
if err != nil {
return nil, rest, err
}
pubKey := c.PublicKey()
if pubKey != nil {
pubKey = append([]byte(nil), pubKey...)
}
sig := c.Signature()
if sig != nil {
sig = append([]byte(nil), sig...)
}
return &NebulaCertificate{
Details: NebulaCertificateDetails{
Name: c.Name(),
NotBefore: c.NotBefore(),
NotAfter: c.NotAfter(),
PublicKey: pubKey,
IsCA: c.IsCA(),
Issuer: issuerBytes,
Curve: c.Curve(),
},
Signature: sig,
cert: c,
}, rest, nil
}
// IssuerString returns the issuer in hex format for compatibility
func (n *NebulaCertificate) IssuerString() string {
if n.Details.Issuer == nil {
return ""
}
return hex.EncodeToString(n.Details.Issuer)
}
// Certificate returns the underlying certificate (read-only)
func (n *NebulaCertificate) Certificate() Certificate {
return n.cert
}
// UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non // UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non
// consumed data or an error on failure // consumed data or an error on failure
func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) { func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {

View File

@@ -132,6 +132,13 @@ listen:
# Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg) # Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg)
# default is 64, does not support reload # default is 64, does not support reload
#batch: 64 #batch: 64
# Control batching between UDP and TUN pipelines
#batch:
# inbound_size: 32 # packets to queue from UDP before handing to workers
# outbound_size: 32 # packets to queue from TUN before handing to workers
# flush_interval: 50us # flush partially filled batches after this duration
# max_outstanding: 1028 # batches buffered per routine on each channel
# Configure socket buffers for the udp side (outside), leave unset to use the system defaults. Values will be doubled by the kernel # Configure socket buffers for the udp side (outside), leave unset to use the system defaults. Values will be doubled by the kernel
# Default is net.core.rmem_default and net.core.wmem_default (/proc/sys/net/core/rmem_default and /proc/sys/net/core/rmem_default) # Default is net.core.rmem_default and net.core.wmem_default (/proc/sys/net/core/rmem_default and /proc/sys/net/core/rmem_default)
# Maximum is limited by memory in the system, SO_RCVBUFFORCE and SO_SNDBUFFORCE is used to avoid having to raise the system wide # Maximum is limited by memory in the system, SO_RCVBUFFORCE and SO_SNDBUFFORCE is used to avoid having to raise the system wide

View File

@@ -22,7 +22,14 @@ import (
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
) )
const mtu = 9001 const (
mtu = 9001
inboundBatchSizeDefault = 32
outboundBatchSizeDefault = 32
batchFlushIntervalDefault = 50 * time.Microsecond
maxOutstandingBatchesDefault = 1028
)
type InterfaceConfig struct { type InterfaceConfig struct {
HostMap *HostMap HostMap *HostMap
@@ -48,9 +55,17 @@ type InterfaceConfig struct {
reQueryWait time.Duration reQueryWait time.Duration
ConntrackCacheTimeout time.Duration ConntrackCacheTimeout time.Duration
BatchConfig BatchConfig
l *logrus.Logger l *logrus.Logger
} }
type BatchConfig struct {
InboundBatchSize int
OutboundBatchSize int
FlushInterval time.Duration
MaxOutstandingPerChan int
}
type Interface struct { type Interface struct {
hostMap *HostMap hostMap *HostMap
outside udp.Conn outside udp.Conn
@@ -97,10 +112,86 @@ type Interface struct {
l *logrus.Logger l *logrus.Logger
inPool sync.Pool inPool sync.Pool
inbound chan *packet.Packet inbound []chan *packetBatch
outPool sync.Pool outPool sync.Pool
outbound chan *[]byte outbound []chan *outboundBatch
packetBatchPool sync.Pool
outboundBatchPool sync.Pool
inboundBatchSize int
outboundBatchSize int
batchFlushInterval time.Duration
maxOutstandingPerChan int
}
type packetBatch struct {
packets []*packet.Packet
}
func newPacketBatch(capacity int) *packetBatch {
return &packetBatch{
packets: make([]*packet.Packet, 0, capacity),
}
}
func (b *packetBatch) add(p *packet.Packet) {
b.packets = append(b.packets, p)
}
func (b *packetBatch) reset() {
for i := range b.packets {
b.packets[i] = nil
}
b.packets = b.packets[:0]
}
func (f *Interface) getPacketBatch() *packetBatch {
if v := f.packetBatchPool.Get(); v != nil {
b := v.(*packetBatch)
b.reset()
return b
}
return newPacketBatch(f.inboundBatchSize)
}
func (f *Interface) releasePacketBatch(b *packetBatch) {
b.reset()
f.packetBatchPool.Put(b)
}
type outboundBatch struct {
payloads []*[]byte
}
func newOutboundBatch(capacity int) *outboundBatch {
return &outboundBatch{payloads: make([]*[]byte, 0, capacity)}
}
func (b *outboundBatch) add(buf *[]byte) {
b.payloads = append(b.payloads, buf)
}
func (b *outboundBatch) reset() {
for i := range b.payloads {
b.payloads[i] = nil
}
b.payloads = b.payloads[:0]
}
func (f *Interface) getOutboundBatch() *outboundBatch {
if v := f.outboundBatchPool.Get(); v != nil {
b := v.(*outboundBatch)
b.reset()
return b
}
return newOutboundBatch(f.outboundBatchSize)
}
func (f *Interface) releaseOutboundBatch(b *outboundBatch) {
b.reset()
f.outboundBatchPool.Put(b)
} }
type EncWriter interface { type EncWriter interface {
@@ -170,6 +261,20 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
} }
cs := c.pki.getCertState() cs := c.pki.getCertState()
bc := c.BatchConfig
if bc.InboundBatchSize <= 0 {
bc.InboundBatchSize = inboundBatchSizeDefault
}
if bc.OutboundBatchSize <= 0 {
bc.OutboundBatchSize = outboundBatchSizeDefault
}
if bc.FlushInterval <= 0 {
bc.FlushInterval = batchFlushIntervalDefault
}
if bc.MaxOutstandingPerChan <= 0 {
bc.MaxOutstandingPerChan = maxOutstandingBatchesDefault
}
ifce := &Interface{ ifce := &Interface{
pki: c.pki, pki: c.pki,
hostMap: c.HostMap, hostMap: c.HostMap,
@@ -202,11 +307,20 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil), dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
}, },
//TODO: configurable size inbound: make([]chan *packetBatch, c.routines),
inbound: make(chan *packet.Packet, 1028), outbound: make([]chan *outboundBatch, c.routines),
outbound: make(chan *[]byte, 1028),
l: c.l, l: c.l,
inboundBatchSize: bc.InboundBatchSize,
outboundBatchSize: bc.OutboundBatchSize,
batchFlushInterval: bc.FlushInterval,
maxOutstandingPerChan: bc.MaxOutstandingPerChan,
}
for i := 0; i < c.routines; i++ {
ifce.inbound[i] = make(chan *packetBatch, ifce.maxOutstandingPerChan)
ifce.outbound[i] = make(chan *outboundBatch, ifce.maxOutstandingPerChan)
} }
ifce.inPool = sync.Pool{New: func() any { ifce.inPool = sync.Pool{New: func() any {
@@ -218,6 +332,14 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
return &t return &t
}} }}
ifce.packetBatchPool = sync.Pool{New: func() any {
return newPacketBatch(ifce.inboundBatchSize)
}}
ifce.outboundBatchPool = sync.Pool{New: func() any {
return newOutboundBatch(ifce.outboundBatchSize)
}}
ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryEvery.Store(c.reQueryEvery)
ifce.reQueryWait.Store(int64(c.reQueryWait)) ifce.reQueryWait.Store(int64(c.reQueryWait))
@@ -296,22 +418,41 @@ func (f *Interface) listenOut(i int) {
li = f.outside li = f.outside
} }
batch := f.getPacketBatch()
lastFlush := time.Now()
flush := func(force bool) {
if len(batch.packets) == 0 {
if force {
f.releasePacketBatch(batch)
}
return
}
f.inbound[i] <- batch
batch = f.getPacketBatch()
lastFlush = time.Now()
}
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
p := f.inPool.Get().(*packet.Packet) p := f.inPool.Get().(*packet.Packet)
//TODO: have the listener store this in the msgs array after a read instead of doing a copy
p.Payload = p.Payload[:mtu] p.Payload = p.Payload[:mtu]
copy(p.Payload, payload) copy(p.Payload, payload)
p.Payload = p.Payload[:len(payload)] p.Payload = p.Payload[:len(payload)]
p.Addr = fromUdpAddr p.Addr = fromUdpAddr
f.inbound <- p batch.add(p)
//select {
//case f.inbound <- p: if len(batch.packets) >= f.inboundBatchSize || time.Since(lastFlush) >= f.batchFlushInterval {
//default: flush(false)
// f.l.Error("Dropped packet from inbound channel") }
//}
}) })
if len(batch.packets) > 0 {
f.inbound[i] <- batch
} else {
f.releasePacketBatch(batch)
}
if err != nil && !f.closed.Load() { if err != nil && !f.closed.Load() {
f.l.WithError(err).Error("Error while reading packet inbound packet, closing") f.l.WithError(err).Error("Error while reading packet inbound packet, closing")
//TODO: Trigger Control to close //TODO: Trigger Control to close
@@ -324,6 +465,22 @@ func (f *Interface) listenOut(i int) {
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
runtime.LockOSThread() runtime.LockOSThread()
batch := f.getOutboundBatch()
lastFlush := time.Now()
flush := func(force bool) {
if len(batch.payloads) == 0 {
if force {
f.releaseOutboundBatch(batch)
}
return
}
f.outbound[i] <- batch
batch = f.getOutboundBatch()
lastFlush = time.Now()
}
for { for {
p := f.outPool.Get().(*[]byte) p := f.outPool.Get().(*[]byte)
*p = (*p)[:mtu] *p = (*p)[:mtu]
@@ -337,13 +494,17 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
} }
*p = (*p)[:n] *p = (*p)[:n]
//TODO: nonblocking channel write batch.add(p)
f.outbound <- p
//select { if len(batch.payloads) >= f.outboundBatchSize || time.Since(lastFlush) >= f.batchFlushInterval {
//case f.outbound <- p: flush(false)
//default: }
// f.l.Error("Dropped packet from outbound channel") }
//}
if len(batch.payloads) > 0 {
f.outbound[i] <- batch
} else {
f.releaseOutboundBatch(batch)
} }
f.l.Debugf("overlay reader %v is done", i) f.l.Debugf("overlay reader %v is done", i)
@@ -360,10 +521,13 @@ func (f *Interface) workerIn(i int, ctx context.Context) {
for { for {
select { select {
case p := <-f.inbound: case batch := <-f.inbound[i]:
f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l)) for _, p := range batch.packets {
p.Payload = p.Payload[:mtu] f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
f.inPool.Put(p) p.Payload = p.Payload[:mtu]
f.inPool.Put(p)
}
f.releasePacketBatch(batch)
case <-ctx.Done(): case <-ctx.Done():
f.wg.Done() f.wg.Done()
return return
@@ -379,10 +543,13 @@ func (f *Interface) workerOut(i int, ctx context.Context) {
for { for {
select { select {
case data := <-f.outbound: case batch := <-f.outbound[i]:
f.consumeInsidePacket(*data, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l)) for _, data := range batch.payloads {
*data = (*data)[:mtu] f.consumeInsidePacket(*data, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l))
f.outPool.Put(data) *data = (*data)[:mtu]
f.outPool.Put(data)
}
f.releaseOutboundBatch(batch)
case <-ctx.Done(): case <-ctx.Done():
f.wg.Done() f.wg.Done()
return return

View File

@@ -221,6 +221,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
} }
} }
batchCfg := BatchConfig{
InboundBatchSize: c.GetInt("batch.inbound_size", inboundBatchSizeDefault),
OutboundBatchSize: c.GetInt("batch.outbound_size", outboundBatchSizeDefault),
FlushInterval: c.GetDuration("batch.flush_interval", batchFlushIntervalDefault),
MaxOutstandingPerChan: c.GetInt("batch.max_outstanding", maxOutstandingBatchesDefault),
}
ifConfig := &InterfaceConfig{ ifConfig := &InterfaceConfig{
HostMap: hostMap, HostMap: hostMap,
Inside: tun, Inside: tun,
@@ -242,6 +249,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
relayManager: NewRelayManager(ctx, l, hostMap, c), relayManager: NewRelayManager(ctx, l, hostMap, c),
punchy: punchy, punchy: punchy,
ConntrackCacheTimeout: conntrackCacheTimeout, ConntrackCacheTimeout: conntrackCacheTimeout,
BatchConfig: batchCfg,
l: l, l: l,
} }

View File

@@ -9,13 +9,10 @@ import (
"math" "math"
"net" "net"
"net/netip" "net/netip"
"os"
"strings" "strings"
"sync" "sync"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/buffer"
@@ -46,15 +43,7 @@ type Service struct {
} }
} }
func New(config *config.C) (*Service, error) { func New(control *nebula.Control) (*Service, error) {
logger := logrus.New()
logger.Out = os.Stdout
control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
if err != nil {
return nil, err
}
wait, err := control.Start() wait, err := control.Start()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -30,8 +30,8 @@ func (NoopConn) Rebind() error {
func (NoopConn) LocalAddr() (netip.AddrPort, error) { func (NoopConn) LocalAddr() (netip.AddrPort, error) {
return netip.AddrPort{}, nil return netip.AddrPort{}, nil
} }
func (NoopConn) ListenOut(_ EncReader) { func (NoopConn) ListenOut(_ EncReader) error {
return return nil
} }
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
return nil return nil

View File

@@ -165,7 +165,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() {
return func() {} return func() {}
} }
func (u *StdConn) ListenOut(r EncReader) { func (u *StdConn) ListenOut(r EncReader) error {
buffer := make([]byte, MTU) buffer := make([]byte, MTU)
for { for {
@@ -174,14 +174,17 @@ func (u *StdConn) ListenOut(r EncReader) {
if err != nil { if err != nil {
if errors.Is(err, net.ErrClosed) { if errors.Is(err, net.ErrClosed) {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop") u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return return err
} }
u.l.WithError(err).Error("unexpected udp socket receive error") u.l.WithError(err).Error("unexpected udp socket receive error")
continue
} }
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
} }
return nil
} }
func (u *StdConn) Rebind() error { func (u *StdConn) Rebind() error {

View File

@@ -5,9 +5,11 @@ package udp
import ( import (
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"sync"
"syscall" "syscall"
"time" "time"
"unsafe" "unsafe"
@@ -20,11 +22,38 @@ import (
var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500)) var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500))
const (
defaultGSOMaxSegments = 8
defaultGSOFlushTimeout = 150 * time.Microsecond
defaultGROReadBufferSize = MTU * defaultGSOMaxSegments
maxGSOBatchBytes = 0xFFFF
)
var (
errGSOFallback = errors.New("udp gso fallback")
errGSODisabled = errors.New("udp gso disabled")
)
type StdConn struct { type StdConn struct {
sysFd int sysFd int
isV4 bool isV4 bool
l *logrus.Logger l *logrus.Logger
batch int batch int
enableGRO bool
enableGSO bool
gsoMu sync.Mutex
gsoBuf []byte
gsoAddr netip.AddrPort
gsoSegSize int
gsoSegments int
gsoMaxSegments int
gsoMaxBytes int
gsoFlushTimeout time.Duration
gsoTimer *time.Timer
groBufSize int
} }
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
@@ -69,7 +98,16 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
return nil, fmt.Errorf("unable to bind to socket: %s", err) return nil, fmt.Errorf("unable to bind to socket: %s", err)
} }
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err return &StdConn{
sysFd: fd,
isV4: ip.Is4(),
l: l,
batch: batch,
gsoMaxSegments: defaultGSOMaxSegments,
gsoMaxBytes: MTU * defaultGSOMaxSegments,
gsoFlushTimeout: defaultGSOFlushTimeout,
groBufSize: MTU,
}, err
} }
func (u *StdConn) Rebind() error { func (u *StdConn) Rebind() error {
@@ -119,15 +157,42 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
} }
func (u *StdConn) ListenOut(r EncReader) error { func (u *StdConn) ListenOut(r EncReader) error {
var ip netip.Addr var (
ip netip.Addr
controls [][]byte
)
msgs, buffers, names := u.PrepareRawMessages(u.batch) bufSize := u.readBufferSize()
msgs, buffers, names := u.PrepareRawMessages(u.batch, bufSize)
read := u.ReadMulti read := u.ReadMulti
if u.batch == 1 { if u.batch == 1 {
read = u.ReadSingle read = u.ReadSingle
} }
for { for {
desired := u.readBufferSize()
if len(buffers) == 0 || cap(buffers[0]) < desired {
msgs, buffers, names = u.PrepareRawMessages(u.batch, desired)
controls = nil
}
if u.enableGRO {
if controls == nil {
controls = make([][]byte, len(msgs))
for i := range controls {
controls[i] = make([]byte, unix.CmsgSpace(4))
}
}
for i := range msgs {
setRawMessageControl(&msgs[i], controls[i])
}
} else if controls != nil {
for i := range msgs {
setRawMessageControl(&msgs[i], nil)
}
controls = nil
}
n, err := read(msgs) n, err := read(msgs)
if err != nil { if err != nil {
return err return err
@@ -140,11 +205,82 @@ func (u *StdConn) ListenOut(r EncReader) error {
} else { } else {
ip, _ = netip.AddrFromSlice(names[i][8:24]) ip, _ = netip.AddrFromSlice(names[i][8:24])
} }
r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) addr := netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
payload := buffers[i][:msgs[i].Len]
if u.enableGRO && u.l.IsLevelEnabled(logrus.DebugLevel) {
ctrlLen := getRawMessageControlLen(&msgs[i])
msgFlags := getRawMessageFlags(&msgs[i])
u.l.WithFields(logrus.Fields{
"tag": "gro-debug",
"stage": "recv",
"payload_len": len(payload),
"ctrl_len": ctrlLen,
"msg_flags": msgFlags,
}).Debug("gro batch data")
if controls != nil && ctrlLen > 0 {
maxDump := ctrlLen
if maxDump > 16 {
maxDump = 16
}
u.l.WithFields(logrus.Fields{
"tag": "gro-debug",
"stage": "control-bytes",
"control_hex": fmt.Sprintf("%x", controls[i][:maxDump]),
"datalen": ctrlLen,
}).Debug("gro control dump")
}
}
sawControl := false
if controls != nil {
if ctrlLen := getRawMessageControlLen(&msgs[i]); ctrlLen > 0 {
if segSize, segCount := parseGROControl(controls[i][:ctrlLen]); segSize > 0 {
sawControl = true
if u.l.IsLevelEnabled(logrus.DebugLevel) {
u.l.WithFields(logrus.Fields{
"tag": "gro-debug",
"stage": "control",
"seg_size": segSize,
"seg_count": segCount,
"payloadLen": len(payload),
}).Debug("gro control parsed")
}
segSize = normalizeGROSegSize(segSize, segCount, len(payload))
if segSize > 0 && segSize < len(payload) {
if u.emitGROSegments(r, addr, payload, segSize) {
continue
}
}
}
}
}
if u.enableGRO && len(payload) > MTU {
if !sawControl && u.l.IsLevelEnabled(logrus.DebugLevel) {
u.l.WithFields(logrus.Fields{
"tag": "gro-debug",
"stage": "fallback",
"payload_len": len(payload),
}).Debug("gro control missing; splitting payload by MTU")
}
if u.emitGROSegments(r, addr, payload, MTU) {
continue
}
}
r(addr, payload)
} }
} }
} }
func (u *StdConn) readBufferSize() int {
if u.enableGRO && u.groBufSize > MTU {
return u.groBufSize
}
return MTU
}
func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) { func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
for { for {
n, _, err := unix.Syscall6( n, _, err := unix.Syscall6(
@@ -193,6 +329,14 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
} }
func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
if u.enableGSO && ip.IsValid() {
if err := u.queueGSOPacket(b, ip); err == nil {
return nil
} else if !errors.Is(err, errGSOFallback) {
return err
}
}
if u.isV4 { if u.isV4 {
return u.writeTo4(b, ip) return u.writeTo4(b, ip)
} }
@@ -299,6 +443,94 @@ func (u *StdConn) ReloadConfig(c *config.C) {
u.l.WithError(err).Error("Failed to set listen.so_mark") u.l.WithError(err).Error("Failed to set listen.so_mark")
} }
} }
u.configureGRO(c)
u.configureGSO(c)
}
func (u *StdConn) configureGRO(c *config.C) {
if c == nil {
return
}
enable := c.GetBool("listen.enable_gro", false)
if enable == u.enableGRO {
if enable {
if size := c.GetInt("listen.gro_read_buffer", 0); size > 0 {
u.setGROBufferSize(size)
}
}
return
}
if enable {
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 1); err != nil {
u.l.WithError(err).Warn("Failed to enable UDP GRO")
return
}
u.enableGRO = true
u.setGROBufferSize(c.GetInt("listen.gro_read_buffer", defaultGROReadBufferSize))
u.l.WithField("buffer_size", u.groBufSize).Info("UDP GRO enabled")
return
}
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 0); err != nil && err != unix.ENOPROTOOPT {
u.l.WithError(err).Warn("Failed to disable UDP GRO")
}
u.enableGRO = false
u.groBufSize = MTU
}
func (u *StdConn) configureGSO(c *config.C) {
enable := c.GetBool("listen.enable_gso", false)
if !enable {
u.disableGSO()
} else {
u.enableGSO = true
}
segments := c.GetInt("listen.gso_max_segments", defaultGSOMaxSegments)
if segments < 1 {
segments = 1
}
u.gsoMaxSegments = segments
maxBytes := c.GetInt("listen.gso_max_bytes", 0)
if maxBytes <= 0 {
maxBytes = MTU * segments
}
if maxBytes > maxGSOBatchBytes {
u.l.WithField("requested", maxBytes).Warn("listen.gso_max_bytes larger than UDP limit; clamping")
maxBytes = maxGSOBatchBytes
}
u.gsoMaxBytes = maxBytes
timeout := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushTimeout)
if timeout < 0 {
timeout = 0
}
u.gsoFlushTimeout = timeout
}
func (u *StdConn) setGROBufferSize(size int) {
if size < MTU {
size = defaultGROReadBufferSize
}
if size > maxGSOBatchBytes {
size = maxGSOBatchBytes
}
u.groBufSize = size
}
func (u *StdConn) disableGSO() {
u.gsoMu.Lock()
defer u.gsoMu.Unlock()
u.enableGSO = false
_ = u.flushGSOlocked()
u.gsoBuf = nil
u.gsoSegments = 0
u.gsoSegSize = 0
u.stopGSOTimerLocked()
} }
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
@@ -310,7 +542,239 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
return nil return nil
} }
func (u *StdConn) queueGSOPacket(b []byte, addr netip.AddrPort) error {
if len(b) == 0 {
return nil
}
u.gsoMu.Lock()
defer u.gsoMu.Unlock()
if !u.enableGSO || !addr.IsValid() || len(b) > u.gsoMaxBytes {
if err := u.flushGSOlocked(); err != nil {
return err
}
return errGSOFallback
}
if u.gsoSegments == 0 {
if cap(u.gsoBuf) < u.gsoMaxBytes {
u.gsoBuf = make([]byte, 0, u.gsoMaxBytes)
}
u.gsoAddr = addr
u.gsoSegSize = len(b)
} else if addr != u.gsoAddr || len(b) != u.gsoSegSize {
if err := u.flushGSOlocked(); err != nil {
return err
}
if cap(u.gsoBuf) < u.gsoMaxBytes {
u.gsoBuf = make([]byte, 0, u.gsoMaxBytes)
}
u.gsoAddr = addr
u.gsoSegSize = len(b)
}
if len(u.gsoBuf)+len(b) > u.gsoMaxBytes {
if err := u.flushGSOlocked(); err != nil {
return err
}
if cap(u.gsoBuf) < u.gsoMaxBytes {
u.gsoBuf = make([]byte, 0, u.gsoMaxBytes)
}
u.gsoAddr = addr
u.gsoSegSize = len(b)
}
u.gsoBuf = append(u.gsoBuf, b...)
u.gsoSegments++
if u.gsoSegments >= u.gsoMaxSegments || u.gsoFlushTimeout <= 0 {
return u.flushGSOlocked()
}
u.scheduleGSOFlushLocked()
return nil
}
func (u *StdConn) flushGSOlocked() error {
if u.gsoSegments == 0 {
u.stopGSOTimerLocked()
return nil
}
payload := append([]byte(nil), u.gsoBuf...)
addr := u.gsoAddr
segSize := u.gsoSegSize
u.gsoBuf = u.gsoBuf[:0]
u.gsoSegments = 0
u.gsoSegSize = 0
u.stopGSOTimerLocked()
if segSize <= 0 {
return errGSOFallback
}
err := u.sendSegmented(payload, addr, segSize)
if errors.Is(err, errGSODisabled) {
u.l.WithField("addr", addr).Warn("UDP GSO disabled by kernel, falling back to sendto")
u.enableGSO = false
return u.sendSegmentsIndividually(payload, addr, segSize)
}
return err
}
func (u *StdConn) sendSegmented(payload []byte, addr netip.AddrPort, segSize int) error {
if len(payload) == 0 {
return nil
}
control := make([]byte, unix.CmsgSpace(2))
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.SOL_UDP
hdr.Type = unix.UDP_SEGMENT
setCmsgLen(hdr, unix.CmsgLen(2))
binary.NativeEndian.PutUint16(control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(segSize))
var sa unix.Sockaddr
if addr.Addr().Is4() {
var sa4 unix.SockaddrInet4
sa4.Port = int(addr.Port())
sa4.Addr = addr.Addr().As4()
sa = &sa4
} else {
var sa6 unix.SockaddrInet6
sa6.Port = int(addr.Port())
sa6.Addr = addr.Addr().As16()
sa = &sa6
}
if _, err := unix.SendmsgN(u.sysFd, payload, control, sa, 0); err != nil {
if errno, ok := err.(syscall.Errno); ok && (errno == unix.EINVAL || errno == unix.ENOTSUP || errno == unix.EOPNOTSUPP) {
return errGSODisabled
}
return &net.OpError{Op: "sendmsg", Err: err}
}
return nil
}
func (u *StdConn) sendSegmentsIndividually(buf []byte, addr netip.AddrPort, segSize int) error {
if segSize <= 0 {
return errGSOFallback
}
for offset := 0; offset < len(buf); offset += segSize {
end := offset + segSize
if end > len(buf) {
end = len(buf)
}
var err error
if u.isV4 {
err = u.writeTo4(buf[offset:end], addr)
} else {
err = u.writeTo6(buf[offset:end], addr)
}
if err != nil {
return err
}
}
return nil
}
func (u *StdConn) scheduleGSOFlushLocked() {
if u.gsoTimer == nil {
u.gsoTimer = time.AfterFunc(u.gsoFlushTimeout, u.gsoFlushTimer)
return
}
u.gsoTimer.Reset(u.gsoFlushTimeout)
}
func (u *StdConn) stopGSOTimerLocked() {
if u.gsoTimer != nil {
u.gsoTimer.Stop()
u.gsoTimer = nil
}
}
func (u *StdConn) gsoFlushTimer() {
u.gsoMu.Lock()
defer u.gsoMu.Unlock()
_ = u.flushGSOlocked()
}
func parseGROControl(control []byte) (int, int) {
if len(control) == 0 {
return 0, 0
}
cmsgs, err := unix.ParseSocketControlMessage(control)
if err != nil {
return 0, 0
}
for _, c := range cmsgs {
if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 {
segSize := int(binary.NativeEndian.Uint16(c.Data[:2]))
segCount := 0
if len(c.Data) >= 4 {
segCount = int(binary.NativeEndian.Uint16(c.Data[2:4]))
}
return segSize, segCount
}
}
return 0, 0
}
func (u *StdConn) emitGROSegments(r EncReader, addr netip.AddrPort, payload []byte, segSize int) bool {
if segSize <= 0 {
return false
}
for offset := 0; offset < len(payload); offset += segSize {
end := offset + segSize
if end > len(payload) {
end = len(payload)
}
segment := make([]byte, end-offset)
copy(segment, payload[offset:end])
r(addr, segment)
}
return true
}
func normalizeGROSegSize(segSize, segCount, total int) int {
if segSize <= 0 || total <= 0 {
return segSize
}
if segSize > total && segCount > 0 {
segSize = total / segCount
if segSize == 0 {
segSize = total
}
}
if segCount <= 1 && segSize > 0 && total > segSize {
calculated := total / segSize
if calculated <= 1 {
calculated = (total + segSize - 1) / segSize
}
if calculated > 1 {
segCount = calculated
}
}
if segSize > MTU {
return MTU
}
return segSize
}
func (u *StdConn) Close() error { func (u *StdConn) Close() error {
u.disableGSO()
return syscall.Close(u.sysFd) return syscall.Close(u.sysFd)
} }

View File

@@ -30,13 +30,16 @@ type rawMessage struct {
Len uint32 Len uint32
} }
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { func (u *StdConn) PrepareRawMessages(n int, bufSize int) ([]rawMessage, [][]byte, [][]byte) {
if bufSize <= 0 {
bufSize = MTU
}
msgs := make([]rawMessage, n) msgs := make([]rawMessage, n)
buffers := make([][]byte, n) buffers := make([][]byte, n)
names := make([][]byte, n) names := make([][]byte, n)
for i := range msgs { for i := range msgs {
buffers[i] = make([]byte, MTU) buffers[i] = make([]byte, bufSize)
names[i] = make([]byte, unix.SizeofSockaddrInet6) names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{ vs := []iovec{
@@ -52,3 +55,25 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
return msgs, buffers, names return msgs, buffers, names
} }
func setRawMessageControl(msg *rawMessage, buf []byte) {
if len(buf) == 0 {
msg.Hdr.Control = nil
msg.Hdr.Controllen = 0
return
}
msg.Hdr.Control = &buf[0]
msg.Hdr.Controllen = uint32(len(buf))
}
func getRawMessageControlLen(msg *rawMessage) int {
return int(msg.Hdr.Controllen)
}
func getRawMessageFlags(msg *rawMessage) int {
return int(msg.Hdr.Flags)
}
func setCmsgLen(h *unix.Cmsghdr, l int) {
h.Len = uint32(l)
}

View File

@@ -33,13 +33,16 @@ type rawMessage struct {
Pad0 [4]byte Pad0 [4]byte
} }
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { func (u *StdConn) PrepareRawMessages(n int, bufSize int) ([]rawMessage, [][]byte, [][]byte) {
if bufSize <= 0 {
bufSize = MTU
}
msgs := make([]rawMessage, n) msgs := make([]rawMessage, n)
buffers := make([][]byte, n) buffers := make([][]byte, n)
names := make([][]byte, n) names := make([][]byte, n)
for i := range msgs { for i := range msgs {
buffers[i] = make([]byte, MTU) buffers[i] = make([]byte, bufSize)
names[i] = make([]byte, unix.SizeofSockaddrInet6) names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{ vs := []iovec{
@@ -55,3 +58,25 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
return msgs, buffers, names return msgs, buffers, names
} }
func setRawMessageControl(msg *rawMessage, buf []byte) {
if len(buf) == 0 {
msg.Hdr.Control = nil
msg.Hdr.Controllen = 0
return
}
msg.Hdr.Control = &buf[0]
msg.Hdr.Controllen = uint64(len(buf))
}
func getRawMessageControlLen(msg *rawMessage) int {
return int(msg.Hdr.Controllen)
}
func getRawMessageFlags(msg *rawMessage) int {
return int(msg.Hdr.Flags)
}
func setCmsgLen(h *unix.Cmsghdr, l int) {
h.Len = uint64(l)
}