mirror of
https://github.com/slackhq/nebula.git
synced 2026-02-14 08:44:24 +01:00
batch tun reads
This commit is contained in:
139
inside.go
139
inside.go
@@ -9,8 +9,75 @@ import (
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/noiseutil"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
)
|
||||
|
||||
// consumeInsidePacketBatched is a variant of consumeInsidePacket that queues
|
||||
// outgoing packets into pendingPackets instead of sending them immediately.
|
||||
// The caller is responsible for flushing pendingPackets with WriteBatch.
|
||||
func (f *Interface) consumeInsidePacketBatched(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache, pendingPackets *[]udp.BatchPacket) {
|
||||
err := newPacket(packet, false, fwPacket)
|
||||
if err != nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Ignore local broadcast packets
|
||||
if f.dropLocalBroadcast {
|
||||
if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) {
|
||||
if immediatelyForwardToSelf {
|
||||
_, err := f.readers[q].Write(packet)
|
||||
if err != nil {
|
||||
f.l.WithError(err).Error("Failed to forward to tun")
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Ignore multicast packets
|
||||
if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
|
||||
return
|
||||
}
|
||||
|
||||
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
|
||||
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
||||
})
|
||||
|
||||
if hostinfo == nil {
|
||||
f.rejectInside(packet, out, q)
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
|
||||
WithField("fwPacket", fwPacket).
|
||||
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if !ready {
|
||||
return
|
||||
}
|
||||
|
||||
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
if dropReason == nil {
|
||||
f.sendNoMetricsBatched(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q, pendingPackets)
|
||||
} else {
|
||||
f.rejectInside(packet, out, q)
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(f.l).
|
||||
WithField("fwPacket", fwPacket).
|
||||
WithField("reason", dropReason).
|
||||
Debugln("dropping outbound packet")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
||||
err := newPacket(packet, false, fwPacket)
|
||||
if err != nil {
|
||||
@@ -409,3 +476,75 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendNoMetricsBatched is like sendNoMetrics but queues the packet for batched sending
|
||||
// instead of sending immediately. The caller must flush pendingPackets with WriteBatch.
|
||||
func (f *Interface) sendNoMetricsBatched(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int, pendingPackets *[]udp.BatchPacket) {
|
||||
if ci.eKey == nil {
|
||||
return
|
||||
}
|
||||
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
|
||||
fullOut := out
|
||||
|
||||
if useRelay {
|
||||
if len(out) < header.Len {
|
||||
out = out[:header.Len]
|
||||
}
|
||||
out = out[header.Len:]
|
||||
}
|
||||
|
||||
if noiseutil.EncryptLockNeeded {
|
||||
ci.writeLock.Lock()
|
||||
}
|
||||
c := ci.messageCounter.Add(1)
|
||||
|
||||
out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)
|
||||
f.connectionManager.Out(hostinfo)
|
||||
|
||||
if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
|
||||
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
|
||||
hostinfo.lastRebindCount = f.rebindCount
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
out, err = ci.eKey.EncryptDanger(out, out, p, c, nb)
|
||||
if noiseutil.EncryptLockNeeded {
|
||||
ci.writeLock.Unlock()
|
||||
}
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).
|
||||
WithField("udpAddr", remote).WithField("counter", c).
|
||||
WithField("attemptedCounter", c).
|
||||
Error("Failed to encrypt outgoing packet")
|
||||
return
|
||||
}
|
||||
|
||||
// Queue the packet for batched sending
|
||||
var addr netip.AddrPort
|
||||
if remote.IsValid() {
|
||||
addr = remote
|
||||
} else if hostinfo.remote.IsValid() {
|
||||
addr = hostinfo.remote
|
||||
} else {
|
||||
// Relay path - send immediately, not batched
|
||||
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
|
||||
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
||||
if err != nil {
|
||||
hostinfo.relayState.DeleteRelay(relayIP)
|
||||
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetricsBatched failed to find HostInfo")
|
||||
continue
|
||||
}
|
||||
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
|
||||
break
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Copy the payload since the buffer will be reused
|
||||
payload := make([]byte, len(out))
|
||||
copy(payload, out)
|
||||
*pendingPackets = append(*pendingPackets, udp.BatchPacket{Payload: payload, Addr: addr})
|
||||
}
|
||||
|
||||
77
interface.go
77
interface.go
@@ -48,6 +48,8 @@ type InterfaceConfig struct {
|
||||
|
||||
ConntrackCacheTimeout time.Duration
|
||||
l *logrus.Logger
|
||||
|
||||
tunBatchSize int // batch size for TUN read/write batching, 0 to disable
|
||||
}
|
||||
|
||||
type Interface struct {
|
||||
@@ -86,8 +88,9 @@ type Interface struct {
|
||||
|
||||
conntrackCacheTimeout time.Duration
|
||||
|
||||
writers []udp.Conn
|
||||
readers []io.ReadWriteCloser
|
||||
writers []udp.Conn
|
||||
readers []io.ReadWriteCloser
|
||||
tunBatchSize int // batch size for TUN read/write batching
|
||||
|
||||
metricHandshakes metrics.Histogram
|
||||
messageMetrics *MessageMetrics
|
||||
@@ -187,6 +190,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
||||
relayManager: c.relayManager,
|
||||
connectionManager: c.connectionManager,
|
||||
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
||||
tunBatchSize: c.tunBatchSize,
|
||||
|
||||
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
||||
messageMetrics: c.MessageMetrics,
|
||||
@@ -244,6 +248,15 @@ func (f *Interface) activate() {
|
||||
f.readers[i] = reader
|
||||
}
|
||||
|
||||
// Enable batch reading on all readers if batch size > 1
|
||||
if f.tunBatchSize > 1 {
|
||||
for i := 0; i < f.routines; i++ {
|
||||
if err := overlay.EnableBatchReading(f.readers[i]); err != nil {
|
||||
f.l.WithError(err).WithField("routine", i).Warn("Failed to enable batch reading, falling back to single reads")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := f.inside.Activate(); err != nil {
|
||||
f.inside.Close()
|
||||
f.l.Fatal(err)
|
||||
@@ -287,13 +300,21 @@ func (f *Interface) listenOut(i int) {
|
||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||
runtime.LockOSThread()
|
||||
|
||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||
|
||||
// Check if batch reading is available and enabled
|
||||
batchReader := overlay.AsBatchReader(reader)
|
||||
if batchReader != nil && f.tunBatchSize > 1 {
|
||||
f.listenInBatched(reader, batchReader, i, conntrackCache)
|
||||
return
|
||||
}
|
||||
|
||||
// Fallback to single-packet reading
|
||||
packet := make([]byte, mtu)
|
||||
out := make([]byte, mtu)
|
||||
fwPacket := &firewall.Packet{}
|
||||
nb := make([]byte, 12, 12)
|
||||
|
||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||
|
||||
for {
|
||||
n, err := reader.Read(packet)
|
||||
if err != nil {
|
||||
@@ -310,6 +331,54 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) listenInBatched(reader io.ReadWriteCloser, batchReader overlay.BatchReader, i int, conntrackCache *firewall.ConntrackCacheTicker) {
|
||||
batchSize := f.tunBatchSize
|
||||
|
||||
// Pre-allocate buffers for batch reading
|
||||
packets := make([][]byte, batchSize)
|
||||
for j := range packets {
|
||||
packets[j] = make([]byte, mtu)
|
||||
}
|
||||
sizes := make([]int, batchSize)
|
||||
|
||||
// Pre-allocate buffers for packet processing
|
||||
out := make([]byte, mtu)
|
||||
fwPacket := &firewall.Packet{}
|
||||
nb := make([]byte, 12, 12)
|
||||
|
||||
// Pre-allocate buffer for batched UDP writes
|
||||
pendingPackets := make([]udp.BatchPacket, 0, batchSize)
|
||||
|
||||
for {
|
||||
// Read a batch of packets from TUN
|
||||
n, err := batchReader.ReadBatch(packets, sizes)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
f.l.WithError(err).Error("Error while reading outbound packets")
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Process all packets in the batch
|
||||
cache := conntrackCache.Get(f.l)
|
||||
for j := 0; j < n; j++ {
|
||||
f.consumeInsidePacketBatched(packets[j][:sizes[j]], fwPacket, nb, out, i, cache, &pendingPackets)
|
||||
}
|
||||
|
||||
// Flush all pending UDP writes
|
||||
if len(pendingPackets) > 0 {
|
||||
f.writers[i].WriteBatch(pendingPackets)
|
||||
pendingPackets = pendingPackets[:0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
||||
c.RegisterReloadCallback(f.reloadFirewall)
|
||||
c.RegisterReloadCallback(f.reloadSendRecvError)
|
||||
|
||||
1
main.go
1
main.go
@@ -250,6 +250,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
punchy: punchy,
|
||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||
l: l,
|
||||
tunBatchSize: c.GetInt("tun.batch", 64),
|
||||
}
|
||||
|
||||
var ifce *Interface
|
||||
|
||||
@@ -16,3 +16,38 @@ type Device interface {
|
||||
SupportsMultiqueue() bool
|
||||
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
||||
}
|
||||
|
||||
// BatchReader is an optional interface that devices can implement
|
||||
// to support reading multiple packets in a single batch operation.
|
||||
// This can significantly reduce syscall overhead under high load.
|
||||
type BatchReader interface {
|
||||
// ReadBatch reads up to len(packets) packets into the provided buffers.
|
||||
// Each packet is read into packets[i] and its length is stored in sizes[i].
|
||||
// Returns the number of packets read, or an error.
|
||||
// A return of (0, nil) indicates no packets were available (non-blocking).
|
||||
ReadBatch(packets [][]byte, sizes []int) (int, error)
|
||||
}
|
||||
|
||||
// AsBatchReader returns a BatchReader if the reader supports batch operations,
|
||||
// otherwise returns nil.
|
||||
func AsBatchReader(r io.ReadWriteCloser) BatchReader {
|
||||
if br, ok := r.(BatchReader); ok {
|
||||
return br
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BatchEnabler is an optional interface for devices that need explicit
|
||||
// enabling of batch read support (e.g., setting non-blocking mode).
|
||||
type BatchEnabler interface {
|
||||
EnableBatchReading() error
|
||||
}
|
||||
|
||||
// EnableBatchReading enables batch reading on the device if supported.
|
||||
// Returns nil if the device doesn't support or need explicit enabling.
|
||||
func EnableBatchReading(d interface{}) error {
|
||||
if be, ok := d.(BatchEnabler); ok {
|
||||
return be.EnableBatchReading()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -34,6 +34,7 @@ type tun struct {
|
||||
TXQueueLen int
|
||||
deviceIndex int
|
||||
ioctlFd uintptr
|
||||
nonBlocking bool // true if fd is in non-blocking mode
|
||||
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
@@ -239,7 +240,7 @@ func (t *tun) SupportsMultiqueue() bool {
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR|unix.O_NONBLOCK, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -251,9 +252,99 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||
return &tunBatchReader{fd: fd, device: t.Device}, nil
|
||||
}
|
||||
|
||||
return file, nil
|
||||
// tunBatchReader implements BatchReader for efficient batch packet reading
|
||||
type tunBatchReader struct {
|
||||
fd int
|
||||
device string
|
||||
}
|
||||
|
||||
func (r *tunBatchReader) Read(b []byte) (int, error) {
|
||||
// Use poll to wait for data, then read
|
||||
for {
|
||||
n, err := unix.Read(r.fd, b)
|
||||
if err == nil {
|
||||
return n, nil
|
||||
}
|
||||
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
|
||||
// Wait for data
|
||||
pfds := []unix.PollFd{{Fd: int32(r.fd), Events: unix.POLLIN}}
|
||||
_, err = unix.Poll(pfds, -1)
|
||||
if err != nil {
|
||||
if err == unix.EINTR {
|
||||
continue
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
func (r *tunBatchReader) Write(b []byte) (int, error) {
|
||||
return unix.Write(r.fd, b)
|
||||
}
|
||||
|
||||
func (r *tunBatchReader) Close() error {
|
||||
return unix.Close(r.fd)
|
||||
}
|
||||
|
||||
// ReadBatch reads up to len(packets) packets from the TUN device.
|
||||
// It drains all available packets without blocking, using poll() only
|
||||
// when no packets have been read yet.
|
||||
func (r *tunBatchReader) ReadBatch(packets [][]byte, sizes []int) (int, error) {
|
||||
count := 0
|
||||
maxPackets := len(packets)
|
||||
if len(sizes) < maxPackets {
|
||||
maxPackets = len(sizes)
|
||||
}
|
||||
|
||||
for count < maxPackets {
|
||||
n, err := unix.Read(r.fd, packets[count])
|
||||
if err == nil && n > 0 {
|
||||
sizes[count] = n
|
||||
count++
|
||||
continue
|
||||
}
|
||||
|
||||
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
|
||||
// No more packets available
|
||||
if count > 0 {
|
||||
// We have some packets, return them
|
||||
return count, nil
|
||||
}
|
||||
// No packets yet, wait for at least one
|
||||
pfds := []unix.PollFd{{Fd: int32(r.fd), Events: unix.POLLIN}}
|
||||
_, err = unix.Poll(pfds, -1)
|
||||
if err != nil {
|
||||
if err == unix.EINTR {
|
||||
continue
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if count > 0 {
|
||||
// Return what we have
|
||||
return count, nil
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
if count > 0 {
|
||||
return count, nil
|
||||
}
|
||||
return 0, io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
@@ -284,6 +375,111 @@ func (t *tun) Write(b []byte) (int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// EnableBatchReading sets the TUN fd to non-blocking mode to enable batch reading.
|
||||
// This should be called before using ReadBatch.
|
||||
func (t *tun) EnableBatchReading() error {
|
||||
if t.nonBlocking {
|
||||
return nil
|
||||
}
|
||||
err := unix.SetNonblock(t.fd, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.nonBlocking = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read overrides the default Read to handle non-blocking mode
|
||||
func (t *tun) Read(b []byte) (int, error) {
|
||||
if !t.nonBlocking {
|
||||
// Use the embedded ReadWriteCloser for blocking reads
|
||||
return t.ReadWriteCloser.Read(b)
|
||||
}
|
||||
|
||||
// Non-blocking read with poll
|
||||
for {
|
||||
n, err := unix.Read(t.fd, b)
|
||||
if err == nil {
|
||||
return n, nil
|
||||
}
|
||||
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
|
||||
// Wait for data
|
||||
pfds := []unix.PollFd{{Fd: int32(t.fd), Events: unix.POLLIN}}
|
||||
_, err = unix.Poll(pfds, -1)
|
||||
if err != nil {
|
||||
if err == unix.EINTR {
|
||||
continue
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
// ReadBatch reads up to len(packets) packets from the TUN device.
|
||||
// EnableBatchReading must be called first.
|
||||
func (t *tun) ReadBatch(packets [][]byte, sizes []int) (int, error) {
|
||||
if !t.nonBlocking {
|
||||
// Fallback to single read if non-blocking not enabled
|
||||
n, err := t.ReadWriteCloser.Read(packets[0])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
count := 0
|
||||
maxPackets := len(packets)
|
||||
if len(sizes) < maxPackets {
|
||||
maxPackets = len(sizes)
|
||||
}
|
||||
|
||||
for count < maxPackets {
|
||||
n, err := unix.Read(t.fd, packets[count])
|
||||
if err == nil && n > 0 {
|
||||
sizes[count] = n
|
||||
count++
|
||||
continue
|
||||
}
|
||||
|
||||
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
|
||||
// No more packets available
|
||||
if count > 0 {
|
||||
return count, nil
|
||||
}
|
||||
// No packets yet, wait for at least one
|
||||
pfds := []unix.PollFd{{Fd: int32(t.fd), Events: unix.POLLIN}}
|
||||
_, err = unix.Poll(pfds, -1)
|
||||
if err != nil {
|
||||
if err == unix.EINTR {
|
||||
continue
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if count > 0 {
|
||||
return count, nil
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
if count > 0 {
|
||||
return count, nil
|
||||
}
|
||||
return 0, io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (t *tun) deviceBytes() (o [16]byte) {
|
||||
for i, c := range t.Device {
|
||||
o[i] = byte(c)
|
||||
|
||||
Reference in New Issue
Block a user