Compare commits

..

5 Commits

Author SHA1 Message Date
Ryan
2c6f81c224 config tweaks for batching 2025-11-06 10:01:20 -05:00
Ryan
ad37749c5e add batching of packets 2025-11-06 09:42:13 -05:00
Ryan
a0f8cb2098 works properly 2025-11-05 22:09:06 -05:00
Ryan
d18d1aea67 first 2025-11-05 20:34:02 -05:00
Ryan
f5ff534671 make it work with dnclient 2025-11-05 19:25:32 -05:00
14 changed files with 852 additions and 458 deletions

View File

@@ -1,8 +1,10 @@
package cert
import (
"encoding/hex"
"encoding/pem"
"fmt"
"time"
"golang.org/x/crypto/ed25519"
)
@@ -138,6 +140,101 @@ func MarshalSigningPrivateKeyToPEM(curve Curve, b []byte) []byte {
}
}
// Backward compatibility functions for older API
func MarshalX25519PublicKey(b []byte) []byte {
return MarshalPublicKeyToPEM(Curve_CURVE25519, b)
}
func MarshalX25519PrivateKey(b []byte) []byte {
return MarshalPrivateKeyToPEM(Curve_CURVE25519, b)
}
func MarshalPublicKey(curve Curve, b []byte) []byte {
return MarshalPublicKeyToPEM(curve, b)
}
func MarshalPrivateKey(curve Curve, b []byte) []byte {
return MarshalPrivateKeyToPEM(curve, b)
}
// NebulaCertificate is a compatibility wrapper for the old API
type NebulaCertificate struct {
Details NebulaCertificateDetails
Signature []byte
cert Certificate
}
// NebulaCertificateDetails is a compatibility wrapper for certificate details
type NebulaCertificateDetails struct {
Name string
NotBefore time.Time
NotAfter time.Time
PublicKey []byte
IsCA bool
Issuer []byte
Curve Curve
}
// UnmarshalNebulaCertificateFromPEM provides backward compatibility with the old API
func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) {
c, rest, err := UnmarshalCertificateFromPEM(b)
if err != nil {
return nil, rest, err
}
issuerBytes, err := func() ([]byte, error) {
issuer := c.Issuer()
if issuer == "" {
return nil, nil
}
decoded, err := hex.DecodeString(issuer)
if err != nil {
return nil, fmt.Errorf("failed to decode issuer fingerprint: %w", err)
}
return decoded, nil
}()
if err != nil {
return nil, rest, err
}
pubKey := c.PublicKey()
if pubKey != nil {
pubKey = append([]byte(nil), pubKey...)
}
sig := c.Signature()
if sig != nil {
sig = append([]byte(nil), sig...)
}
return &NebulaCertificate{
Details: NebulaCertificateDetails{
Name: c.Name(),
NotBefore: c.NotBefore(),
NotAfter: c.NotAfter(),
PublicKey: pubKey,
IsCA: c.IsCA(),
Issuer: issuerBytes,
Curve: c.Curve(),
},
Signature: sig,
cert: c,
}, rest, nil
}
// IssuerString returns the issuer in hex format for compatibility
func (n *NebulaCertificate) IssuerString() string {
if n.Details.Issuer == nil {
return ""
}
return hex.EncodeToString(n.Details.Issuer)
}
// Certificate returns the underlying certificate (read-only)
func (n *NebulaCertificate) Certificate() Certificate {
return n.cert
}
// UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non
// consumed data or an error on failure
func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {

View File

@@ -1,191 +0,0 @@
package main
import (
"encoding/binary"
"errors"
"flag"
"fmt"
"log"
"net"
"net/netip"
"time"
"unsafe"
"golang.org/x/sys/unix"
)
const (
// UDP_SEGMENT enables GSO segmentation
UDP_SEGMENT = 103
// Maximum GSO segment size (typical MTU - headers)
maxGSOSize = 1400
)
func main() {
destAddr := flag.String("dest", "10.4.0.16:4202", "Destination address")
gsoSize := flag.Int("gso", 1400, "GSO segment size")
totalSize := flag.Int("size", 14000, "Total payload size to send")
count := flag.Int("count", 1, "Number of packets to send")
flag.Parse()
if *gsoSize > maxGSOSize {
log.Fatalf("GSO size %d exceeds maximum %d", *gsoSize, maxGSOSize)
}
// Resolve destination address
_, err := net.ResolveUDPAddr("udp", *destAddr)
if err != nil {
log.Fatalf("Failed to resolve address: %v", err)
}
// Create a raw UDP socket with GSO support
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
if err != nil {
log.Fatalf("Failed to create socket: %v", err)
}
defer unix.Close(fd)
// Bind to a local address
localAddr := &unix.SockaddrInet4{
Port: 0, // Let the system choose a port
}
if err := unix.Bind(fd, localAddr); err != nil {
log.Fatalf("Failed to bind socket: %v", err)
}
fmt.Printf("Sending UDP packets with GSO enabled\n")
fmt.Printf("Destination: %s\n", *destAddr)
fmt.Printf("GSO segment size: %d bytes\n", *gsoSize)
fmt.Printf("Total payload size: %d bytes\n", *totalSize)
fmt.Printf("Number of packets: %d\n\n", *count)
// Create payload
payload := make([]byte, *totalSize)
for i := range payload {
payload[i] = byte(i % 256)
}
dest := netip.MustParseAddrPort(*destAddr)
//if err := unix.SetsockoptInt(fd, unix.SOL_UDP, unix.UDP_SEGMENT, 1400); err != nil {
// panic(err)
//}
for i := 0; i < *count; i++ {
err := WriteBatch(fd, payload, dest, uint16(*gsoSize), true)
if err != nil {
log.Printf("Send error on packet %d: %v", i, err)
continue
}
if (i+1)%100 == 0 || i == *count-1 {
fmt.Printf("Sent %d packets\n", i+1)
}
}
fmt.Printf("now, let's send without the correct ctrl header\n")
time.Sleep(time.Second)
for i := 0; i < *count; i++ {
err := WriteBatch(fd, payload, dest, uint16(*gsoSize), false)
if err != nil {
log.Printf("Send error on packet %d: %v", i, err)
continue
}
if (i+1)%100 == 0 || i == *count-1 {
fmt.Printf("Sent %d packets\n", i+1)
}
}
}
func WriteBatch(fd int, payload []byte, addr netip.AddrPort, segSize uint16, withHeader bool) error {
msgs := make([]rawMessage, 0, 1)
iovs := make([]iovec, 0, 1)
names := make([][unix.SizeofSockaddrInet6]byte, 0, 1)
sent := 0
pkts := []BatchPacket{
{
Payload: payload,
Addr: addr,
},
}
for _, pkt := range pkts {
if len(pkt.Payload) == 0 {
sent++
continue
}
msgs = append(msgs, rawMessage{})
iovs = append(iovs, iovec{})
names = append(names, [unix.SizeofSockaddrInet6]byte{})
idx := len(msgs) - 1
msg := &msgs[idx]
iov := &iovs[idx]
name := &names[idx]
setIovecSlice(iov, pkt.Payload)
msg.Hdr.Iov = iov
msg.Hdr.Iovlen = 1
if withHeader {
setRawMessageControl(msg, buildGSOControlMessage(segSize)) //
} else {
setRawMessageControl(msg, nil) //
}
msg.Hdr.Flags = 0
nameLen, err := encodeSockaddr(name[:], pkt.Addr)
if err != nil {
return err
}
msg.Hdr.Name = &name[0]
msg.Hdr.Namelen = nameLen
}
if len(msgs) == 0 {
return errors.New("nothing to write")
}
offset := 0
for offset < len(msgs) {
n, _, errno := unix.Syscall6(
unix.SYS_SENDMMSG,
uintptr(fd),
uintptr(unsafe.Pointer(&msgs[offset])),
uintptr(len(msgs)-offset),
0,
0,
0,
)
if errno != 0 {
if errno == unix.EINTR {
continue
}
return &net.OpError{Op: "sendmmsg", Err: errno}
}
if n == 0 {
break
}
offset += int(n)
}
return nil
}
func buildGSOControlMessage(segSize uint16) []byte {
control := make([]byte, unix.CmsgSpace(2))
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.SOL_UDP
hdr.Type = unix.UDP_SEGMENT
setCmsgLen(hdr, unix.CmsgLen(2))
binary.NativeEndian.PutUint16(control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(segSize))
return control
}

View File

@@ -1,85 +0,0 @@
package main
import (
"encoding/binary"
"fmt"
"net/netip"
"unsafe"
"golang.org/x/sys/unix"
)
type iovec struct {
Base *byte
Len uint64
}
type msghdr struct {
Name *byte
Namelen uint32
Pad0 [4]byte
Iov *iovec
Iovlen uint64
Control *byte
Controllen uint64
Flags int32
Pad1 [4]byte
}
type rawMessage struct {
Hdr msghdr
Len uint32
Pad0 [4]byte
}
type BatchPacket struct {
Payload []byte
Addr netip.AddrPort
}
func encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
if addr.Addr().Is4() {
if !addr.Addr().Is4() {
return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
}
var sa unix.RawSockaddrInet4
sa.Family = unix.AF_INET
sa.Addr = addr.Addr().As4()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
size := unix.SizeofSockaddrInet4
copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:])
return uint32(size), nil
}
var sa unix.RawSockaddrInet6
sa.Family = unix.AF_INET6
sa.Addr = addr.Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
size := unix.SizeofSockaddrInet6
copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:])
return uint32(size), nil
}
func setRawMessageControl(msg *rawMessage, buf []byte) {
if len(buf) == 0 {
msg.Hdr.Control = nil
msg.Hdr.Controllen = 0
return
}
msg.Hdr.Control = &buf[0]
msg.Hdr.Controllen = uint64(len(buf))
}
func setCmsgLen(h *unix.Cmsghdr, l int) {
h.Len = uint64(l)
}
func setIovecSlice(iov *iovec, b []byte) {
if len(b) == 0 {
iov.Base = nil
iov.Len = 0
return
}
iov.Base = &b[0]
iov.Len = uint64(len(b))
}

