add batching of packets

This commit is contained in:
Ryan
2025-11-06 09:42:13 -05:00
parent a0f8cb2098
commit ad37749c5e
3 changed files with 167 additions and 42 deletions

View File

@@ -22,7 +22,14 @@ import (
"github.com/slackhq/nebula/udp"
)
const mtu = 9001
const (
mtu = 9001
inboundBatchSize = 32
outboundBatchSize = 32
batchFlushInterval = 50 * time.Microsecond
maxOutstandingBatches = 1028
)
type InterfaceConfig struct {
HostMap *HostMap
@@ -97,10 +104,81 @@ type Interface struct {
l *logrus.Logger
inPool sync.Pool
inbound chan *packet.Packet
inbound []chan *packetBatch
outPool sync.Pool
outbound chan *[]byte
outbound []chan *outboundBatch
packetBatchPool sync.Pool
outboundBatchPool sync.Pool
}
type packetBatch struct {
packets []*packet.Packet
}
func newPacketBatch() *packetBatch {
return &packetBatch{
packets: make([]*packet.Packet, 0, inboundBatchSize),
}
}
func (b *packetBatch) add(p *packet.Packet) {
b.packets = append(b.packets, p)
}
func (b *packetBatch) reset() {
for i := range b.packets {
b.packets[i] = nil
}
b.packets = b.packets[:0]
}
func (f *Interface) getPacketBatch() *packetBatch {
if v := f.packetBatchPool.Get(); v != nil {
b := v.(*packetBatch)
b.reset()
return b
}
return newPacketBatch()
}
func (f *Interface) releasePacketBatch(b *packetBatch) {
b.reset()
f.packetBatchPool.Put(b)
}
type outboundBatch struct {
payloads []*[]byte
}
func newOutboundBatch() *outboundBatch {
return &outboundBatch{payloads: make([]*[]byte, 0, outboundBatchSize)}
}
func (b *outboundBatch) add(buf *[]byte) {
b.payloads = append(b.payloads, buf)
}
func (b *outboundBatch) reset() {
for i := range b.payloads {
b.payloads[i] = nil
}
b.payloads = b.payloads[:0]
}
func (f *Interface) getOutboundBatch() *outboundBatch {
if v := f.outboundBatchPool.Get(); v != nil {
b := v.(*outboundBatch)
b.reset()
return b
}
return newOutboundBatch()
}
func (f *Interface) releaseOutboundBatch(b *outboundBatch) {
b.reset()
f.outboundBatchPool.Put(b)
}
type EncWriter interface {
@@ -203,12 +281,17 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
},
//TODO: configurable size
inbound: make(chan *packet.Packet, 1028),
outbound: make(chan *[]byte, 1028),
inbound: make([]chan *packetBatch, c.routines),
outbound: make([]chan *outboundBatch, c.routines),
l: c.l,
}
for i := 0; i < c.routines; i++ {
ifce.inbound[i] = make(chan *packetBatch, maxOutstandingBatches)
ifce.outbound[i] = make(chan *outboundBatch, maxOutstandingBatches)
}
ifce.inPool = sync.Pool{New: func() any {
return packet.New()
}}
@@ -218,6 +301,14 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
return &t
}}
ifce.packetBatchPool = sync.Pool{New: func() any {
return newPacketBatch()
}}
ifce.outboundBatchPool = sync.Pool{New: func() any {
return newOutboundBatch()
}}
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
ifce.reQueryEvery.Store(c.reQueryEvery)
ifce.reQueryWait.Store(int64(c.reQueryWait))
@@ -296,22 +387,41 @@ func (f *Interface) listenOut(i int) {
li = f.outside
}
batch := f.getPacketBatch()
lastFlush := time.Now()
flush := func(force bool) {
if len(batch.packets) == 0 {
if force {
f.releasePacketBatch(batch)
}
return
}
f.inbound[i] <- batch
batch = f.getPacketBatch()
lastFlush = time.Now()
}
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
p := f.inPool.Get().(*packet.Packet)
//TODO: have the listener store this in the msgs array after a read instead of doing a copy
p.Payload = p.Payload[:mtu]
copy(p.Payload, payload)
p.Payload = p.Payload[:len(payload)]
p.Addr = fromUdpAddr
f.inbound <- p
//select {
//case f.inbound <- p:
//default:
// f.l.Error("Dropped packet from inbound channel")
//}
batch.add(p)
if len(batch.packets) >= inboundBatchSize || time.Since(lastFlush) >= batchFlushInterval {
flush(false)
}
})
if len(batch.packets) > 0 {
f.inbound[i] <- batch
} else {
f.releasePacketBatch(batch)
}
if err != nil && !f.closed.Load() {
f.l.WithError(err).Error("Error while reading packet inbound packet, closing")
//TODO: Trigger Control to close
@@ -324,6 +434,22 @@ func (f *Interface) listenOut(i int) {
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
runtime.LockOSThread()
batch := f.getOutboundBatch()
lastFlush := time.Now()
flush := func(force bool) {
if len(batch.payloads) == 0 {
if force {
f.releaseOutboundBatch(batch)
}
return
}
f.outbound[i] <- batch
batch = f.getOutboundBatch()
lastFlush = time.Now()
}
for {
p := f.outPool.Get().(*[]byte)
*p = (*p)[:mtu]
@@ -337,13 +463,17 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
}
*p = (*p)[:n]
//TODO: nonblocking channel write
f.outbound <- p
//select {
//case f.outbound <- p:
//default:
// f.l.Error("Dropped packet from outbound channel")
//}
batch.add(p)
if len(batch.payloads) >= outboundBatchSize || time.Since(lastFlush) >= batchFlushInterval {
flush(false)
}
}
if len(batch.payloads) > 0 {
f.outbound[i] <- batch
} else {
f.releaseOutboundBatch(batch)
}
f.l.Debugf("overlay reader %v is done", i)
@@ -360,10 +490,13 @@ func (f *Interface) workerIn(i int, ctx context.Context) {
for {
select {
case p := <-f.inbound:
f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
p.Payload = p.Payload[:mtu]
f.inPool.Put(p)
case batch := <-f.inbound[i]:
for _, p := range batch.packets {
f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
p.Payload = p.Payload[:mtu]
f.inPool.Put(p)
}
f.releasePacketBatch(batch)
case <-ctx.Done():
f.wg.Done()
return
@@ -379,10 +512,13 @@ func (f *Interface) workerOut(i int, ctx context.Context) {
for {
select {
case data := <-f.outbound:
f.consumeInsidePacket(*data, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l))
*data = (*data)[:mtu]
f.outPool.Put(data)
case batch := <-f.outbound[i]:
for _, data := range batch.payloads {
f.consumeInsidePacket(*data, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l))
*data = (*data)[:mtu]
f.outPool.Put(data)
}
f.releaseOutboundBatch(batch)
case <-ctx.Done():
f.wg.Done()
return

View File

@@ -9,13 +9,10 @@ import (
"math"
"net"
"net/netip"
"os"
"strings"
"sync"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay"
"golang.org/x/sync/errgroup"
"gvisor.dev/gvisor/pkg/buffer"
@@ -46,15 +43,7 @@ type Service struct {
}
}
func New(config *config.C) (*Service, error) {
logger := logrus.New()
logger.Out = os.Stdout
control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
if err != nil {
return nil, err
}
func New(control *nebula.Control) (*Service, error) {
wait, err := control.Start()
if err != nil {
return nil, err

View File

@@ -30,8 +30,8 @@ func (NoopConn) Rebind() error {
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
return netip.AddrPort{}, nil
}
func (NoopConn) ListenOut(_ EncReader) {
return
func (NoopConn) ListenOut(_ EncReader) error {
return nil
}
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
return nil