From 2ab75709ad84e02a1010b30e5181f249cd13867c Mon Sep 17 00:00:00 2001 From: JackDoan Date: Tue, 4 Nov 2025 15:40:33 -0600 Subject: [PATCH] hmm --- connection_state.go | 2 +- interface.go | 65 ++++++++++++++------------------------------- packet/packet.go | 28 +++++++++++++++++-- udp/conn.go | 10 +++---- udp/udp_linux.go | 15 ++++++++--- udp/udp_linux_64.go | 14 ++++++---- 6 files changed, 73 insertions(+), 61 deletions(-) diff --git a/connection_state.go b/connection_state.go index b5d6d0f..485f6fd 100644 --- a/connection_state.go +++ b/connection_state.go @@ -15,7 +15,7 @@ import ( // TODO: In a 5Gbps test, 1024 is not sufficient. With a 1400 MTU this is about 1.4Gbps of window, assuming full packets. // 4092 should be sufficient for 5Gbps -const ReplayWindow = 1024 +const ReplayWindow = 8192 type ConnectionState struct { eKey *NebulaCipherState diff --git a/interface.go b/interface.go index 06ece3e..424a66f 100644 --- a/interface.go +++ b/interface.go @@ -96,11 +96,9 @@ type Interface struct { l *logrus.Logger - inPool sync.Pool - inbound chan *packet.Packet - - outPool sync.Pool - outbound chan *[]byte + pktPool *packet.Pool + inbound chan *packet.Packet + outbound chan *packet.Packet } type EncWriter interface { @@ -203,20 +201,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { }, //TODO: configurable size - inbound: make(chan *packet.Packet, 1028), - outbound: make(chan *[]byte, 1028), + inbound: make(chan *packet.Packet, 2048), + outbound: make(chan *packet.Packet, 2048), l: c.l, } - ifce.inPool = sync.Pool{New: func() any { - return packet.New() - }} - - ifce.outPool = sync.Pool{New: func() any { - t := make([]byte, mtu) - return &t - }} + ifce.pktPool = packet.NewPool() ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.reQueryEvery.Store(c.reQueryEvery) @@ -267,19 +258,21 @@ func (f *Interface) activate() error { func (f *Interface) run(c context.Context) (func(), error) { for i := 0; i < f.routines; i++ { - // Launch n queues to read packets from udp + // read packets from udp and queue to f.inbound f.wg.Add(1) go f.listenOut(i) - // Launch n queues to read packets from tun dev - f.wg.Add(1) + // Launch n queues to read packets from inside tun dev and queue to f.outbound + //todo this never stops f.wg.Add(1) go f.listenIn(f.readers[i], i) - // Launch n queues to read packets from tun dev + // Launch n workers to process traffic from f.inbound and smash it onto the inside of the tun + f.wg.Add(1) + go f.workerIn(i, c) f.wg.Add(1) go f.workerIn(i, c) - // Launch n queues to read packets from tun dev + // read from f.outbound and write to UDP (outside the tun) f.wg.Add(1) go f.workerOut(i, c) } @@ -296,22 +289,7 @@ func (f *Interface) listenOut(i int) { li = f.outside } - err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { - p := f.inPool.Get().(*packet.Packet) - //TODO: have the listener store this in the msgs array after a read instead of doing a copy - - p.Payload = p.Payload[:mtu] - copy(p.Payload, payload) - p.Payload = p.Payload[:len(payload)] - p.Addr = fromUdpAddr - f.inbound <- p - //select { - //case f.inbound <- p: - //default: - // f.l.Error("Dropped packet from inbound channel") - //} - }) - + err := li.ListenOut(f.pktPool.Get, f.inbound) if err != nil && !f.closed.Load() { f.l.WithError(err).Error("Error while reading packet inbound packet, closing") //TODO: Trigger Control to close @@ -325,9 +303,8 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { runtime.LockOSThread() for { - p := f.outPool.Get().(*[]byte) - *p = (*p)[:mtu] - n, err := reader.Read(*p) + p := f.pktPool.Get() + n, err := reader.Read(p.Payload) if err != nil { if !f.closed.Load() { f.l.WithError(err).Error("Error while reading outbound packet, closing") @@ -336,7 +313,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { break } - *p = (*p)[:n] + p.Payload = (p.Payload)[:n] //TODO: nonblocking channel write f.outbound <- p //select { @@ -362,8 +339,7 @@ func (f *Interface) workerIn(i int, ctx context.Context) { select { case p := <-f.inbound: f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l)) - p.Payload = p.Payload[:mtu] - f.inPool.Put(p) + f.pktPool.Put(p) case <-ctx.Done(): f.wg.Done() return @@ -380,9 +356,8 @@ func (f *Interface) workerOut(i int, ctx context.Context) { for { select { case data := <-f.outbound: - f.consumeInsidePacket(*data, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l)) - *data = (*data)[:mtu] - f.outPool.Put(data) + f.consumeInsidePacket(data.Payload, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l)) + f.pktPool.Put(data) case <-ctx.Done(): f.wg.Done() return diff --git a/packet/packet.go b/packet/packet.go index 83dd9dd..d6dd01c 100644 --- a/packet/packet.go +++ b/packet/packet.go @@ -1,6 +1,11 @@ package packet -import "net/netip" +import ( + "net/netip" + "sync" +) + +const Size = 9001 type Packet struct { Payload []byte @@ -8,5 +13,24 @@ type Packet struct { } func New() *Packet { - return &Packet{Payload: make([]byte, 9001)} + return &Packet{Payload: make([]byte, Size)} +} + +type Pool struct { + pool sync.Pool +} + +func NewPool() *Pool { + return &Pool{ + pool: sync.Pool{New: func() any { return New() }}, + } +} + +func (p *Pool) Get() *Packet { + return p.pool.Get().(*Packet) +} + +func (p *Pool) Put(x *Packet) { + x.Payload = x.Payload[:Size] + p.pool.Put(x) } diff --git a/udp/conn.go b/udp/conn.go index 27fcd22..6d0b79e 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -4,19 +4,19 @@ import ( "net/netip" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/packet" ) const MTU = 9001 -type EncReader func( - addr netip.AddrPort, - payload []byte, -) +type EncReader func(*packet.Packet) + +type PacketBufferGetter func() *packet.Packet type Conn interface { Rebind() error LocalAddr() (netip.AddrPort, error) - ListenOut(r EncReader) error + ListenOut(pg PacketBufferGetter, pc chan *packet.Packet) error WriteTo(b []byte, addr netip.AddrPort) error ReloadConfig(c *config.C) Close() error diff --git a/udp/udp_linux.go b/udp/udp_linux.go index e3df48f..34aaa3f 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -15,6 +15,7 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/packet" "golang.org/x/sys/unix" ) @@ -118,10 +119,10 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) { } } -func (u *StdConn) ListenOut(r EncReader) error { +func (u *StdConn) ListenOut(pg PacketBufferGetter, pc chan *packet.Packet) error { var ip netip.Addr - msgs, buffers, names := u.PrepareRawMessages(u.batch) + msgs, packets, names := u.PrepareRawMessages(u.batch, pg) read := u.ReadMulti if u.batch == 1 { read = u.ReadSingle @@ -134,13 +135,21 @@ func (u *StdConn) ListenOut(r EncReader) error { } for i := 0; i < n; i++ { + out := packets[i] + out.Payload = out.Payload[:msgs[i].Len] + // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic if u.isV4 { ip, _ = netip.AddrFromSlice(names[i][4:8]) } else { ip, _ = netip.AddrFromSlice(names[i][8:24]) } - r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) + out.Addr = netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])) + pc <- out + + //rotate this packet out so we don't overwrite it + packets[i] = pg() + msgs[i].Hdr.Iov.Base = &packets[i].Payload[0] } } } diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 48c5a97..0550db7 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -7,6 +7,7 @@ package udp import ( + "github.com/slackhq/nebula/packet" "golang.org/x/sys/unix" ) @@ -33,17 +34,20 @@ type rawMessage struct { Pad0 [4]byte } -func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { +func (u *StdConn) PrepareRawMessages(n int, pg PacketBufferGetter) ([]rawMessage, []*packet.Packet, [][]byte) { msgs := make([]rawMessage, n) - buffers := make([][]byte, n) names := make([][]byte, n) + packets := make([]*packet.Packet, n) + for i := range packets { + packets[i] = pg() + } + for i := range msgs { - buffers[i] = make([]byte, MTU) names[i] = make([]byte, unix.SizeofSockaddrInet6) vs := []iovec{ - {Base: &buffers[i][0], Len: uint64(len(buffers[i]))}, + {Base: &packets[i].Payload[0], Len: uint64(packet.Size)}, } msgs[i].Hdr.Iov = &vs[0] @@ -53,5 +57,5 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { msgs[i].Hdr.Namelen = uint32(len(names[i])) } - return msgs, buffers, names + return msgs, packets, names }