From ad37749c5e7575453cf4e21fbad1e307605f9061 Mon Sep 17 00:00:00 2001 From: Ryan Date: Thu, 6 Nov 2025 09:42:13 -0500 Subject: [PATCH] add batching of packets --- interface.go | 192 ++++++++++++++++++++++++++++++++++++++------- service/service.go | 13 +-- udp/conn.go | 4 +- 3 files changed, 167 insertions(+), 42 deletions(-) diff --git a/interface.go b/interface.go index 06ece3e..3df31d3 100644 --- a/interface.go +++ b/interface.go @@ -22,7 +22,14 @@ import ( "github.com/slackhq/nebula/udp" ) -const mtu = 9001 +const ( + mtu = 9001 + + inboundBatchSize = 32 + outboundBatchSize = 32 + batchFlushInterval = 50 * time.Microsecond + maxOutstandingBatches = 1028 +) type InterfaceConfig struct { HostMap *HostMap @@ -97,10 +104,81 @@ type Interface struct { l *logrus.Logger inPool sync.Pool - inbound chan *packet.Packet + inbound []chan *packetBatch outPool sync.Pool - outbound chan *[]byte + outbound []chan *outboundBatch + + packetBatchPool sync.Pool + outboundBatchPool sync.Pool +} + +type packetBatch struct { + packets []*packet.Packet +} + +func newPacketBatch() *packetBatch { + return &packetBatch{ + packets: make([]*packet.Packet, 0, inboundBatchSize), + } +} + +func (b *packetBatch) add(p *packet.Packet) { + b.packets = append(b.packets, p) +} + +func (b *packetBatch) reset() { + for i := range b.packets { + b.packets[i] = nil + } + b.packets = b.packets[:0] +} + +func (f *Interface) getPacketBatch() *packetBatch { + if v := f.packetBatchPool.Get(); v != nil { + b := v.(*packetBatch) + b.reset() + return b + } + return newPacketBatch() +} + +func (f *Interface) releasePacketBatch(b *packetBatch) { + b.reset() + f.packetBatchPool.Put(b) +} + +type outboundBatch struct { + payloads []*[]byte +} + +func newOutboundBatch() *outboundBatch { + return &outboundBatch{payloads: make([]*[]byte, 0, outboundBatchSize)} +} + +func (b *outboundBatch) add(buf *[]byte) { + b.payloads = append(b.payloads, buf) +} + +func (b *outboundBatch) reset() { + for i := range b.payloads { + b.payloads[i] = nil + } + b.payloads = b.payloads[:0] +} + +func (f *Interface) getOutboundBatch() *outboundBatch { + if v := f.outboundBatchPool.Get(); v != nil { + b := v.(*outboundBatch) + b.reset() + return b + } + return newOutboundBatch() +} + +func (f *Interface) releaseOutboundBatch(b *outboundBatch) { + b.reset() + f.outboundBatchPool.Put(b) } type EncWriter interface { @@ -203,12 +281,17 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { }, //TODO: configurable size - inbound: make(chan *packet.Packet, 1028), - outbound: make(chan *[]byte, 1028), + inbound: make([]chan *packetBatch, c.routines), + outbound: make([]chan *outboundBatch, c.routines), l: c.l, } + for i := 0; i < c.routines; i++ { + ifce.inbound[i] = make(chan *packetBatch, maxOutstandingBatches) + ifce.outbound[i] = make(chan *outboundBatch, maxOutstandingBatches) + } + ifce.inPool = sync.Pool{New: func() any { return packet.New() }} @@ -218,6 +301,14 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { return &t }} + ifce.packetBatchPool = sync.Pool{New: func() any { + return newPacketBatch() + }} + + ifce.outboundBatchPool = sync.Pool{New: func() any { + return newOutboundBatch() + }} + ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryWait.Store(int64(c.reQueryWait)) @@ -296,22 +387,41 @@ func (f *Interface) listenOut(i int) { li = f.outside } + batch := f.getPacketBatch() + lastFlush := time.Now() + + flush := func(force bool) { + if len(batch.packets) == 0 { + if force { + f.releasePacketBatch(batch) + } + return + } + + f.inbound[i] <- batch + batch = f.getPacketBatch() + lastFlush = time.Now() + } + err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { p := f.inPool.Get().(*packet.Packet) - //TODO: have the listener store this in the msgs array after a read instead of doing a copy - p.Payload = p.Payload[:mtu] copy(p.Payload, payload) p.Payload = p.Payload[:len(payload)] p.Addr = fromUdpAddr - f.inbound <- p - //select { - //case f.inbound <- p: - //default: - // f.l.Error("Dropped packet from inbound channel") - //} + batch.add(p) + + if len(batch.packets) >= inboundBatchSize || time.Since(lastFlush) >= batchFlushInterval { + flush(false) + } }) + if len(batch.packets) > 0 { + f.inbound[i] <- batch + } else { + f.releasePacketBatch(batch) + } + if err != nil && !f.closed.Load() { f.l.WithError(err).Error("Error while reading packet inbound packet, closing") //TODO: Trigger Control to close @@ -324,6 +434,22 @@ func (f *Interface) listenOut(i int) { func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { runtime.LockOSThread() + batch := f.getOutboundBatch() + lastFlush := time.Now() + + flush := func(force bool) { + if len(batch.payloads) == 0 { + if force { + f.releaseOutboundBatch(batch) + } + return + } + + f.outbound[i] <- batch + batch = f.getOutboundBatch() + lastFlush = time.Now() + } + for { p := f.outPool.Get().(*[]byte) *p = (*p)[:mtu] @@ -337,13 +463,17 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { } *p = (*p)[:n] - //TODO: nonblocking channel write - f.outbound <- p - //select { - //case f.outbound <- p: - //default: - // f.l.Error("Dropped packet from outbound channel") - //} + batch.add(p) + + if len(batch.payloads) >= outboundBatchSize || time.Since(lastFlush) >= batchFlushInterval { + flush(false) + } + } + + if len(batch.payloads) > 0 { + f.outbound[i] <- batch + } else { + f.releaseOutboundBatch(batch) } f.l.Debugf("overlay reader %v is done", i) @@ -360,10 +490,13 @@ func (f *Interface) workerIn(i int, ctx context.Context) { for { select { - case p := <-f.inbound: - f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l)) - p.Payload = p.Payload[:mtu] - f.inPool.Put(p) + case batch := <-f.inbound[i]: + for _, p := range batch.packets { + f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l)) + p.Payload = p.Payload[:mtu] + f.inPool.Put(p) + } + f.releasePacketBatch(batch) case <-ctx.Done(): f.wg.Done() return @@ -379,10 +512,13 @@ func (f *Interface) workerOut(i int, ctx context.Context) { for { select { - case data := <-f.outbound: - f.consumeInsidePacket(*data, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l)) - *data = (*data)[:mtu] - f.outPool.Put(data) + case batch := <-f.outbound[i]: + for _, data := range batch.payloads { + f.consumeInsidePacket(*data, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l)) + *data = (*data)[:mtu] + f.outPool.Put(data) + } + f.releaseOutboundBatch(batch) case <-ctx.Done(): f.wg.Done() return diff --git a/service/service.go b/service/service.go index 16c244b..c86d08c 100644 --- a/service/service.go +++ b/service/service.go @@ -9,13 +9,10 @@ import ( "math" "net" "net/netip" - "os" "strings" "sync" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula" - "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/overlay" "golang.org/x/sync/errgroup" "gvisor.dev/gvisor/pkg/buffer" @@ -46,15 +43,7 @@ type Service struct { } } -func New(config *config.C) (*Service, error) { - logger := logrus.New() - logger.Out = os.Stdout - - control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig) - if err != nil { - return nil, err - } - +func New(control *nebula.Control) (*Service, error) { wait, err := control.Start() if err != nil { return nil, err diff --git a/udp/conn.go b/udp/conn.go index 27fcd22..340d30c 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -30,8 +30,8 @@ func (NoopConn) Rebind() error { func (NoopConn) LocalAddr() (netip.AddrPort, error) { return netip.AddrPort{}, nil } -func (NoopConn) ListenOut(_ EncReader) { - return +func (NoopConn) ListenOut(_ EncReader) error { + return nil } func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { return nil