From bc9711df68d6d0d7b343f5a613b59ba0228c9d06 Mon Sep 17 00:00:00 2001 From: Jay Wren Date: Wed, 19 Nov 2025 16:58:59 -0500 Subject: [PATCH] batch more writes --- interface.go | 64 +++++------------------- main.go | 2 +- outside.go | 105 ++++++++++++++++++++++++++++++++++++++++ overlay/device.go | 18 ++++++- overlay/tun_android.go | 25 +++++++++- overlay/tun_darwin.go | 28 ++++++++++- overlay/tun_disabled.go | 28 ++++++++++- overlay/tun_freebsd.go | 28 ++++++++++- overlay/tun_ios.go | 25 +++++++++- overlay/tun_linux.go | 2 +- overlay/tun_netbsd.go | 25 +++++++++- overlay/tun_openbsd.go | 25 +++++++++- overlay/tun_tester.go | 25 +++++++++- overlay/tun_windows.go | 29 ++++++++++- overlay/user.go | 28 ++++++++++- 15 files changed, 389 insertions(+), 68 deletions(-) diff --git a/interface.go b/interface.go index 74e2c84..905d06d 100644 --- a/interface.go +++ b/interface.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io" "net/netip" "os" "runtime" @@ -52,9 +51,10 @@ type InterfaceConfig struct { } type batchMetrics struct { - udpReadSize metrics.Histogram - tunReadSize metrics.Histogram - udpWriteSize metrics.Histogram + udpReadSize metrics.Histogram + tunReadSize metrics.Histogram + udpWriteSize metrics.Histogram + tunWriteSize metrics.Histogram } type Interface struct { @@ -93,7 +93,7 @@ type Interface struct { conntrackCacheTimeout time.Duration writers []udp.Conn - readers []io.ReadWriteCloser + readers []overlay.BatchReadWriter metricHandshakes metrics.Histogram messageMetrics *MessageMetrics @@ -185,7 +185,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { routines: c.routines, version: c.version, writers: make([]udp.Conn, c.routines), - readers: make([]io.ReadWriteCloser, c.routines), + readers: make([]overlay.BatchReadWriter, c.routines), myVpnNetworks: cs.myVpnNetworks, myVpnNetworksTable: cs.myVpnNetworksTable, myVpnAddrs: cs.myVpnAddrs, @@ -205,6 +205,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { udpReadSize: metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024)), tunReadSize: metrics.GetOrRegisterHistogram("batch.tun_read_size", nil, metrics.NewUniformSample(1024)), udpWriteSize: metrics.GetOrRegisterHistogram("batch.udp_write_size", nil, metrics.NewUniformSample(1024)), + tunWriteSize: metrics.GetOrRegisterHistogram("batch.tun_write_size", nil, metrics.NewUniformSample(1024)), }, l: c.l, @@ -238,7 +239,7 @@ func (f *Interface) activate() { metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines)) // Prepare n tun queues - var reader io.ReadWriteCloser = f.inside + var reader overlay.BatchReadWriter = f.inside for i := 0; i < f.routines; i++ { if i > 0 { reader, err = f.inside.NewMultiQueueReader() @@ -297,53 +298,10 @@ func (f *Interface) listenOut(i int) { }) } -// BatchReader is an interface for devices that support reading multiple packets at once -type BatchReader interface { - BatchRead(bufs [][]byte, sizes []int) (int, error) - BatchSize() int -} - -func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { +func (f *Interface) listenIn(reader overlay.BatchReadWriter, i int) { runtime.LockOSThread() - // Check if reader supports batching - batchReader, supportsBatching := reader.(BatchReader) - - if supportsBatching { - f.listenInBatch(reader, batchReader, i) - } else { - f.listenInSingle(reader, i) - } -} - -func (f *Interface) listenInSingle(reader io.ReadWriteCloser, i int) { - packet := make([]byte, mtu) - // Allocate out buffer with virtio header headroom (10 bytes) to avoid copies on write - outBuf := make([]byte, virtioNetHdrLen+mtu) - out := outBuf[virtioNetHdrLen:] // Use slice starting after headroom - fwPacket := &firewall.Packet{} - nb := make([]byte, 12) - - conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) - - for { - n, err := reader.Read(packet) - if err != nil { - if errors.Is(err, os.ErrClosed) && f.closed.Load() { - return - } - - f.l.WithError(err).Error("Error while reading outbound packet") - // This only seems to happen when something fatal happens to the fd, so exit. - os.Exit(2) - } - - f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) - } -} - -func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchReader, i int) { - batchSize := batchReader.BatchSize() + batchSize := reader.BatchSize() // Allocate buffers for batch reading bufs := make([][]byte, batchSize) @@ -370,7 +328,7 @@ func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchRe conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) for { - n, err := batchReader.BatchRead(bufs, sizes) + n, err := reader.BatchRead(bufs, sizes) if err != nil { if errors.Is(err, os.ErrClosed) && f.closed.Load() { return diff --git a/main.go b/main.go index 98c849a..49ca9f7 100644 --- a/main.go +++ b/main.go @@ -165,7 +165,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg for i := 0; i < routines; i++ { l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port))) - udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64)) + udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 128)) if err != nil { return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) } diff --git a/outside.go b/outside.go index eae15f3..5c681f8 100644 --- a/outside.go +++ b/outside.go @@ -549,3 +549,108 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) { // We also delete it from pending hostmap to allow for fast reconnect. f.handshakeManager.DeleteHostInfo(hostinfo) } + +// readOutsidePacketsBatch processes multiple packets received from UDP in a batch +// and writes all successfully decrypted packets to TUN in a single operation +func (f *Interface) readOutsidePacketsBatch(addrs []netip.AddrPort, payloads [][]byte, count int, outs [][]byte, nb []byte, q int, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, localCache firewall.ConntrackCache) { + // Pre-allocate slice for accumulating successful decryptions + tunPackets := make([][]byte, 0, count) + + for i := 0; i < count; i++ { + payload := payloads[i] + addr := addrs[i] + out := outs[i] + + // Parse header + err := h.Parse(payload) + if err != nil { + if len(payload) > 1 { + f.l.WithField("packet", payload).Infof("Error while parsing inbound packet from %s: %s", addr, err) + } + continue + } + + if addr.IsValid() { + if f.myVpnNetworksTable.Contains(addr.Addr()) { + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("udpAddr", addr).Debug("Refusing to process double encrypted packet") + } + continue + } + } + + var hostinfo *HostInfo + if h.Type == header.Message && h.Subtype == header.MessageRelay { + hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex) + } else { + hostinfo = f.hostMap.QueryIndex(h.RemoteIndex) + } + + var ci *ConnectionState + if hostinfo != nil { + ci = hostinfo.ConnectionState + } + + switch h.Type { + case header.Message: + if !f.handleEncrypted(ci, addr, h) { + continue + } + + switch h.Subtype { + case header.MessageNone: + // Decrypt packet + out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, payload[:header.Len], payload[header.Len:], h.MessageCounter, nb) + if err != nil { + hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") + continue + } + + packetData := out[virtioNetHdrLen:] + + err = newPacket(packetData, true, fwPacket) + if err != nil { + hostinfo.logger(f.l).WithError(err).WithField("packet", packetData).Warnf("Error while validating inbound packet") + continue + } + + if !hostinfo.ConnectionState.window.Update(f.l, h.MessageCounter) { + hostinfo.logger(f.l).WithField("fwPacket", fwPacket).Debugln("dropping out of window packet") + continue + } + + dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) + if dropReason != nil { + f.rejectOutside(packetData, hostinfo.ConnectionState, hostinfo, nb, payload, q) + if f.l.Level >= logrus.DebugLevel { + hostinfo.logger(f.l).WithField("fwPacket", fwPacket).WithField("reason", dropReason).Debugln("dropping inbound packet") + } + continue + } + + f.connectionManager.In(hostinfo) + // Add to batch for TUN write + tunPackets = append(tunPackets, out) + + case header.MessageRelay: + // Skip relay packets in batch mode for now (less common path) + f.readOutsidePackets(addr, nil, out[:virtioNetHdrLen], payload, h, fwPacket, lhf, nb, q, localCache) + + default: + hostinfo.logger(f.l).Debugf("unexpected message subtype %d", h.Subtype) + } + + default: + // Handle non-Message types using single-packet path + f.readOutsidePackets(addr, nil, out[:virtioNetHdrLen], payload, h, fwPacket, lhf, nb, q, localCache) + } + } + + if len(tunPackets) > 0 { + n, err := f.readers[q].WriteBatch(tunPackets, virtioNetHdrLen) + if err != nil { + f.l.WithError(err).WithField("sent", n).WithField("total", len(tunPackets)).Error("Failed to batch write to tun") + } + f.batchMetrics.tunWriteSize.Update(int64(len(tunPackets))) + } +} diff --git a/overlay/device.go b/overlay/device.go index 07146ab..8683ead 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -7,11 +7,25 @@ import ( "github.com/slackhq/nebula/routing" ) -type Device interface { +// BatchReadWriter extends io.ReadWriteCloser with batch I/O operations +type BatchReadWriter interface { io.ReadWriteCloser + + // BatchRead reads multiple packets at once + BatchRead(bufs [][]byte, sizes []int) (int, error) + + // WriteBatch writes multiple packets at once + WriteBatch(bufs [][]byte, offset int) (int, error) + + // BatchSize returns the optimal batch size for this device + BatchSize() int +} + +type Device interface { + BatchReadWriter Activate() error Networks() []netip.Prefix Name() string RoutesFor(netip.Addr) routing.Gateways - NewMultiQueueReader() (io.ReadWriteCloser, error) + NewMultiQueueReader() (BatchReadWriter, error) } diff --git a/overlay/tun_android.go b/overlay/tun_android.go index df1ed8d..3e88bb3 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -95,6 +95,29 @@ func (t *tun) Name() string { return "android" } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for android") } + +func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) { + n, err := t.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) { + for i, buf := range bufs { + _, err := t.Write(buf[offset:]) + if err != nil { + return i, err + } + } + return len(bufs), nil +} + +func (t *tun) BatchSize() int { + return 1 +} diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 34c2a71..d9c8830 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -549,6 +549,32 @@ func (t *tun) Name() string { return t.Device } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") } + +// BatchRead reads a single packet (batch size 1 for non-Linux platforms) +func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) { + n, err := t.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +// WriteBatch writes packets individually (no batching for non-Linux platforms) +func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) { + for i, buf := range bufs { + _, err := t.Write(buf[offset:]) + if err != nil { + return i, err + } + } + return len(bufs), nil +} + +// BatchSize returns 1 for non-Linux platforms (no batching) +func (t *tun) BatchSize() int { + return 1 +} diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index 131879d..ee63b1e 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -105,10 +105,36 @@ func (t *disabledTun) Write(b []byte) (int, error) { return len(b), nil } -func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *disabledTun) NewMultiQueueReader() (BatchReadWriter, error) { return t, nil } +// BatchRead reads a single packet (batch size 1 for disabled tun) +func (t *disabledTun) BatchRead(bufs [][]byte, sizes []int) (int, error) { + n, err := t.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +// WriteBatch writes packets individually (no batching for disabled tun) +func (t *disabledTun) WriteBatch(bufs [][]byte, offset int) (int, error) { + for i, buf := range bufs { + _, err := t.Write(buf[offset:]) + if err != nil { + return i, err + } + } + return len(bufs), nil +} + +// BatchSize returns 1 for disabled tun (no batching) +func (t *disabledTun) BatchSize() int { + return 1 +} + func (t *disabledTun) Close() error { if t.read != nil { close(t.read) diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 0dd8228..8d037d4 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -450,10 +450,36 @@ func (t *tun) Name() string { return t.Device } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd") } +// BatchRead reads a single packet (batch size 1 for FreeBSD) +func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) { + n, err := t.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +// WriteBatch writes packets individually (no batching for FreeBSD) +func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) { + for i, buf := range bufs { + _, err := t.Write(buf[offset:]) + if err != nil { + return i, err + } + } + return len(bufs), nil +} + +// BatchSize returns 1 for FreeBSD (no batching) +func (t *tun) BatchSize() int { + return 1 +} + func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index e51e112..369184e 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -151,6 +151,29 @@ func (t *tun) Name() string { return "iOS" } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for ios") } + +func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) { + n, err := t.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) { + for i, buf := range bufs { + _, err := t.Write(buf[offset:]) + if err != nil { + return i, err + } + } + return len(bufs), nil +} + +func (t *tun) BatchSize() int { + return 1 +} diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 3c98d72..f7d9f6e 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -312,7 +312,7 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { return nil, err diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 49ac19f..4c298ae 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -390,10 +390,33 @@ func (t *tun) Name() string { return t.Device } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd") } +func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) { + n, err := t.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) { + for i, buf := range bufs { + _, err := t.Write(buf[offset:]) + if err != nil { + return i, err + } + } + return len(bufs), nil +} + +func (t *tun) BatchSize() int { + return 1 +} + func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 52d5297..4e0cc2d 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -310,10 +310,33 @@ func (t *tun) Name() string { return t.Device } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd") } +func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) { + n, err := t.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) { + for i, buf := range bufs { + _, err := t.Write(buf[offset:]) + if err != nil { + return i, err + } + } + return len(bufs), nil +} + +func (t *tun) BatchSize() int { + return 1 +} + func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index b6712fb..bd0da4c 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -132,6 +132,29 @@ func (t *TestTun) Read(b []byte) (int, error) { return len(p), nil } -func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *TestTun) NewMultiQueueReader() (BatchReadWriter, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented") } + +func (t *TestTun) BatchRead(bufs [][]byte, sizes []int) (int, error) { + n, err := t.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +func (t *TestTun) WriteBatch(bufs [][]byte, offset int) (int, error) { + for i, buf := range bufs { + _, err := t.Write(buf[offset:]) + if err != nil { + return i, err + } + } + return len(bufs), nil +} + +func (t *TestTun) BatchSize() int { + return 1 +} diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 7aac128..3ab882f 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -6,7 +6,6 @@ package overlay import ( "crypto" "fmt" - "io" "net/netip" "os" "path/filepath" @@ -234,10 +233,36 @@ func (t *winTun) Write(b []byte) (int, error) { return t.tun.Write(b, 0) } -func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *winTun) NewMultiQueueReader() (BatchReadWriter, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") } +// BatchRead reads a single packet (batch size 1 for Windows) +func (t *winTun) BatchRead(bufs [][]byte, sizes []int) (int, error) { + n, err := t.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +// WriteBatch writes packets individually (no batching for Windows) +func (t *winTun) WriteBatch(bufs [][]byte, offset int) (int, error) { + for i, buf := range bufs { + _, err := t.Write(buf[offset:]) + if err != nil { + return i, err + } + } + return len(bufs), nil +} + +// BatchSize returns 1 for Windows (no batching) +func (t *winTun) BatchSize() int { + return 1 +} + func (t *winTun) Close() error { // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes, // so to be certain, just remove everything before destroying. diff --git a/overlay/user.go b/overlay/user.go index 8a56d66..18ae0bf 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -46,10 +46,36 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways { return routing.Gateways{routing.NewGateway(ip, 1)} } -func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (d *UserDevice) NewMultiQueueReader() (BatchReadWriter, error) { return d, nil } +// BatchRead reads a single packet (batch size 1 for UserDevice) +func (d *UserDevice) BatchRead(bufs [][]byte, sizes []int) (int, error) { + n, err := d.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +// WriteBatch writes packets individually (no batching for UserDevice) +func (d *UserDevice) WriteBatch(bufs [][]byte, offset int) (int, error) { + for i, buf := range bufs { + _, err := d.Write(buf[offset:]) + if err != nil { + return i, err + } + } + return len(bufs), nil +} + +// BatchSize returns 1 for UserDevice (no batching) +func (d *UserDevice) BatchSize() int { + return 1 +} + func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) { return d.inboundReader, d.outboundWriter }