View File

@@ -15,7 +15,7 @@ import (
// TODO: In a 5Gbps test, 1024 is not sufficient. With a 1400 MTU this is about 1.4Gbps of window, assuming full packets.
// 4092 should be sufficient for 5Gbps
const ReplayWindow = 8192
const ReplayWindow = 1024
type ConnectionState struct {
eKey *NebulaCipherState

View File

@@ -132,6 +132,13 @@ listen:
# Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg)
# default is 64, does not support reload
#batch: 64
# Control batching between UDP and TUN pipelines
#batch:
# inbound_size: 32 # packets to queue from UDP before handing to workers
# outbound_size: 32 # packets to queue from TUN before handing to workers
# flush_interval: 50us # flush partially filled batches after this duration
# max_outstanding: 1028 # batches buffered per routine on each channel
# Configure socket buffers for the udp side (outside), leave unset to use the system defaults. Values will be doubled by the kernel
# Default is net.core.rmem_default and net.core.wmem_default (/proc/sys/net/core/rmem_default and /proc/sys/net/core/rmem_default)
# Maximum is limited by memory in the system, SO_RCVBUFFORCE and SO_SNDBUFFORCE is used to avoid having to raise the system wide

View File

@@ -22,7 +22,14 @@ import (
"github.com/slackhq/nebula/udp"
)
const mtu = 9001
const (
mtu = 9001
inboundBatchSizeDefault = 32
outboundBatchSizeDefault = 32
batchFlushIntervalDefault = 50 * time.Microsecond
maxOutstandingBatchesDefault = 1028
)
type InterfaceConfig struct {
HostMap *HostMap
@@ -48,9 +55,17 @@ type InterfaceConfig struct {
reQueryWait time.Duration
ConntrackCacheTimeout time.Duration
BatchConfig BatchConfig
l *logrus.Logger
}
type BatchConfig struct {
InboundBatchSize int
OutboundBatchSize int
FlushInterval time.Duration
MaxOutstandingPerChan int
}
type Interface struct {
hostMap *HostMap
outside udp.Conn
@@ -96,9 +111,87 @@ type Interface struct {
l *logrus.Logger
pktPool *packet.Pool
inbound chan *packet.Packet
outbound chan *packet.Packet
inPool sync.Pool
inbound []chan *packetBatch
outPool sync.Pool
outbound []chan *outboundBatch
packetBatchPool sync.Pool
outboundBatchPool sync.Pool
inboundBatchSize int
outboundBatchSize int
batchFlushInterval time.Duration
maxOutstandingPerChan int
}
type packetBatch struct {
packets []*packet.Packet
}
func newPacketBatch(capacity int) *packetBatch {
return &packetBatch{
packets: make([]*packet.Packet, 0, capacity),
}
}
func (b *packetBatch) add(p *packet.Packet) {
b.packets = append(b.packets, p)
}
func (b *packetBatch) reset() {
for i := range b.packets {
b.packets[i] = nil
}
b.packets = b.packets[:0]
}
func (f *Interface) getPacketBatch() *packetBatch {
if v := f.packetBatchPool.Get(); v != nil {
b := v.(*packetBatch)
b.reset()
return b
}
return newPacketBatch(f.inboundBatchSize)
}
func (f *Interface) releasePacketBatch(b *packetBatch) {
b.reset()
f.packetBatchPool.Put(b)
}
type outboundBatch struct {
payloads []*[]byte
}
func newOutboundBatch(capacity int) *outboundBatch {
return &outboundBatch{payloads: make([]*[]byte, 0, capacity)}
}
func (b *outboundBatch) add(buf *[]byte) {
b.payloads = append(b.payloads, buf)
}
func (b *outboundBatch) reset() {
for i := range b.payloads {
b.payloads[i] = nil
}
b.payloads = b.payloads[:0]
}
func (f *Interface) getOutboundBatch() *outboundBatch {
if v := f.outboundBatchPool.Get(); v != nil {
b := v.(*outboundBatch)
b.reset()
return b
}
return newOutboundBatch(f.outboundBatchSize)
}
func (f *Interface) releaseOutboundBatch(b *outboundBatch) {
b.reset()
f.outboundBatchPool.Put(b)
}
type EncWriter interface {
@@ -168,6 +261,20 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
}
cs := c.pki.getCertState()
bc := c.BatchConfig
if bc.InboundBatchSize <= 0 {
bc.InboundBatchSize = inboundBatchSizeDefault
}
if bc.OutboundBatchSize <= 0 {
bc.OutboundBatchSize = outboundBatchSizeDefault
}
if bc.FlushInterval <= 0 {
bc.FlushInterval = batchFlushIntervalDefault
}
if bc.MaxOutstandingPerChan <= 0 {
bc.MaxOutstandingPerChan = maxOutstandingBatchesDefault
}
ifce := &Interface{
pki: c.pki,
hostMap: c.HostMap,
@@ -200,14 +307,38 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
},
//TODO: configurable size
inbound: make(chan *packet.Packet, 2048),
outbound: make(chan *packet.Packet, 2048),
inbound: make([]chan *packetBatch, c.routines),
outbound: make([]chan *outboundBatch, c.routines),
l: c.l,
inboundBatchSize: bc.InboundBatchSize,
outboundBatchSize: bc.OutboundBatchSize,
batchFlushInterval: bc.FlushInterval,
maxOutstandingPerChan: bc.MaxOutstandingPerChan,
}
ifce.pktPool = packet.GetPool()
for i := 0; i < c.routines; i++ {
ifce.inbound[i] = make(chan *packetBatch, ifce.maxOutstandingPerChan)
ifce.outbound[i] = make(chan *outboundBatch, ifce.maxOutstandingPerChan)
}
ifce.inPool = sync.Pool{New: func() any {
return packet.New()
}}
ifce.outPool = sync.Pool{New: func() any {
t := make([]byte, mtu)
return &t
}}
ifce.packetBatchPool = sync.Pool{New: func() any {
return newPacketBatch(ifce.inboundBatchSize)
}}
ifce.outboundBatchPool = sync.Pool{New: func() any {
return newOutboundBatch(ifce.outboundBatchSize)
}}
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
ifce.reQueryEvery.Store(c.reQueryEvery)
@@ -258,21 +389,19 @@ func (f *Interface) activate() error {
func (f *Interface) run(c context.Context) (func(), error) {
for i := 0; i < f.routines; i++ {
// read packets from udp and queue to f.inbound
// Launch n queues to read packets from udp
f.wg.Add(1)
go f.listenOut(i)
// Launch n queues to read packets from inside tun dev and queue to f.outbound
//todo this never stops f.wg.Add(1)
// Launch n queues to read packets from tun dev
f.wg.Add(1)
go f.listenIn(f.readers[i], i)
// Launch n workers to process traffic from f.inbound and smash it onto the inside of the tun
f.wg.Add(1)
go f.workerIn(i, c)
// Launch n queues to read packets from tun dev
f.wg.Add(1)
go f.workerIn(i, c)
// read from f.outbound and write to UDP (outside the tun)
// Launch n queues to read packets from tun dev
f.wg.Add(1)
go f.workerOut(i, c)
}
@@ -289,7 +418,41 @@ func (f *Interface) listenOut(i int) {
li = f.outside
}
err := li.ListenOut(f.pktPool.Get, f.inbound)
batch := f.getPacketBatch()
lastFlush := time.Now()
flush := func(force bool) {
if len(batch.packets) == 0 {
if force {
f.releasePacketBatch(batch)
}
return
}
f.inbound[i] <- batch
batch = f.getPacketBatch()
lastFlush = time.Now()
}
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
p := f.inPool.Get().(*packet.Packet)
p.Payload = p.Payload[:mtu]
copy(p.Payload, payload)
p.Payload = p.Payload[:len(payload)]
p.Addr = fromUdpAddr
batch.add(p)
if len(batch.packets) >= f.inboundBatchSize || time.Since(lastFlush) >= f.batchFlushInterval {
flush(false)
}
})
if len(batch.packets) > 0 {
f.inbound[i] <- batch
} else {
f.releasePacketBatch(batch)
}
if err != nil && !f.closed.Load() {
f.l.WithError(err).Error("Error while reading packet inbound packet, closing")
//TODO: Trigger Control to close
@@ -302,9 +465,26 @@ func (f *Interface) listenOut(i int) {
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
runtime.LockOSThread()
batch := f.getOutboundBatch()
lastFlush := time.Now()
flush := func(force bool) {
if len(batch.payloads) == 0 {
if force {
f.releaseOutboundBatch(batch)
}
return
}
f.outbound[i] <- batch
batch = f.getOutboundBatch()
lastFlush = time.Now()
}
for {
p := f.pktPool.Get()
n, err := reader.Read(p.Payload)
p := f.outPool.Get().(*[]byte)
*p = (*p)[:mtu]
n, err := reader.Read(*p)
if err != nil {
if !f.closed.Load() {
f.l.WithError(err).Error("Error while reading outbound packet, closing")
@@ -313,14 +493,18 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
break
}
p.Payload = (p.Payload)[:n]
//TODO: nonblocking channel write
f.outbound <- p
//select {
//case f.outbound <- p:
//default:
// f.l.Error("Dropped packet from outbound channel")
//}
*p = (*p)[:n]
batch.add(p)
if len(batch.payloads) >= f.outboundBatchSize || time.Since(lastFlush) >= f.batchFlushInterval {
flush(false)
}
}
if len(batch.payloads) > 0 {
f.outbound[i] <- batch
} else {
f.releaseOutboundBatch(batch)
}
f.l.Debugf("overlay reader %v is done", i)
@@ -337,20 +521,13 @@ func (f *Interface) workerIn(i int, ctx context.Context) {
for {
select {
case p := <-f.inbound:
if p.SegSize > 0 && p.SegSize < len(p.Payload) {
for offset := 0; offset < len(p.Payload); offset += p.SegSize {
end := offset + p.SegSize
if end > len(p.Payload) {
end = len(p.Payload)
}
f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload[offset:end], h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
}
} else {
case batch := <-f.inbound[i]:
for _, p := range batch.packets {
f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
p.Payload = p.Payload[:mtu]
f.inPool.Put(p)
}
f.pktPool.Put(p)
f.releasePacketBatch(batch)
case <-ctx.Done():
f.wg.Done()
return
@@ -366,9 +543,13 @@ func (f *Interface) workerOut(i int, ctx context.Context) {
for {
select {
case data := <-f.outbound:
f.consumeInsidePacket(data.Payload, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l))
f.pktPool.Put(data)
case batch := <-f.outbound[i]:
for _, data := range batch.payloads {
f.consumeInsidePacket(*data, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l))
*data = (*data)[:mtu]
f.outPool.Put(data)
}
f.releaseOutboundBatch(batch)
case <-ctx.Done():
f.wg.Done()
return

View File

@@ -221,6 +221,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
}
batchCfg := BatchConfig{
InboundBatchSize: c.GetInt("batch.inbound_size", inboundBatchSizeDefault),
OutboundBatchSize: c.GetInt("batch.outbound_size", outboundBatchSizeDefault),
FlushInterval: c.GetDuration("batch.flush_interval", batchFlushIntervalDefault),
MaxOutstandingPerChan: c.GetInt("batch.max_outstanding", maxOutstandingBatchesDefault),
}
ifConfig := &InterfaceConfig{
HostMap: hostMap,
Inside: tun,
@@ -242,6 +249,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
relayManager: NewRelayManager(ctx, l, hostMap, c),
punchy: punchy,
ConntrackCacheTimeout: conntrackCacheTimeout,
BatchConfig: batchCfg,
l: l,
}

View File

@@ -1,45 +1,12 @@
package packet
import (
"net/netip"
"sync"
"golang.org/x/sys/unix"
)
const Size = 0xffff
import "net/netip"
type Packet struct {
Payload []byte
Control []byte
SegSize int
Addr netip.AddrPort
}
func New() *Packet {
return &Packet{
Payload: make([]byte, Size),
Control: make([]byte, unix.CmsgSpace(2)),
}
}
type Pool struct {
pool sync.Pool
}
var bigPool = &Pool{
pool: sync.Pool{New: func() any { return New() }},
}
func GetPool() *Pool {
return bigPool
}
func (p *Pool) Get() *Packet {
return p.pool.Get().(*Packet)
}
func (p *Pool) Put(x *Packet) {
x.Payload = x.Payload[:Size]
p.pool.Put(x)
return &Packet{Payload: make([]byte, 9001)}
}

View File

@@ -9,13 +9,10 @@ import (
"math"
"net"
"net/netip"
"os"
"strings"
"sync"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay"
"golang.org/x/sync/errgroup"
"gvisor.dev/gvisor/pkg/buffer"
@@ -46,15 +43,7 @@ type Service struct {
}
}
func New(config *config.C) (*Service, error) {
logger := logrus.New()
logger.Out = os.Stdout
control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
if err != nil {
return nil, err
}
func New(control *nebula.Control) (*Service, error) {
wait, err := control.Start()
if err != nil {
return nil, err

View File

@@ -4,19 +4,19 @@ import (
"net/netip"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/packet"
)
const MTU = 9001
type EncReader func(*packet.Packet)
type PacketBufferGetter func() *packet.Packet
type EncReader func(
addr netip.AddrPort,
payload []byte,
)
type Conn interface {
Rebind() error
LocalAddr() (netip.AddrPort, error)
ListenOut(pg PacketBufferGetter, pc chan *packet.Packet) error
ListenOut(r EncReader) error
WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C)
Close() error
@@ -30,8 +30,8 @@ func (NoopConn) Rebind() error {
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
return netip.AddrPort{}, nil
}
func (NoopConn) ListenOut(_ EncReader) {
return
func (NoopConn) ListenOut(_ EncReader) error {
return nil
}
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
return nil

View File

@@ -165,7 +165,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() {
return func() {}
}
func (u *StdConn) ListenOut(r EncReader) {
func (u *StdConn) ListenOut(r EncReader) error {
buffer := make([]byte, MTU)
for {
@@ -174,14 +174,17 @@ func (u *StdConn) ListenOut(r EncReader) {
if err != nil {
if errors.Is(err, net.ErrClosed) {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
return err
}
u.l.WithError(err).Error("unexpected udp socket receive error")
continue
}
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
}
return nil
}
func (u *StdConn) Rebind() error {

View File

@@ -5,9 +5,11 @@ package udp
import (
"encoding/binary"
"errors"
"fmt"
"net"
"net/netip"
"sync"
"syscall"
"time"
"unsafe"
@@ -15,20 +17,43 @@ import (
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/packet"
"golang.org/x/sys/unix"
)
var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500))
const (
defaultGSOMaxSegments = 8
defaultGSOFlushTimeout = 150 * time.Microsecond
defaultGROReadBufferSize = MTU * defaultGSOMaxSegments
maxGSOBatchBytes = 0xFFFF
)
var (
errGSOFallback = errors.New("udp gso fallback")
errGSODisabled = errors.New("udp gso disabled")
)
type StdConn struct {
sysFd int
isV4 bool
l *logrus.Logger
batch int
enableGRO bool
enableGSO bool
//gso gsoState
gsoMu sync.Mutex
gsoBuf []byte
gsoAddr netip.AddrPort
gsoSegSize int
gsoSegments int
gsoMaxSegments int
gsoMaxBytes int
gsoFlushTimeout time.Duration
gsoTimer *time.Timer
groBufSize int
}
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
@@ -73,7 +98,16 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
return nil, fmt.Errorf("unable to bind to socket: %s", err)
}
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
return &StdConn{
sysFd: fd,
isV4: ip.Is4(),
l: l,
batch: batch,
gsoMaxSegments: defaultGSOMaxSegments,
gsoMaxBytes: MTU * defaultGSOMaxSegments,
gsoFlushTimeout: defaultGSOFlushTimeout,
groBufSize: MTU,
}, err
}
func (u *StdConn) Rebind() error {
@@ -122,71 +156,129 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
}
}
func (u *StdConn) ListenOut(pg PacketBufferGetter, pc chan *packet.Packet) error {
var ip netip.Addr
func (u *StdConn) ListenOut(r EncReader) error {
var (
ip netip.Addr
controls [][]byte
)
msgs, packets, names := u.PrepareRawMessages(u.batch, pg)
bufSize := u.readBufferSize()
msgs, buffers, names := u.PrepareRawMessages(u.batch, bufSize)
read := u.ReadMulti
if u.batch == 1 {
read = u.ReadSingle
}
for {
desired := u.readBufferSize()
if len(buffers) == 0 || cap(buffers[0]) < desired {
msgs, buffers, names = u.PrepareRawMessages(u.batch, desired)
controls = nil
}
if u.enableGRO {
if controls == nil {
controls = make([][]byte, len(msgs))
for i := range controls {
controls[i] = make([]byte, unix.CmsgSpace(4))
}
}
for i := range msgs {
setRawMessageControl(&msgs[i], controls[i])
}
} else if controls != nil {
for i := range msgs {
setRawMessageControl(&msgs[i], nil)
}
controls = nil
}
n, err := read(msgs)
if err != nil {
return err
}
for i := 0; i < n; i++ {
out := packets[i]
out.Payload = out.Payload[:msgs[i].Len]
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
if u.isV4 {
ip, _ = netip.AddrFromSlice(names[i][4:8])
} else {
ip, _ = netip.AddrFromSlice(names[i][8:24])
}
out.Addr = netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
addr := netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
payload := buffers[i][:msgs[i].Len]
if u.enableGRO && u.l.IsLevelEnabled(logrus.DebugLevel) {
ctrlLen := getRawMessageControlLen(&msgs[i])
if ctrlLen > 0 {
packets[i].SegSize = parseGROControl(packets[i].Control[:ctrlLen])
} else {
packets[i].SegSize = 0
msgFlags := getRawMessageFlags(&msgs[i])
u.l.WithFields(logrus.Fields{
"tag": "gro-debug",
"stage": "recv",
"payload_len": len(payload),
"ctrl_len": ctrlLen,
"msg_flags": msgFlags,
}).Debug("gro batch data")
if controls != nil && ctrlLen > 0 {
maxDump := ctrlLen
if maxDump > 16 {
maxDump = 16
}
u.l.WithFields(logrus.Fields{
"tag": "gro-debug",
"stage": "control-bytes",
"control_hex": fmt.Sprintf("%x", controls[i][:maxDump]),
"datalen": ctrlLen,
}).Debug("gro control dump")
}
}
pc <- out
//rotate this packet out so we don't overwrite it
packets[i] = pg()
msgs[i].Hdr.Iov.Base = &packets[i].Payload[0]
if u.enableGRO {
msgs[i].Hdr.Control = &packets[i].Control[0]
msgs[i].Hdr.Controllen = uint64(cap(packets[i].Control))
sawControl := false
if controls != nil {
if ctrlLen := getRawMessageControlLen(&msgs[i]); ctrlLen > 0 {
if segSize, segCount := parseGROControl(controls[i][:ctrlLen]); segSize > 0 {
sawControl = true
if u.l.IsLevelEnabled(logrus.DebugLevel) {
u.l.WithFields(logrus.Fields{
"tag": "gro-debug",
"stage": "control",
"seg_size": segSize,
"seg_count": segCount,
"payloadLen": len(payload),
}).Debug("gro control parsed")
}
segSize = normalizeGROSegSize(segSize, segCount, len(payload))
if segSize > 0 && segSize < len(payload) {
if u.emitGROSegments(r, addr, payload, segSize) {
continue
}
}
}
}
}
if u.enableGRO && len(payload) > MTU {
if !sawControl && u.l.IsLevelEnabled(logrus.DebugLevel) {
u.l.WithFields(logrus.Fields{
"tag": "gro-debug",
"stage": "fallback",
"payload_len": len(payload),
}).Debug("gro control missing; splitting payload by MTU")
}
if u.emitGROSegments(r, addr, payload, MTU) {
continue
}
}
r(addr, payload)
}
}
}
func parseGROControl(control []byte) int {
if len(control) == 0 {
return 0
func (u *StdConn) readBufferSize() int {
if u.enableGRO && u.groBufSize > MTU {
return u.groBufSize
}
cmsgs, err := unix.ParseSocketControlMessage(control)
if err != nil {
return 0
}
for _, c := range cmsgs {
if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 {
segSize := int(binary.LittleEndian.Uint16(c.Data[:2]))
return segSize
}
}
return 0
return MTU
}
func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
@@ -237,6 +329,14 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
}
func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
if u.enableGSO && ip.IsValid() {
if err := u.queueGSOPacket(b, ip); err == nil {
return nil
} else if !errors.Is(err, errGSOFallback) {
return err
}
}
if u.isV4 {
return u.writeTo4(b, ip)
}
@@ -343,11 +443,23 @@ func (u *StdConn) ReloadConfig(c *config.C) {
u.l.WithError(err).Error("Failed to set listen.so_mark")
}
}
u.configureGRO(true)
u.configureGRO(c)
u.configureGSO(c)
}
func (u *StdConn) configureGRO(enable bool) {
func (u *StdConn) configureGRO(c *config.C) {
if c == nil {
return
}
enable := c.GetBool("listen.enable_gro", false)
if enable == u.enableGRO {
if enable {
if size := c.GetInt("listen.gro_read_buffer", 0); size > 0 {
u.setGROBufferSize(size)
}
}
return
}
@@ -357,7 +469,8 @@ func (u *StdConn) configureGRO(enable bool) {
return
}
u.enableGRO = true
u.l.Info("UDP GRO enabled")
u.setGROBufferSize(c.GetInt("listen.gro_read_buffer", defaultGROReadBufferSize))
u.l.WithField("buffer_size", u.groBufSize).Info("UDP GRO enabled")
return
}
@@ -365,6 +478,59 @@ func (u *StdConn) configureGRO(enable bool) {
u.l.WithError(err).Warn("Failed to disable UDP GRO")
}
u.enableGRO = false
u.groBufSize = MTU
}
func (u *StdConn) configureGSO(c *config.C) {
enable := c.GetBool("listen.enable_gso", false)
if !enable {
u.disableGSO()
} else {
u.enableGSO = true
}
segments := c.GetInt("listen.gso_max_segments", defaultGSOMaxSegments)
if segments < 1 {
segments = 1
}
u.gsoMaxSegments = segments
maxBytes := c.GetInt("listen.gso_max_bytes", 0)
if maxBytes <= 0 {
maxBytes = MTU * segments
}
if maxBytes > maxGSOBatchBytes {
u.l.WithField("requested", maxBytes).Warn("listen.gso_max_bytes larger than UDP limit; clamping")
maxBytes = maxGSOBatchBytes
}
u.gsoMaxBytes = maxBytes
timeout := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushTimeout)
if timeout < 0 {
timeout = 0
}
u.gsoFlushTimeout = timeout
}
func (u *StdConn) setGROBufferSize(size int) {
if size < MTU {
size = defaultGROReadBufferSize
}
if size > maxGSOBatchBytes {
size = maxGSOBatchBytes
}
u.groBufSize = size
}
func (u *StdConn) disableGSO() {
u.gsoMu.Lock()
defer u.gsoMu.Unlock()
u.enableGSO = false
_ = u.flushGSOlocked()
u.gsoBuf = nil
u.gsoSegments = 0
u.gsoSegSize = 0
u.stopGSOTimerLocked()
}
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
@@ -376,7 +542,239 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
return nil
}
func (u *StdConn) queueGSOPacket(b []byte, addr netip.AddrPort) error {
if len(b) == 0 {
return nil
}
u.gsoMu.Lock()
defer u.gsoMu.Unlock()
if !u.enableGSO || !addr.IsValid() || len(b) > u.gsoMaxBytes {
if err := u.flushGSOlocked(); err != nil {
return err
}
return errGSOFallback
}
if u.gsoSegments == 0 {
if cap(u.gsoBuf) < u.gsoMaxBytes {
u.gsoBuf = make([]byte, 0, u.gsoMaxBytes)
}
u.gsoAddr = addr
u.gsoSegSize = len(b)
} else if addr != u.gsoAddr || len(b) != u.gsoSegSize {
if err := u.flushGSOlocked(); err != nil {
return err
}
if cap(u.gsoBuf) < u.gsoMaxBytes {
u.gsoBuf = make([]byte, 0, u.gsoMaxBytes)
}
u.gsoAddr = addr
u.gsoSegSize = len(b)
}
if len(u.gsoBuf)+len(b) > u.gsoMaxBytes {
if err := u.flushGSOlocked(); err != nil {
return err
}
if cap(u.gsoBuf) < u.gsoMaxBytes {
u.gsoBuf = make([]byte, 0, u.gsoMaxBytes)
}
u.gsoAddr = addr
u.gsoSegSize = len(b)
}
u.gsoBuf = append(u.gsoBuf, b...)
u.gsoSegments++
if u.gsoSegments >= u.gsoMaxSegments || u.gsoFlushTimeout <= 0 {
return u.flushGSOlocked()
}
u.scheduleGSOFlushLocked()
return nil
}
func (u *StdConn) flushGSOlocked() error {
if u.gsoSegments == 0 {
u.stopGSOTimerLocked()
return nil
}
payload := append([]byte(nil), u.gsoBuf...)
addr := u.gsoAddr
segSize := u.gsoSegSize
u.gsoBuf = u.gsoBuf[:0]
u.gsoSegments = 0
u.gsoSegSize = 0
u.stopGSOTimerLocked()
if segSize <= 0 {
return errGSOFallback
}
err := u.sendSegmented(payload, addr, segSize)
if errors.Is(err, errGSODisabled) {
u.l.WithField("addr", addr).Warn("UDP GSO disabled by kernel, falling back to sendto")
u.enableGSO = false
return u.sendSegmentsIndividually(payload, addr, segSize)
}
return err
}
func (u *StdConn) sendSegmented(payload []byte, addr netip.AddrPort, segSize int) error {
if len(payload) == 0 {
return nil
}
control := make([]byte, unix.CmsgSpace(2))
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.SOL_UDP
hdr.Type = unix.UDP_SEGMENT
setCmsgLen(hdr, unix.CmsgLen(2))
binary.NativeEndian.PutUint16(control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(segSize))
var sa unix.Sockaddr
if addr.Addr().Is4() {
var sa4 unix.SockaddrInet4
sa4.Port = int(addr.Port())
sa4.Addr = addr.Addr().As4()
sa = &sa4
} else {
var sa6 unix.SockaddrInet6
sa6.Port = int(addr.Port())
sa6.Addr = addr.Addr().As16()
sa = &sa6
}
if _, err := unix.SendmsgN(u.sysFd, payload, control, sa, 0); err != nil {
if errno, ok := err.(syscall.Errno); ok && (errno == unix.EINVAL || errno == unix.ENOTSUP || errno == unix.EOPNOTSUPP) {
return errGSODisabled
}
return &net.OpError{Op: "sendmsg", Err: err}
}
return nil
}
func (u *StdConn) sendSegmentsIndividually(buf []byte, addr netip.AddrPort, segSize int) error {
if segSize <= 0 {
return errGSOFallback
}
for offset := 0; offset < len(buf); offset += segSize {
end := offset + segSize
if end > len(buf) {
end = len(buf)
}
var err error
if u.isV4 {
err = u.writeTo4(buf[offset:end], addr)
} else {
err = u.writeTo6(buf[offset:end], addr)
}
if err != nil {
return err
}
}
return nil
}
func (u *StdConn) scheduleGSOFlushLocked() {
if u.gsoTimer == nil {
u.gsoTimer = time.AfterFunc(u.gsoFlushTimeout, u.gsoFlushTimer)
return
}
u.gsoTimer.Reset(u.gsoFlushTimeout)
}
func (u *StdConn) stopGSOTimerLocked() {
if u.gsoTimer != nil {
u.gsoTimer.Stop()
u.gsoTimer = nil
}
}
func (u *StdConn) gsoFlushTimer() {
u.gsoMu.Lock()
defer u.gsoMu.Unlock()
_ = u.flushGSOlocked()
}
func parseGROControl(control []byte) (int, int) {
if len(control) == 0 {
return 0, 0
}
cmsgs, err := unix.ParseSocketControlMessage(control)
if err != nil {
return 0, 0
}
for _, c := range cmsgs {
if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 {
segSize := int(binary.NativeEndian.Uint16(c.Data[:2]))
segCount := 0
if len(c.Data) >= 4 {
segCount = int(binary.NativeEndian.Uint16(c.Data[2:4]))
}
return segSize, segCount
}
}
return 0, 0
}
func (u *StdConn) emitGROSegments(r EncReader, addr netip.AddrPort, payload []byte, segSize int) bool {
if segSize <= 0 {
return false
}
for offset := 0; offset < len(payload); offset += segSize {
end := offset + segSize
if end > len(payload) {
end = len(payload)
}
segment := make([]byte, end-offset)
copy(segment, payload[offset:end])
r(addr, segment)
}
return true
}
func normalizeGROSegSize(segSize, segCount, total int) int {
if segSize <= 0 || total <= 0 {
return segSize
}
if segSize > total && segCount > 0 {
segSize = total / segCount
if segSize == 0 {
segSize = total
}
}
if segCount <= 1 && segSize > 0 && total > segSize {
calculated := total / segSize
if calculated <= 1 {
calculated = (total + segSize - 1) / segSize
}
if calculated > 1 {
segCount = calculated
}
}
if segSize > MTU {
return MTU
}
return segSize
}
func (u *StdConn) Close() error {
u.disableGSO()
return syscall.Close(u.sysFd)
}

View File

@@ -30,13 +30,16 @@ type rawMessage struct {
Len uint32
}
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
func (u *StdConn) PrepareRawMessages(n int, bufSize int) ([]rawMessage, [][]byte, [][]byte) {
if bufSize <= 0 {
bufSize = MTU
}
msgs := make([]rawMessage, n)
buffers := make([][]byte, n)
names := make([][]byte, n)
for i := range msgs {
buffers[i] = make([]byte, MTU)
buffers[i] = make([]byte, bufSize)
names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{
@@ -52,3 +55,25 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
return msgs, buffers, names
}
func setRawMessageControl(msg *rawMessage, buf []byte) {
if len(buf) == 0 {
msg.Hdr.Control = nil
msg.Hdr.Controllen = 0
return
}
msg.Hdr.Control = &buf[0]
msg.Hdr.Controllen = uint32(len(buf))
}
func getRawMessageControlLen(msg *rawMessage) int {
return int(msg.Hdr.Controllen)
}
func getRawMessageFlags(msg *rawMessage) int {
return int(msg.Hdr.Flags)
}
func setCmsgLen(h *unix.Cmsghdr, l int) {
h.Len = uint32(l)
}

View File

@@ -7,7 +7,6 @@
package udp
import (
"github.com/slackhq/nebula/packet"
"golang.org/x/sys/unix"
)
@@ -34,6 +33,32 @@ type rawMessage struct {
Pad0 [4]byte
}
func (u *StdConn) PrepareRawMessages(n int, bufSize int) ([]rawMessage, [][]byte, [][]byte) {
if bufSize <= 0 {
bufSize = MTU
}
msgs := make([]rawMessage, n)
buffers := make([][]byte, n)
names := make([][]byte, n)
for i := range msgs {
buffers[i] = make([]byte, bufSize)
names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{
{Base: &buffers[i][0], Len: uint64(len(buffers[i]))},
}
msgs[i].Hdr.Iov = &vs[0]
msgs[i].Hdr.Iovlen = uint64(len(vs))
msgs[i].Hdr.Name = &names[i][0]
msgs[i].Hdr.Namelen = uint32(len(names[i]))
}
return msgs, buffers, names
}
func setRawMessageControl(msg *rawMessage, buf []byte) {
if len(buf) == 0 {
msg.Hdr.Control = nil
@@ -48,40 +73,10 @@ func getRawMessageControlLen(msg *rawMessage) int {
return int(msg.Hdr.Controllen)
}
func getRawMessageFlags(msg *rawMessage) int {
return int(msg.Hdr.Flags)
}
func setCmsgLen(h *unix.Cmsghdr, l int) {
h.Len = uint64(l)
}
func (u *StdConn) PrepareRawMessages(n int, pg PacketBufferGetter) ([]rawMessage, []*packet.Packet, [][]byte) {
msgs := make([]rawMessage, n)
names := make([][]byte, n)
packets := make([]*packet.Packet, n)
for i := range packets {
packets[i] = pg()
}
//todo?
for i := range msgs {
names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{
{Base: &packets[i].Payload[0], Len: uint64(packet.Size)},
}
msgs[i].Hdr.Iov = &vs[0]
msgs[i].Hdr.Iovlen = uint64(len(vs))
msgs[i].Hdr.Name = &names[i][0]
msgs[i].Hdr.Namelen = uint32(len(names[i]))
if u.enableGRO {
msgs[i].Hdr.Control = &packets[i].Control[0]
msgs[i].Hdr.Controllen = uint64(len(packets[i].Control))
} else {
msgs[i].Hdr.Control = nil
msgs[i].Hdr.Controllen = 0
}
}
return msgs, packets, names
}