mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-16 04:47:38 +02:00
batched tun interface
This commit is contained in:
@@ -4,15 +4,13 @@
|
|||||||
package e2e
|
package e2e
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"log/slog"
|
|
||||||
|
|
||||||
"dario.cat/mergo"
|
"dario.cat/mergo"
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
@@ -382,7 +380,7 @@ func getAddrs(ns []netip.Prefix) []netip.Addr {
|
|||||||
func NewTestLogger() *slog.Logger {
|
func NewTestLogger() *slog.Logger {
|
||||||
v := os.Getenv("TEST_LOGS")
|
v := os.Getenv("TEST_LOGS")
|
||||||
if v == "" {
|
if v == "" {
|
||||||
return slog.New(slog.NewTextHandler(io.Discard, nil))
|
return slog.New(slog.DiscardHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
level := slog.LevelInfo
|
level := slog.LevelInfo
|
||||||
|
|||||||
@@ -974,6 +974,7 @@ func (hm *HandshakeManager) continueHandshake(via ViaSender, hh *HandshakeHostIn
|
|||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
for _, cp := range hh.packetStore {
|
for _, cp := range hh.packetStore {
|
||||||
|
//todo use a sendbatcher
|
||||||
cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out)
|
cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out)
|
||||||
}
|
}
|
||||||
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
|
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
|
||||||
|
|||||||
192
inside.go
192
inside.go
@@ -2,6 +2,7 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
@@ -9,10 +10,16 @@ import (
|
|||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
"github.com/slackhq/nebula/noiseutil"
|
"github.com/slackhq/nebula/noiseutil"
|
||||||
|
"github.com/slackhq/nebula/overlay/batch"
|
||||||
|
"github.com/slackhq/nebula/overlay/tio"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
func (f *Interface) consumeInsidePacket(pkt tio.Packet, fwPacket *firewall.Packet, nb []byte, sendBatch batch.TxBatcher, rejectBuf []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
|
// borrowed: pkt.Bytes is owned by the originating tio.Queue and is
|
||||||
|
// only valid until the next Read on that queue. If you must keep
|
||||||
|
// the packet, use pkt.Clone() to detach it
|
||||||
|
packet := pkt.Bytes
|
||||||
err := newPacket(packet, false, fwPacket)
|
err := newPacket(packet, false, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
@@ -37,7 +44,10 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
// routes packets from the Nebula addr to the Nebula addr through the Nebula
|
// routes packets from the Nebula addr to the Nebula addr through the Nebula
|
||||||
// TUN device.
|
// TUN device.
|
||||||
if immediatelyForwardToSelf {
|
if immediatelyForwardToSelf {
|
||||||
_, err := f.readers[q].Write(packet)
|
err := tio.SegmentSuperpacket(pkt, func(seg []byte) error {
|
||||||
|
_, werr := f.readers[q].Write(seg)
|
||||||
|
return werr
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.Error("Failed to forward to tun", "error", err)
|
f.l.Error("Failed to forward to tun", "error", err)
|
||||||
}
|
}
|
||||||
@@ -53,11 +63,23 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
}
|
}
|
||||||
|
|
||||||
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
|
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
|
||||||
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
// borrowed: SegmentSuperpacket builds each segment in the kernel-supplied pkt
|
||||||
|
// bytes underneath. cachePacket explicitly copies its argument (handshake_manager.go cachePacket),
|
||||||
|
// so retaining segments past the loop is safe.
|
||||||
|
err := tio.SegmentSuperpacket(pkt, func(seg []byte) error {
|
||||||
|
hh.cachePacket(f.l, header.Message, 0, seg, f.sendMessageNow, f.cachedPacketMetrics)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil && f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
f.l.Debug("Failed to segment superpacket for handshake cache",
|
||||||
|
"error", err,
|
||||||
|
"vpnAddr", fwPacket.RemoteAddr,
|
||||||
|
)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
if hostinfo == nil {
|
if hostinfo == nil {
|
||||||
f.rejectInside(packet, out, q)
|
f.rejectInside(packet, rejectBuf, q)
|
||||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.Debug("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks",
|
f.l.Debug("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks",
|
||||||
"vpnAddr", fwPacket.RemoteAddr,
|
"vpnAddr", fwPacket.RemoteAddr,
|
||||||
@@ -73,10 +95,9 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
|
|
||||||
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
|
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
|
||||||
if dropReason == nil {
|
if dropReason == nil {
|
||||||
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
|
f.sendInsideMessage(hostinfo, pkt, nb, sendBatch, rejectBuf, q)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
f.rejectInside(packet, out, q)
|
f.rejectInside(packet, rejectBuf, q)
|
||||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hostinfo.logger(f.l).Debug("dropping outbound packet",
|
hostinfo.logger(f.l).Debug("dropping outbound packet",
|
||||||
"fwPacket", fwPacket,
|
"fwPacket", fwPacket,
|
||||||
@@ -86,6 +107,124 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *Interface) sendInsideEncrypt(hostinfo *HostInfo, ci *ConnectionState, seg, scratch, nb []byte) []byte {
|
||||||
|
if noiseutil.EncryptLockNeeded {
|
||||||
|
ci.writeLock.Lock()
|
||||||
|
}
|
||||||
|
c := ci.messageCounter.Add(1)
|
||||||
|
|
||||||
|
out := header.Encode(scratch, header.Version, header.Message, 0, hostinfo.remoteIndexId, c)
|
||||||
|
f.connectionManager.Out(hostinfo)
|
||||||
|
|
||||||
|
out, encErr := ci.eKey.EncryptDanger(out, out, seg, c, nb)
|
||||||
|
if noiseutil.EncryptLockNeeded {
|
||||||
|
ci.writeLock.Unlock()
|
||||||
|
}
|
||||||
|
if encErr != nil {
|
||||||
|
hostinfo.logger(f.l).Error("Failed to encrypt outgoing packet",
|
||||||
|
"error", encErr,
|
||||||
|
"udpAddr", hostinfo.remote,
|
||||||
|
"counter", c,
|
||||||
|
)
|
||||||
|
// Skip this segment; the rest of the superpacket can still
|
||||||
|
// go out — TCP will retransmit anything we drop here.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendInsideMessage encrypts a firewall-approved inside packet (or every
|
||||||
|
// segment of a TSO/USO superpacket) into the caller's batch slot for
|
||||||
|
// later sendmmsg flush. Segmentation is fused with encryption here so the
|
||||||
|
// kernel-supplied superpacket bytes never get written into a separate
|
||||||
|
// scratch arena: SegmentSuperpacket builds each segment's plaintext in
|
||||||
|
// segScratch[:segLen] in turn, and we encrypt directly into a fresh
|
||||||
|
// SendBatch slot.
|
||||||
|
func (f *Interface) sendInsideMessage(hostinfo *HostInfo, pkt tio.Packet, nb []byte, sendBatch batch.TxBatcher, rejectBuf []byte, q int) {
|
||||||
|
ci := hostinfo.ConnectionState
|
||||||
|
if ci.eKey == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if hostinfo.lastRebindCount != f.rebindCount {
|
||||||
|
//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
|
||||||
|
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
|
||||||
|
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
|
||||||
|
hostinfo.lastRebindCount = f.rebindCount
|
||||||
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
hostinfo.logger(f.l).Debug("Lighthouse update triggered for punch due to rebind counter",
|
||||||
|
"vpnAddrs", hostinfo.vpnAddrs,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hostinfo.remote.IsValid() { //the relay path
|
||||||
|
//first, find our relay hostinfo:
|
||||||
|
var relayHostInfo *HostInfo
|
||||||
|
var relay *Relay
|
||||||
|
var err error
|
||||||
|
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
|
||||||
|
relayHostInfo, relay, err = f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
||||||
|
if err != nil {
|
||||||
|
hostinfo.relayState.DeleteRelay(relayIP)
|
||||||
|
hostinfo.logger(f.l).Info("sendNoMetrics failed to find HostInfo",
|
||||||
|
"relay", relayIP,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if relayHostInfo == nil || relay == nil {
|
||||||
|
//failure already logged
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tio.SegmentSuperpacket(pkt, func(seg []byte) error {
|
||||||
|
//relay header + header + plaintext + AEAD tag (16 bytes for both AES-GCM and ChaCha20-Poly1305) + relay tag
|
||||||
|
scratch := sendBatch.Reserve(header.Len + header.Len + len(seg) + 16 + 16)
|
||||||
|
|
||||||
|
innerPacket := f.sendInsideEncrypt(hostinfo, ci, seg, scratch[header.Len:], nb)
|
||||||
|
if innerPacket == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//now we need to do a relay-encrypt:
|
||||||
|
toSend, err := f.prepareSendVia(relayHostInfo, relay, innerPacket, nb, scratch, true)
|
||||||
|
if err != nil {
|
||||||
|
//already logged
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sendBatch.Commit(toSend, relayHostInfo.remote, 0)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
hostinfo.logger(f.l).Error("Failed to segment superpacket for relay send", "error", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := tio.SegmentSuperpacket(pkt, func(seg []byte) error {
|
||||||
|
// header + plaintext + AEAD tag (16 bytes for both AES-GCM and ChaCha20-Poly1305)
|
||||||
|
scratch := sendBatch.Reserve(header.Len + len(seg) + 16)
|
||||||
|
|
||||||
|
out := f.sendInsideEncrypt(hostinfo, ci, seg, scratch, nb)
|
||||||
|
if out == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sendBatch.Commit(out, hostinfo.remote, 0)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
hostinfo.logger(f.l).Error("Failed to segment superpacket for send",
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
|
func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
|
||||||
if !f.firewall.InSendReject {
|
if !f.firewall.InSendReject {
|
||||||
return
|
return
|
||||||
@@ -275,21 +414,13 @@ func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *C
|
|||||||
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
|
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendVia sends a payload through a Relay tunnel. No authentication or encryption is done
|
func (f *Interface) prepareSendVia(via *HostInfo,
|
||||||
// to the payload for the ultimate target host, making this a useful method for sending
|
|
||||||
// handshake messages to peers through relay tunnels.
|
|
||||||
// via is the HostInfo through which the message is relayed.
|
|
||||||
// ad is the plaintext data to authenticate, but not encrypt
|
|
||||||
// nb is a buffer used to store the nonce value, re-used for performance reasons.
|
|
||||||
// out is a buffer used to store the result of the Encrypt operation
|
|
||||||
// q indicates which writer to use to send the packet.
|
|
||||||
func (f *Interface) SendVia(via *HostInfo,
|
|
||||||
relay *Relay,
|
relay *Relay,
|
||||||
ad,
|
ad,
|
||||||
nb,
|
nb,
|
||||||
out []byte,
|
out []byte,
|
||||||
nocopy bool,
|
nocopy bool,
|
||||||
) {
|
) ([]byte, error) {
|
||||||
if noiseutil.EncryptLockNeeded {
|
if noiseutil.EncryptLockNeeded {
|
||||||
// NOTE: for goboring AESGCMTLS we need to lock because of the nonce check
|
// NOTE: for goboring AESGCMTLS we need to lock because of the nonce check
|
||||||
via.ConnectionState.writeLock.Lock()
|
via.ConnectionState.writeLock.Lock()
|
||||||
@@ -311,7 +442,7 @@ func (f *Interface) SendVia(via *HostInfo,
|
|||||||
"headerLen", len(out),
|
"headerLen", len(out),
|
||||||
"cipherOverhead", via.ConnectionState.eKey.Overhead(),
|
"cipherOverhead", via.ConnectionState.eKey.Overhead(),
|
||||||
)
|
)
|
||||||
return
|
return nil, io.ErrShortBuffer
|
||||||
}
|
}
|
||||||
|
|
||||||
// The header bytes are written to the 'out' slice; Grow the slice to hold the header and associated data payload.
|
// The header bytes are written to the 'out' slice; Grow the slice to hold the header and associated data payload.
|
||||||
@@ -331,13 +462,32 @@ func (f *Interface) SendVia(via *HostInfo,
|
|||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
via.logger(f.l).Info("Failed to EncryptDanger in sendVia", "error", err)
|
via.logger(f.l).Info("Failed to EncryptDanger in sendVia", "error", err)
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
err = f.writers[0].WriteTo(out, via.remote)
|
f.connectionManager.RelayUsed(relay.LocalIndex)
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendVia sends a payload through a Relay tunnel. No authentication or encryption is done
|
||||||
|
// to the payload for the ultimate target host, making this a useful method for sending
|
||||||
|
// handshake messages to peers through relay tunnels.
|
||||||
|
// via is the HostInfo through which the message is relayed.
|
||||||
|
// ad is the plaintext data to authenticate, but not encrypt
|
||||||
|
// nb is a buffer used to store the nonce value, re-used for performance reasons.
|
||||||
|
// out is a buffer used to store the result of the Encrypt operation
|
||||||
|
// q indicates which writer to use to send the packet.
|
||||||
|
func (f *Interface) SendVia(via *HostInfo,
|
||||||
|
relay *Relay,
|
||||||
|
ad,
|
||||||
|
nb,
|
||||||
|
out []byte,
|
||||||
|
nocopy bool,
|
||||||
|
) {
|
||||||
|
toSend, err := f.prepareSendVia(via, relay, ad, nb, out, nocopy)
|
||||||
|
err = f.writers[0].WriteTo(toSend, via.remote)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
via.logger(f.l).Info("Failed to WriteTo in sendVia", "error", err)
|
via.logger(f.l).Info("Failed to WriteTo in sendVia", "error", err)
|
||||||
}
|
}
|
||||||
f.connectionManager.RelayUsed(relay.LocalIndex)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) {
|
func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) {
|
||||||
|
|||||||
60
interface.go
60
interface.go
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -13,11 +12,12 @@ import (
|
|||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
|
"github.com/slackhq/nebula/overlay/batch"
|
||||||
|
"github.com/slackhq/nebula/overlay/tio"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -47,7 +47,8 @@ type InterfaceConfig struct {
|
|||||||
reQueryWait time.Duration
|
reQueryWait time.Duration
|
||||||
|
|
||||||
ConntrackCacheTimeout time.Duration
|
ConntrackCacheTimeout time.Duration
|
||||||
l *slog.Logger
|
|
||||||
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type Interface struct {
|
type Interface struct {
|
||||||
@@ -88,8 +89,12 @@ type Interface struct {
|
|||||||
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
writers []udp.Conn
|
writers []udp.Conn
|
||||||
readers []io.ReadWriteCloser
|
readers []tio.Queue
|
||||||
wg sync.WaitGroup
|
// batchers is one per tun queue, wrapping readers[i].
|
||||||
|
// decryptToTun sends plaintext into the batch.RxBatcher;
|
||||||
|
// listenOut calls its Flush at the end of each UDP recvmmsg batch.
|
||||||
|
batchers []batch.RxBatcher
|
||||||
|
wg sync.WaitGroup
|
||||||
|
|
||||||
// fatalErr holds the first unexpected reader error that caused shutdown.
|
// fatalErr holds the first unexpected reader error that caused shutdown.
|
||||||
// nil means "no fatal error" (yet)
|
// nil means "no fatal error" (yet)
|
||||||
@@ -187,7 +192,8 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
routines: c.routines,
|
routines: c.routines,
|
||||||
version: c.version,
|
version: c.version,
|
||||||
writers: make([]udp.Conn, c.routines),
|
writers: make([]udp.Conn, c.routines),
|
||||||
readers: make([]io.ReadWriteCloser, c.routines),
|
readers: make([]tio.Queue, c.routines),
|
||||||
|
batchers: make([]batch.RxBatcher, c.routines),
|
||||||
myVpnNetworks: cs.myVpnNetworks,
|
myVpnNetworks: cs.myVpnNetworks,
|
||||||
myVpnNetworksTable: cs.myVpnNetworksTable,
|
myVpnNetworksTable: cs.myVpnNetworksTable,
|
||||||
myVpnAddrs: cs.myVpnAddrs,
|
myVpnAddrs: cs.myVpnAddrs,
|
||||||
@@ -245,15 +251,17 @@ func (f *Interface) activate() error {
|
|||||||
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
|
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
|
||||||
|
|
||||||
// Prepare n tun queues
|
// Prepare n tun queues
|
||||||
var reader io.ReadWriteCloser = f.inside
|
|
||||||
for i := 0; i < f.routines; i++ {
|
for i := 0; i < f.routines; i++ {
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
reader, err = f.inside.NewMultiQueueReader()
|
if err = f.inside.NewMultiQueueReader(); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
f.readers[i] = reader
|
}
|
||||||
|
f.readers = f.inside.Readers()
|
||||||
|
for i := range f.readers {
|
||||||
|
arena := batch.NewArena(batch.DefaultPassthroughArenaCap)
|
||||||
|
f.batchers[i] = batch.NewPassthrough(f.readers[i], arena)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.wg.Add(1) // for us to wait on Close() to return
|
f.wg.Add(1) // for us to wait on Close() to return
|
||||||
@@ -311,14 +319,22 @@ func (f *Interface) listenOut(i int) {
|
|||||||
|
|
||||||
ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
|
ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
|
||||||
lhh := f.lightHouse.NewRequestHandler()
|
lhh := f.lightHouse.NewRequestHandler()
|
||||||
plaintext := make([]byte, udp.MTU)
|
|
||||||
h := &header.H{}
|
h := &header.H{}
|
||||||
fwPacket := &firewall.Packet{}
|
fwPacket := &firewall.Packet{}
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
listener := func(fromUdpAddr netip.AddrPort, payload []byte, meta udp.RxMeta) {
|
||||||
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get())
|
plaintext := f.batchers[i].Reserve(len(payload))
|
||||||
})
|
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(), meta)
|
||||||
|
}
|
||||||
|
|
||||||
|
flusher := func() {
|
||||||
|
if err := f.batchers[i].Flush(); err != nil {
|
||||||
|
f.l.Error("Failed to flush tun coalescer", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err := li.ListenOut(listener, flusher)
|
||||||
|
|
||||||
if err != nil && !f.closed.Load() {
|
if err != nil && !f.closed.Load() {
|
||||||
f.l.Error("Error while reading inbound packet, closing", "error", err)
|
f.l.Error("Error while reading inbound packet, closing", "error", err)
|
||||||
@@ -328,16 +344,17 @@ func (f *Interface) listenOut(i int) {
|
|||||||
f.l.Debug("underlay reader is done", "reader", i)
|
f.l.Debug("underlay reader is done", "reader", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
func (f *Interface) listenIn(reader tio.Queue, i int) {
|
||||||
packet := make([]byte, mtu)
|
rejectBuf := make([]byte, mtu)
|
||||||
out := make([]byte, mtu)
|
arenaSize := batch.SendBatchCap * (udp.MTU + 32)
|
||||||
|
sb := batch.NewSendBatch(f.writers[i], batch.SendBatchCap, arenaSize)
|
||||||
fwPacket := &firewall.Packet{}
|
fwPacket := &firewall.Packet{}
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
|
conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
n, err := reader.Read(packet)
|
pkts, err := reader.Read()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !f.closed.Load() {
|
if !f.closed.Load() {
|
||||||
f.l.Error("Error while reading outbound packet, closing", "error", err, "reader", i)
|
f.l.Error("Error while reading outbound packet, closing", "error", err, "reader", i)
|
||||||
@@ -346,7 +363,12 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get())
|
for _, pkt := range pkts {
|
||||||
|
f.consumeInsidePacket(pkt, fwPacket, nb, sb, rejectBuf, i, conntrackCache.Get())
|
||||||
|
}
|
||||||
|
if err := sb.Flush(); err != nil {
|
||||||
|
f.l.Error("Failed to write outgoing batch", "error", err, "writer", i)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
f.l.Debug("overlay reader is done", "reader", i)
|
f.l.Debug("overlay reader is done", "reader", i)
|
||||||
|
|||||||
18
outside.go
18
outside.go
@@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,7 +23,7 @@ const (
|
|||||||
|
|
||||||
var ErrOutOfWindow = errors.New("out of window packet")
|
var ErrOutOfWindow = errors.New("out of window packet")
|
||||||
|
|
||||||
func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
|
func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, meta udp.RxMeta) {
|
||||||
err := h.Parse(packet)
|
err := h.Parse(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
||||||
@@ -110,8 +111,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
|||||||
|
|
||||||
// Relay packets are special
|
// Relay packets are special
|
||||||
if isMessageRelay {
|
if isMessageRelay {
|
||||||
f.handleOutsideRelayPacket(hostinfo, via, out, packet, h, fwPacket, lhf, nb, q, localCache)
|
f.handleOutsideRelayPacket(hostinfo, via, out, packet, h, fwPacket, lhf, nb, q, localCache, meta)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -135,7 +135,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
|||||||
case header.Message:
|
case header.Message:
|
||||||
switch h.Subtype {
|
switch h.Subtype {
|
||||||
case header.MessageNone:
|
case header.MessageNone:
|
||||||
f.handleOutsideMessagePacket(hostinfo, out, packet, fwPacket, nb, q, localCache)
|
f.handleOutsideMessagePacket(hostinfo, out, packet, fwPacket, nb, q, localCache, meta)
|
||||||
default:
|
default:
|
||||||
hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message subtype seen", "from", via, "header", h)
|
hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message subtype seen", "from", via, "header", h)
|
||||||
return
|
return
|
||||||
@@ -168,7 +168,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
|
func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, meta udp.RxMeta) {
|
||||||
// The entire body is sent as AD, not encrypted.
|
// The entire body is sent as AD, not encrypted.
|
||||||
// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
|
// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
|
||||||
// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
|
// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
|
||||||
@@ -211,7 +211,7 @@ func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender,
|
|||||||
relay: relay,
|
relay: relay,
|
||||||
IsRelayed: true,
|
IsRelayed: true,
|
||||||
}
|
}
|
||||||
f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
|
f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, meta)
|
||||||
case ForwardingType:
|
case ForwardingType:
|
||||||
// Find the target HostInfo relay object
|
// Find the target HostInfo relay object
|
||||||
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
|
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
|
||||||
@@ -229,7 +229,7 @@ func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender,
|
|||||||
switch targetRelay.Type {
|
switch targetRelay.Type {
|
||||||
case ForwardingType:
|
case ForwardingType:
|
||||||
// Forward this packet through the relay tunnel
|
// Forward this packet through the relay tunnel
|
||||||
// Find the target HostInfo
|
// Find the target HostInfo //todo it would potentially be nice to batch these
|
||||||
f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false)
|
f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false)
|
||||||
case TerminalType:
|
case TerminalType:
|
||||||
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
|
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
|
||||||
@@ -512,7 +512,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) {
|
func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, meta udp.RxMeta) {
|
||||||
err := newPacket(out, true, fwPacket)
|
err := newPacket(out, true, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).Warn("Error while validating inbound packet",
|
hostinfo.logger(f.l).Warn("Error while validating inbound packet",
|
||||||
@@ -536,7 +536,7 @@ func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, p
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = f.readers[q].Write(out)
|
err = f.batchers[q].Commit(out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.Error("Failed to write to tun", "error", err)
|
f.l.Error("Failed to write to tun", "error", err)
|
||||||
}
|
}
|
||||||
|
|||||||
28
overlay/batch/batch.go
Normal file
28
overlay/batch/batch.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package batch
|
||||||
|
|
||||||
|
import "net/netip"
|
||||||
|
|
||||||
|
type RxBatcher interface {
|
||||||
|
// Reserve creates a pkt to borrow
|
||||||
|
Reserve(sz int) []byte
|
||||||
|
// Commit borrows pkt. The caller must keep pkt valid until the next Flush
|
||||||
|
Commit(pkt []byte) error
|
||||||
|
// Flush emits every queued packet in arrival order. Returns the
|
||||||
|
// first error observed; keeps draining so one bad packet doesn't hold up
|
||||||
|
// the rest. After Flush returns, borrowed payload slices may be recycled.
|
||||||
|
Flush() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type TxBatcher interface {
|
||||||
|
// Reserve creates a pkt to borrow
|
||||||
|
Reserve(sz int) []byte
|
||||||
|
// Commit borrows pkt and records its destination plus the 2-bit
|
||||||
|
// IP-level ECN codepoint to set on the outer (carrier) header. The
|
||||||
|
// caller must keep pkt valid until the next Flush. Pass 0 (Not-ECT)
|
||||||
|
// to leave the outer ECN field unset.
|
||||||
|
Commit(pkt []byte, dst netip.AddrPort, outerECN byte)
|
||||||
|
// Flush emits every queued packet via the underlying batch writer in
|
||||||
|
// arrival order. Returns an errors.Join of one or more errors. After Flush returns,
|
||||||
|
// borrowed payload slices may be recycled.
|
||||||
|
Flush() error
|
||||||
|
}
|
||||||
42
overlay/batch/coalesce_core.go
Normal file
42
overlay/batch/coalesce_core.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package batch
|
||||||
|
|
||||||
|
// Arena is an injectable byte-slab that hands out non-overlapping borrowed
|
||||||
|
// slices via Reserve and releases them in bulk via Reset. Coalescers take
|
||||||
|
// an *Arena at construction so the caller controls the slab lifetime and
|
||||||
|
// can share one slab across multiple coalescers (MultiCoalescer hands the
|
||||||
|
// same *Arena to every lane so the lanes don't carry their own backings).
|
||||||
|
//
|
||||||
|
// Reserve borrows; the slice is valid until the next Reset. The slab grows
|
||||||
|
// (by allocating a fresh, larger backing array) if a Reserve doesn't fit;
|
||||||
|
// pre-size the arena via NewArena to avoid that path on the hot path.
|
||||||
|
type Arena struct {
|
||||||
|
buf []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewArena returns an Arena with a pre-allocated backing of the given
|
||||||
|
// capacity. Pass 0 if you don't intend to call Reserve (e.g. a test that
|
||||||
|
// only feeds the coalescer pre-made []byte packets via Commit).
|
||||||
|
func NewArena(capacity int) *Arena {
|
||||||
|
return &Arena{buf: make([]byte, 0, capacity)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reserve hands out a non-overlapping sz-byte slice from the arena. If the
|
||||||
|
// request doesn't fit the current backing, a fresh, larger backing is
|
||||||
|
// allocated; already-borrowed slices reference the old backing and remain
|
||||||
|
// valid until Reset.
|
||||||
|
func (a *Arena) Reserve(sz int) []byte {
|
||||||
|
if len(a.buf)+sz > cap(a.buf) {
|
||||||
|
newCap := max(cap(a.buf)*2, sz)
|
||||||
|
a.buf = make([]byte, 0, newCap)
|
||||||
|
}
|
||||||
|
start := len(a.buf)
|
||||||
|
a.buf = a.buf[:start+sz]
|
||||||
|
return a.buf[start : start+sz : start+sz]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset releases every slice handed out since the last Reset. Callers must
|
||||||
|
// not use any previously-borrowed slice after this returns. The underlying
|
||||||
|
// backing array is retained so subsequent Reserves don't re-allocate.
|
||||||
|
func (a *Arena) Reset() {
|
||||||
|
a.buf = a.buf[:0]
|
||||||
|
}
|
||||||
52
overlay/batch/passthrough.go
Normal file
52
overlay/batch/passthrough.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package batch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Passthrough is a RxBatcher that doesn't batch anything, it just accumulates and then sends packets.
|
||||||
|
type Passthrough struct {
|
||||||
|
out io.Writer
|
||||||
|
slots [][]byte
|
||||||
|
arena *Arena
|
||||||
|
cursor int
|
||||||
|
}
|
||||||
|
|
||||||
|
const passthroughBaseNumSlots = 128
|
||||||
|
|
||||||
|
// DefaultPassthroughArenaCap is the recommended arena capacity for a
|
||||||
|
// standalone Passthrough batcher: 128 slots × udp.MTU ≈ 1.1 MiB.
|
||||||
|
const DefaultPassthroughArenaCap = passthroughBaseNumSlots * udp.MTU
|
||||||
|
|
||||||
|
func NewPassthrough(w io.Writer, arena *Arena) *Passthrough {
|
||||||
|
return &Passthrough{
|
||||||
|
out: w,
|
||||||
|
slots: make([][]byte, 0, passthroughBaseNumSlots),
|
||||||
|
arena: arena,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Passthrough) Reserve(sz int) []byte {
|
||||||
|
return p.arena.Reserve(sz)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Passthrough) Commit(pkt []byte) error {
|
||||||
|
p.slots = append(p.slots, pkt)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Passthrough) Flush() error {
|
||||||
|
var firstErr error
|
||||||
|
for _, s := range p.slots {
|
||||||
|
_, err := p.out.Write(s)
|
||||||
|
if err != nil && firstErr == nil {
|
||||||
|
firstErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
clear(p.slots)
|
||||||
|
p.slots = p.slots[:0]
|
||||||
|
p.arena.Reset()
|
||||||
|
return firstErr
|
||||||
|
}
|
||||||
65
overlay/batch/tx_batch.go
Normal file
65
overlay/batch/tx_batch.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
package batch
|
||||||
|
|
||||||
|
import "net/netip"
|
||||||
|
|
||||||
|
const SendBatchCap = 128
|
||||||
|
|
||||||
|
// batchWriter is the minimal subset of udp.Conn needed by SendBatch to flush.
|
||||||
|
type batchWriter interface {
|
||||||
|
WriteBatch(bufs [][]byte, addrs []netip.AddrPort, outerECNs []byte) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendBatch accumulates encrypted UDP packets and flushes them via WriteBatch.
|
||||||
|
// One SendBatch is owned by each listenIn goroutine; no locking is needed.
|
||||||
|
// The backing arena grows on demand: when there isn't room for the next slot
|
||||||
|
// we allocate a fresh backing array. Already-committed slices keep referencing
|
||||||
|
// the old array and remain valid until Flush drops them.
|
||||||
|
type SendBatch struct {
|
||||||
|
out batchWriter
|
||||||
|
bufs [][]byte
|
||||||
|
dsts []netip.AddrPort
|
||||||
|
ecns []byte
|
||||||
|
backing []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSendBatch makes a SendBatch with batchCap slots and an arenaSize byte buffer for slices to back those slots
|
||||||
|
func NewSendBatch(out batchWriter, batchCap, arenaSize int) *SendBatch {
|
||||||
|
return &SendBatch{
|
||||||
|
out: out,
|
||||||
|
bufs: make([][]byte, 0, batchCap),
|
||||||
|
dsts: make([]netip.AddrPort, 0, batchCap),
|
||||||
|
ecns: make([]byte, 0, batchCap),
|
||||||
|
backing: make([]byte, 0, arenaSize),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *SendBatch) Reserve(sz int) []byte {
|
||||||
|
if len(b.backing)+sz > cap(b.backing) {
|
||||||
|
// Grow: allocate a fresh backing. Already-committed slices still
|
||||||
|
// reference the old array and remain valid until Flush drops them.
|
||||||
|
newCap := max(cap(b.backing)*2, sz)
|
||||||
|
b.backing = make([]byte, 0, newCap)
|
||||||
|
}
|
||||||
|
start := len(b.backing)
|
||||||
|
b.backing = b.backing[:start+sz]
|
||||||
|
return b.backing[start : start+sz : start+sz]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *SendBatch) Commit(pkt []byte, dst netip.AddrPort, outerECN byte) {
|
||||||
|
b.bufs = append(b.bufs, pkt)
|
||||||
|
b.dsts = append(b.dsts, dst)
|
||||||
|
b.ecns = append(b.ecns, outerECN)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *SendBatch) Flush() error {
|
||||||
|
var err error
|
||||||
|
if len(b.bufs) > 0 {
|
||||||
|
err = b.out.WriteBatch(b.bufs, b.dsts, b.ecns)
|
||||||
|
}
|
||||||
|
clear(b.bufs)
|
||||||
|
b.bufs = b.bufs[:0]
|
||||||
|
b.dsts = b.dsts[:0]
|
||||||
|
b.ecns = b.ecns[:0]
|
||||||
|
b.backing = b.backing[:0]
|
||||||
|
return err
|
||||||
|
}
|
||||||
124
overlay/batch/tx_batch_test.go
Normal file
124
overlay/batch/tx_batch_test.go
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
package batch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeBatchWriter struct {
|
||||||
|
bufs [][]byte
|
||||||
|
addrs []netip.AddrPort
|
||||||
|
ecns []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *fakeBatchWriter) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, ecns []byte) error {
|
||||||
|
// Snapshot — SendBatch.Flush nils its slot pointers right after WriteBatch
|
||||||
|
// returns, so tests must capture data before that happens.
|
||||||
|
w.bufs = make([][]byte, len(bufs))
|
||||||
|
for i, b := range bufs {
|
||||||
|
cp := make([]byte, len(b))
|
||||||
|
copy(cp, b)
|
||||||
|
w.bufs[i] = cp
|
||||||
|
}
|
||||||
|
w.addrs = append(w.addrs[:0], addrs...)
|
||||||
|
w.ecns = append(w.ecns[:0], ecns...)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendBatchReserveCommitFlush(t *testing.T) {
|
||||||
|
fw := &fakeBatchWriter{}
|
||||||
|
b := NewSendBatch(fw, 4, 32)
|
||||||
|
|
||||||
|
ap := netip.MustParseAddrPort("10.0.0.1:4242")
|
||||||
|
for i := 0; i < 4; i++ {
|
||||||
|
slot := b.Reserve(32)
|
||||||
|
if cap(slot) != 32 {
|
||||||
|
t.Fatalf("slot %d: cap=%d want 32", i, cap(slot))
|
||||||
|
}
|
||||||
|
pkt := append(slot[:0], byte(i), byte(i+1), byte(i+2))
|
||||||
|
b.Commit(pkt, ap, 0)
|
||||||
|
}
|
||||||
|
if err := b.Flush(); err != nil {
|
||||||
|
t.Fatalf("Flush: %v", err)
|
||||||
|
}
|
||||||
|
if len(fw.bufs) != 4 {
|
||||||
|
t.Fatalf("WriteBatch got %d bufs want 4", len(fw.bufs))
|
||||||
|
}
|
||||||
|
for i, buf := range fw.bufs {
|
||||||
|
if len(buf) != 3 || buf[0] != byte(i) {
|
||||||
|
t.Errorf("buf %d: %x", i, buf)
|
||||||
|
}
|
||||||
|
if fw.addrs[i] != ap {
|
||||||
|
t.Errorf("addr %d: got %v want %v", i, fw.addrs[i], ap)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush again with nothing committed — should be a no-op.
|
||||||
|
fw.bufs = nil
|
||||||
|
if err := b.Flush(); err != nil {
|
||||||
|
t.Fatalf("empty Flush: %v", err)
|
||||||
|
}
|
||||||
|
if fw.bufs != nil {
|
||||||
|
t.Fatalf("empty Flush triggered WriteBatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reuse after Flush.
|
||||||
|
slot := b.Reserve(32)
|
||||||
|
if cap(slot) != 32 {
|
||||||
|
t.Fatalf("after Flush Reserve wrong cap: %d", cap(slot))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendBatchSlotsDoNotOverlap(t *testing.T) {
|
||||||
|
fw := &fakeBatchWriter{}
|
||||||
|
b := NewSendBatch(fw, 3, 8)
|
||||||
|
ap := netip.MustParseAddrPort("10.0.0.1:80")
|
||||||
|
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
s := b.Reserve(8)
|
||||||
|
pkt := append(s[:0], byte(0xA0+i), byte(0xB0+i))
|
||||||
|
b.Commit(pkt, ap, 0)
|
||||||
|
}
|
||||||
|
if err := b.Flush(); err != nil {
|
||||||
|
t.Fatalf("Flush: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, buf := range fw.bufs {
|
||||||
|
if buf[0] != byte(0xA0+i) || buf[1] != byte(0xB0+i) {
|
||||||
|
t.Errorf("slot %d corrupted: %x", i, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendBatchGrowPreservesCommitted(t *testing.T) {
|
||||||
|
fw := &fakeBatchWriter{}
|
||||||
|
// Tiny initial backing forces a grow on the second Reserve.
|
||||||
|
b := NewSendBatch(fw, 1, 4)
|
||||||
|
ap := netip.MustParseAddrPort("10.0.0.1:80")
|
||||||
|
|
||||||
|
s1 := b.Reserve(4)
|
||||||
|
pkt1 := append(s1[:0], 0x11, 0x22, 0x33, 0x44)
|
||||||
|
b.Commit(pkt1, ap, 0)
|
||||||
|
|
||||||
|
s2 := b.Reserve(8) // exceeds remaining cap, triggers grow
|
||||||
|
pkt2 := append(s2[:0], 0xA, 0xB, 0xC, 0xD, 0xE)
|
||||||
|
b.Commit(pkt2, ap, 0)
|
||||||
|
|
||||||
|
// pkt1 must still be intact even though backing reallocated.
|
||||||
|
if pkt1[0] != 0x11 || pkt1[3] != 0x44 {
|
||||||
|
t.Fatalf("first packet corrupted by grow: %x", pkt1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := b.Flush(); err != nil {
|
||||||
|
t.Fatalf("Flush: %v", err)
|
||||||
|
}
|
||||||
|
if len(fw.bufs) != 2 {
|
||||||
|
t.Fatalf("got %d bufs want 2", len(fw.bufs))
|
||||||
|
}
|
||||||
|
if fw.bufs[0][0] != 0x11 || fw.bufs[0][3] != 0x44 {
|
||||||
|
t.Errorf("first packet on the wire: %x", fw.bufs[0])
|
||||||
|
}
|
||||||
|
if fw.bufs[1][0] != 0xA || fw.bufs[1][4] != 0xE {
|
||||||
|
t.Errorf("second packet on the wire: %x", fw.bufs[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,15 +4,21 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/overlay/tio"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// defaultBatchBufSize is the per-Queue scratch size for Read on backends
|
||||||
|
// that don't do TSO segmentation. 65535 covers any single IP packet.
|
||||||
|
const defaultBatchBufSize = 65535
|
||||||
|
|
||||||
type Device interface {
|
type Device interface {
|
||||||
io.ReadWriteCloser
|
io.Closer
|
||||||
Activate() error
|
Activate() error
|
||||||
Networks() []netip.Prefix
|
Networks() []netip.Prefix
|
||||||
Name() string
|
Name() string
|
||||||
RoutesFor(netip.Addr) routing.Gateways
|
RoutesFor(netip.Addr) routing.Gateways
|
||||||
SupportsMultiqueue() bool
|
SupportsMultiqueue() bool
|
||||||
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
NewMultiQueueReader() error
|
||||||
|
Readers() []tio.Queue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,9 +4,9 @@ package overlaytest
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/overlay/tio"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -31,8 +31,8 @@ func (NoopTun) Name() string {
|
|||||||
return "noop"
|
return "noop"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (NoopTun) Read([]byte) (int, error) {
|
func (NoopTun) Read() ([]tio.Packet, error) {
|
||||||
return 0, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (NoopTun) Write([]byte) (int, error) {
|
func (NoopTun) Write([]byte) (int, error) {
|
||||||
@@ -43,8 +43,12 @@ func (NoopTun) SupportsMultiqueue() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (NoopTun) NewMultiQueueReader() error {
|
||||||
return nil, errors.New("unsupported")
|
return errors.New("unsupported")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (NoopTun) Readers() []tio.Queue {
|
||||||
|
return []tio.Queue{NoopTun{}}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (NoopTun) Close() error {
|
func (NoopTun) Close() error {
|
||||||
|
|||||||
69
overlay/tio/queueset_poll_linux.go
Normal file
69
overlay/tio/queueset_poll_linux.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package tio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
type pollQueueSet struct {
|
||||||
|
pq []*Poll
|
||||||
|
// pqi is exactly the same as pq, but stored as the interface type
|
||||||
|
pqi []Queue
|
||||||
|
shutdownFd int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPollQueueSet() (QueueSet, error) {
|
||||||
|
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create eventfd: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
out := &pollQueueSet{
|
||||||
|
pq: []*Poll{},
|
||||||
|
pqi: []Queue{},
|
||||||
|
shutdownFd: shutdownFd,
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *pollQueueSet) Queues() []Queue {
|
||||||
|
return c.pqi
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *pollQueueSet) Add(fd int) error {
|
||||||
|
x, err := newPoll(fd, c.shutdownFd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.pq = append(c.pq, x)
|
||||||
|
c.pqi = append(c.pqi, x)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *pollQueueSet) wakeForShutdown() error {
|
||||||
|
var buf [8]byte
|
||||||
|
binary.NativeEndian.PutUint64(buf[:], 1)
|
||||||
|
_, err := unix.Write(int(c.shutdownFd), buf[:])
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *pollQueueSet) Close() error {
|
||||||
|
errs := []error{}
|
||||||
|
|
||||||
|
if err := c.wakeForShutdown(); err != nil {
|
||||||
|
errs = append(errs, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, x := range c.pq {
|
||||||
|
if err := x.Close(); err != nil {
|
||||||
|
errs = append(errs, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors.Join(errs...)
|
||||||
|
}
|
||||||
12
overlay/tio/segment.go
Normal file
12
overlay/tio/segment.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package tio
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// SegmentSuperpacket invokes fn once per segment of pkt.
|
||||||
|
// This is a stub implementation that does not actually support segmentation
|
||||||
|
func SegmentSuperpacket(pkt Packet, fn func(seg []byte) error) error {
|
||||||
|
if pkt.GSO.IsSuperpacket() {
|
||||||
|
return fmt.Errorf("tio: GSO superpacket on platform without segmentation support")
|
||||||
|
}
|
||||||
|
return fn(pkt.Bytes)
|
||||||
|
}
|
||||||
170
overlay/tio/tio.go
Normal file
170
overlay/tio/tio.go
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
package tio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// QueueSet holds one or many Queue objects and helps close them in an orderly way.
|
||||||
|
type QueueSet interface {
|
||||||
|
io.Closer
|
||||||
|
Queues() []Queue
|
||||||
|
|
||||||
|
// Add takes a tun fd, adds it to the set, and prepares it for use as a Queue.
|
||||||
|
Add(fd int) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capabilities advertises which kernel offload features a Queue
|
||||||
|
// successfully negotiated. Callers consult this to decide which coalescers
|
||||||
|
// to wire onto the write path — a Queue without TSO can't usefully accept a
|
||||||
|
// TCPCoalescer, and a Queue without USO can't accept a UDPCoalescer.
|
||||||
|
type Capabilities struct {
|
||||||
|
// TSO means the FD was opened with IFF_VNET_HDR and the kernel agreed
|
||||||
|
// to TUN_F_TSO4|TSO6 — i.e. WriteGSO with GSOProtoTCP is safe.
|
||||||
|
TSO bool
|
||||||
|
// USO means the kernel additionally agreed to TUN_F_USO4|USO6, so
|
||||||
|
// WriteGSO with GSOProtoUDP is safe. Linux ≥ 6.2.
|
||||||
|
USO bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Queue is a readable/writable Poll queue. One Queue is driven by a single
|
||||||
|
// read goroutine plus a single writer (see Write below).
|
||||||
|
type Queue interface {
|
||||||
|
io.Closer
|
||||||
|
|
||||||
|
// Read returns one or more packets. The returned Packet.Bytes slices
|
||||||
|
// are borrowed from the Queue's internal buffer and are only valid
|
||||||
|
// until the next Read or Close on this Queue - callers must encrypt
|
||||||
|
// or copy each slice before the next call. A Packet may carry a
|
||||||
|
// GSO/USO superpacket (see GSOInfo); when GSO.IsSuperpacket() is
|
||||||
|
// true the caller must segment Bytes before treating it as a single
|
||||||
|
// IP datagram. Not safe for concurrent Reads.
|
||||||
|
Read() ([]Packet, error)
|
||||||
|
|
||||||
|
// Write emits a single packet on the plaintext (outside→inside)
|
||||||
|
// delivery path. Not safe for concurrent Writes.
|
||||||
|
Write(p []byte) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Packet is the unit Queue.Read returns. Bytes points into the queue's
|
||||||
|
// internal buffer and is only valid until the next Read or Close on the
|
||||||
|
// queue that produced it. GSO is the zero value for an already-segmented
|
||||||
|
// IP datagram; when non-zero it describes a kernel-supplied TSO/USO
|
||||||
|
// superpacket the caller must segment before consuming.
|
||||||
|
type Packet struct {
|
||||||
|
Bytes []byte
|
||||||
|
GSO GSOInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// GSOInfo describes a kernel-supplied superpacket sitting in Packet.Bytes.
|
||||||
|
// The zero value means "not a superpacket" — Bytes is one regular IP
|
||||||
|
// datagram and no segmentation is required.
|
||||||
|
type GSOInfo struct {
|
||||||
|
// Size is the GSO segment size: max payload bytes per segment
|
||||||
|
// (== TCP MSS for TSO, == UDP payload chunk for USO). Zero means
|
||||||
|
// not a superpacket.
|
||||||
|
Size uint16
|
||||||
|
// HdrLen is the total L3+L4 header length within Bytes (already
|
||||||
|
// corrected via correctHdrLen, so safe to slice on).
|
||||||
|
HdrLen uint16
|
||||||
|
// CsumStart is the L4 header offset inside Bytes (== L3 header
|
||||||
|
// length).
|
||||||
|
CsumStart uint16
|
||||||
|
// Proto picks the L4 protocol (TCP or UDP) so the segmenter knows
|
||||||
|
// which checksum/header layout to apply.
|
||||||
|
Proto GSOProto
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSuperpacket reports whether g describes a multi-segment GSO/USO
|
||||||
|
// superpacket that needs segmentation before its bytes can be encrypted
|
||||||
|
// and sent on the wire.
|
||||||
|
func (g GSOInfo) IsSuperpacket() bool { return g.Size > 0 }
|
||||||
|
|
||||||
|
// Clone returns a Packet whose Bytes is a freshly allocated copy of p.Bytes,
|
||||||
|
// safe to retain past the next Read or Close on the originating Queue.
|
||||||
|
// GSO metadata is copied verbatim. Use this only when a caller genuinely
|
||||||
|
// needs to outlive the borrowed-slice contract — the hot path reads should
|
||||||
|
// continue to consume the borrow synchronously to avoid the allocation.
|
||||||
|
func (p Packet) Clone() Packet {
|
||||||
|
if p.Bytes == nil {
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
cp := make([]byte, len(p.Bytes))
|
||||||
|
copy(cp, p.Bytes)
|
||||||
|
return Packet{Bytes: cp, GSO: p.GSO}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CapsProvider is an optional interface implemented by Queues that
|
||||||
|
// successfully negotiated kernel offload features at open time. Callers
|
||||||
|
// pick a write-path coalescer based on the result. Queues that don't
|
||||||
|
// implement it are treated as having no offload capability — callers must
|
||||||
|
// fall back to plain per-packet writes.
|
||||||
|
type CapsProvider interface {
|
||||||
|
Capabilities() Capabilities
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueueCapabilities returns q's negotiated offload capabilities, or the
|
||||||
|
// zero value when q does not advertise any.
|
||||||
|
func QueueCapabilities(q Queue) Capabilities {
|
||||||
|
if cp, ok := q.(CapsProvider); ok {
|
||||||
|
return cp.Capabilities()
|
||||||
|
}
|
||||||
|
return Capabilities{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GSOProto selects the L4 protocol for a GSO superpacket. Determines which
|
||||||
|
// VIRTIO_NET_HDR_GSO_* type the writer stamps and which checksum offset
|
||||||
|
// inside the transport header virtio NEEDS_CSUM expects.
|
||||||
|
type GSOProto uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
GSOProtoTCP GSOProto = iota
|
||||||
|
GSOProtoUDP
|
||||||
|
)
|
||||||
|
|
||||||
|
// GSOWriter is implemented by Queues that can emit a TCP or UDP superpacket
|
||||||
|
// assembled from a header prefix plus one or more borrowed payload
|
||||||
|
// fragments, in a single vectored write (writev with a leading
|
||||||
|
// virtio_net_hdr). This lets the coalescer avoid copying payload bytes
|
||||||
|
// between the caller's decrypt buffer and the TUN. Backends without GSO
|
||||||
|
// support do not implement this interface and coalescing is skipped.
|
||||||
|
//
|
||||||
|
// hdr contains the IPv4/IPv6 header prefix (mutable - callers will have
|
||||||
|
// filled in total length and IP csum). transportHdr is the TCP or UDP
|
||||||
|
// header (mutable - the L4 checksum field must hold the pseudo-header
|
||||||
|
// partial, single-fold not inverted, per virtio NEEDS_CSUM semantics).
|
||||||
|
// pays are non-overlapping payload fragments whose concatenation is the
|
||||||
|
// full superpacket payload; they are read-only from the writer's
|
||||||
|
// perspective and must remain valid until the call returns. Every segment
|
||||||
|
// in pays except possibly the last is exactly the same size. proto picks
|
||||||
|
// the L4 protocol so the writer knows which GSOType / CsumOffset to set.
|
||||||
|
//
|
||||||
|
// Callers should also consult CapsProvider (via SupportsGSO or
|
||||||
|
// QueueCapabilities) for the per-protocol negotiated capability; an
|
||||||
|
// implementation of GSOWriter is necessary but not sufficient since USO
|
||||||
|
// may not have been negotiated even when TSO was.
|
||||||
|
type GSOWriter interface {
|
||||||
|
WriteGSO(hdr []byte, transportHdr []byte, pays [][]byte, proto GSOProto) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// SupportsGSO reports whether w implements GSOWriter and the underlying
|
||||||
|
// queue advertises the negotiated capability for `want`. A writer that
|
||||||
|
// implements GSOWriter but not CapsProvider is treated as permissive
|
||||||
|
// (used by tests and fakes that don't negotiate).
|
||||||
|
func SupportsGSO(w any, want GSOProto) (GSOWriter, bool) {
|
||||||
|
gw, ok := w.(GSOWriter)
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
cp, ok := w.(CapsProvider)
|
||||||
|
if !ok {
|
||||||
|
return gw, true
|
||||||
|
}
|
||||||
|
caps := cp.Capabilities()
|
||||||
|
switch want {
|
||||||
|
case GSOProtoTCP:
|
||||||
|
return gw, caps.TSO
|
||||||
|
case GSOProtoUDP:
|
||||||
|
return gw, caps.USO
|
||||||
|
}
|
||||||
|
return gw, false
|
||||||
|
}
|
||||||
168
overlay/tio/tio_poll_linux.go
Normal file
168
overlay/tio/tio_poll_linux.go
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
package tio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Poll struct {
|
||||||
|
fd int
|
||||||
|
|
||||||
|
readPoll [2]unix.PollFd
|
||||||
|
writePoll [2]unix.PollFd
|
||||||
|
writeLock sync.Mutex
|
||||||
|
closed atomic.Bool
|
||||||
|
|
||||||
|
readBuf []byte
|
||||||
|
batchRet [1]Packet
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPoll(fd int, shutdownFd int) (*Poll, error) {
|
||||||
|
if err := unix.SetNonblock(fd, true); err != nil {
|
||||||
|
_ = unix.Close(fd)
|
||||||
|
return nil, fmt.Errorf("failed to set Poll device as nonblocking: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
out := &Poll{
|
||||||
|
fd: fd,
|
||||||
|
readBuf: make([]byte, 65535),
|
||||||
|
readPoll: [2]unix.PollFd{
|
||||||
|
{Fd: int32(fd), Events: unix.POLLIN},
|
||||||
|
{Fd: int32(shutdownFd), Events: unix.POLLIN},
|
||||||
|
},
|
||||||
|
writePoll: [2]unix.PollFd{
|
||||||
|
{Fd: int32(fd), Events: unix.POLLOUT},
|
||||||
|
{Fd: int32(shutdownFd), Events: unix.POLLIN},
|
||||||
|
},
|
||||||
|
writeLock: sync.Mutex{},
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// blockOnRead waits until the Poll fd is readable or shutdown has been signaled.
|
||||||
|
// Returns os.ErrClosed if Close was called.
|
||||||
|
func (t *Poll) blockOnRead() error {
|
||||||
|
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
|
||||||
|
var err error
|
||||||
|
for {
|
||||||
|
_, err = unix.Poll(t.readPoll[:], -1)
|
||||||
|
if err != unix.EINTR {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tunEvents := t.readPoll[0].Revents
|
||||||
|
shutdownEvents := t.readPoll[1].Revents
|
||||||
|
t.readPoll[0].Revents = 0
|
||||||
|
t.readPoll[1].Revents = 0
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
|
||||||
|
return os.ErrClosed
|
||||||
|
}
|
||||||
|
if tunEvents&problemFlags != 0 {
|
||||||
|
return os.ErrClosed
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Poll) blockOnWrite() error {
|
||||||
|
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
|
||||||
|
var err error
|
||||||
|
for {
|
||||||
|
_, err = unix.Poll(t.writePoll[:], -1)
|
||||||
|
if err != unix.EINTR {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.writeLock.Lock()
|
||||||
|
tunEvents := t.writePoll[0].Revents
|
||||||
|
shutdownEvents := t.writePoll[1].Revents
|
||||||
|
t.writePoll[0].Revents = 0
|
||||||
|
t.writePoll[1].Revents = 0
|
||||||
|
t.writeLock.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
|
||||||
|
return os.ErrClosed
|
||||||
|
}
|
||||||
|
if tunEvents&problemFlags != 0 {
|
||||||
|
return os.ErrClosed
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Poll) Read() ([]Packet, error) {
|
||||||
|
n, err := t.readOne(t.readBuf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t.batchRet[0] = Packet{Bytes: t.readBuf[:n]}
|
||||||
|
return t.batchRet[:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Poll) readOne(to []byte) (int, error) {
|
||||||
|
for {
|
||||||
|
n, errno := unix.Read(t.fd, to)
|
||||||
|
if errno == nil {
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
switch errno {
|
||||||
|
case unix.EAGAIN:
|
||||||
|
if err := t.blockOnRead(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
case unix.EINTR:
|
||||||
|
// retry
|
||||||
|
case unix.EBADF:
|
||||||
|
return 0, os.ErrClosed
|
||||||
|
default:
|
||||||
|
return 0, errno
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write is only valid for single threaded use
|
||||||
|
func (t *Poll) Write(from []byte) (int, error) {
|
||||||
|
for {
|
||||||
|
n, errno := unix.Write(t.fd, from)
|
||||||
|
if errno == nil {
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
switch errno {
|
||||||
|
case unix.EAGAIN:
|
||||||
|
if err := t.blockOnWrite(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
case unix.EINTR:
|
||||||
|
// retry
|
||||||
|
case unix.EBADF:
|
||||||
|
return 0, os.ErrClosed
|
||||||
|
default:
|
||||||
|
return 0, errno
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Poll) Close() error {
|
||||||
|
if t.closed.Swap(true) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
//shutdownFd is owned by the container, so we should not close it
|
||||||
|
var err error
|
||||||
|
if t.fd >= 0 {
|
||||||
|
err = unix.Close(t.fd)
|
||||||
|
t.fd = -1
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Poll) Capabilities() Capabilities {
|
||||||
|
return Capabilities{TSO: false, USO: false}
|
||||||
|
}
|
||||||
82
overlay/tio/tun_file_linux_test.go
Normal file
82
overlay/tio/tun_file_linux_test.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
//go:build linux && !android && !e2e_testing
|
||||||
|
// +build linux,!android,!e2e_testing
|
||||||
|
|
||||||
|
package tio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newReadPipe returns a read fd. The matching write fd is registered for cleanup.
|
||||||
|
// The caller takes ownership of the read fd (pass it into a QueueSet).
|
||||||
|
func newReadPipe(t *testing.T) int {
|
||||||
|
t.Helper()
|
||||||
|
var fds [2]int
|
||||||
|
if err := unix.Pipe2(fds[:], unix.O_CLOEXEC); err != nil {
|
||||||
|
t.Fatalf("pipe2: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = unix.Close(fds[1]) })
|
||||||
|
return fds[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPoll_WakeForShutdown_WakesFriends(t *testing.T) {
|
||||||
|
pipe1 := newReadPipe(t)
|
||||||
|
pipe2 := newReadPipe(t)
|
||||||
|
parent, err := NewPollQueueSet()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, parent.Add(pipe1))
|
||||||
|
require.NoError(t, parent.Add(pipe2))
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = unix.Close(pipe1)
|
||||||
|
_ = unix.Close(pipe2)
|
||||||
|
})
|
||||||
|
|
||||||
|
readers := parent.Queues()
|
||||||
|
errs := make([]error, len(readers))
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i, r := range readers {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(i int, r Queue) {
|
||||||
|
defer wg.Done()
|
||||||
|
_, errs[i] = r.Read()
|
||||||
|
}(i, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
if err := parent.Close(); err != nil {
|
||||||
|
t.Fatalf("Close: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() { wg.Wait(); close(done) }()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("readers did not wake")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, err := range errs {
|
||||||
|
if !errors.Is(err, os.ErrClosed) {
|
||||||
|
t.Errorf("reader %d: expected os.ErrClosed, got %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPoll_Close_Idempotent(t *testing.T) {
|
||||||
|
tf, err := newPoll(newReadPipe(t), 1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
if err := tf.Close(); err != nil {
|
||||||
|
t.Fatalf("first Close: %v", err)
|
||||||
|
}
|
||||||
|
if err := tf.Close(); err != nil {
|
||||||
|
t.Fatalf("second Close should be a no-op, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -13,17 +13,38 @@ import (
|
|||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/overlay/tio"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
io.ReadWriteCloser
|
rwc io.ReadWriteCloser
|
||||||
fd int
|
fd int
|
||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
l *slog.Logger
|
l *slog.Logger
|
||||||
|
|
||||||
|
readBuf []byte
|
||||||
|
batchRet [1]tio.Packet
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Read() ([]tio.Packet, error) {
|
||||||
|
n, err := t.rwc.Read(t.readBuf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
|
||||||
|
return t.batchRet[:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Write(p []byte) (int, error) {
|
||||||
|
return t.rwc.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Close() error {
|
||||||
|
return t.rwc.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
@@ -32,10 +53,11 @@ func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip
|
|||||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||||
|
|
||||||
t := &tun{
|
t := &tun{
|
||||||
ReadWriteCloser: file,
|
rwc: file,
|
||||||
fd: deviceFd,
|
fd: deviceFd,
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
l: l,
|
l: l,
|
||||||
|
readBuf: make([]byte, defaultBatchBufSize),
|
||||||
}
|
}
|
||||||
|
|
||||||
err := t.reload(c, true)
|
err := t.reload(c, true)
|
||||||
@@ -62,7 +84,7 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t tun) Activate() error {
|
func (t *tun) Activate() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -99,6 +121,10 @@ func (t *tun) SupportsMultiqueue() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() error {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
|
return fmt.Errorf("TODO: multiqueue not implemented for android")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Readers() []tio.Queue {
|
||||||
|
return []tio.Queue{t}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/overlay/tio"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
netroute "golang.org/x/net/route"
|
netroute "golang.org/x/net/route"
|
||||||
@@ -23,7 +24,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
io.ReadWriteCloser
|
rwc io.ReadWriteCloser
|
||||||
Device string
|
Device string
|
||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
DefaultMTU int
|
DefaultMTU int
|
||||||
@@ -34,6 +35,9 @@ type tun struct {
|
|||||||
|
|
||||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||||
out []byte
|
out []byte
|
||||||
|
|
||||||
|
readBuf []byte
|
||||||
|
batchRet [1]tio.Packet
|
||||||
}
|
}
|
||||||
|
|
||||||
type ifReq struct {
|
type ifReq struct {
|
||||||
@@ -124,11 +128,12 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t
|
|||||||
}
|
}
|
||||||
|
|
||||||
t := &tun{
|
t := &tun{
|
||||||
ReadWriteCloser: os.NewFile(uintptr(fd), ""),
|
rwc: os.NewFile(uintptr(fd), ""),
|
||||||
Device: name,
|
Device: name,
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
DefaultMTU: c.GetInt("tun.mtu", DefaultMTU),
|
DefaultMTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||||
l: l,
|
l: l,
|
||||||
|
readBuf: make([]byte, defaultBatchBufSize),
|
||||||
}
|
}
|
||||||
|
|
||||||
err = t.reload(c, true)
|
err = t.reload(c, true)
|
||||||
@@ -158,8 +163,8 @@ func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Close() error {
|
func (t *tun) Close() error {
|
||||||
if t.ReadWriteCloser != nil {
|
if t.rwc != nil {
|
||||||
return t.ReadWriteCloser.Close()
|
return t.rwc.Close()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -502,15 +507,24 @@ func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Read(to []byte) (int, error) {
|
func (t *tun) readOne(to []byte) (int, error) {
|
||||||
buf := make([]byte, len(to)+4)
|
buf := make([]byte, len(to)+4)
|
||||||
|
|
||||||
n, err := t.ReadWriteCloser.Read(buf)
|
n, err := t.rwc.Read(buf)
|
||||||
|
|
||||||
copy(to, buf[4:])
|
copy(to, buf[4:])
|
||||||
return n - 4, err
|
return n - 4, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) Read() ([]tio.Packet, error) {
|
||||||
|
n, err := t.readOne(t.readBuf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
|
||||||
|
return t.batchRet[:], nil
|
||||||
|
}
|
||||||
|
|
||||||
// Write is only valid for single threaded use
|
// Write is only valid for single threaded use
|
||||||
func (t *tun) Write(from []byte) (int, error) {
|
func (t *tun) Write(from []byte) (int, error) {
|
||||||
buf := t.out
|
buf := t.out
|
||||||
@@ -536,7 +550,7 @@ func (t *tun) Write(from []byte) (int, error) {
|
|||||||
|
|
||||||
copy(buf[4:], from)
|
copy(buf[4:], from)
|
||||||
|
|
||||||
n, err := t.ReadWriteCloser.Write(buf)
|
n, err := t.rwc.Write(buf)
|
||||||
return n - 4, err
|
return n - 4, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -552,6 +566,10 @@ func (t *tun) SupportsMultiqueue() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() error {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
return fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Readers() []tio.Queue {
|
||||||
|
return []tio.Queue{t}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
"github.com/slackhq/nebula/overlay/tio"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -18,9 +19,45 @@ type disabledTun struct {
|
|||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
|
|
||||||
// Track these metrics since we don't have the tun device to do it for us
|
// Track these metrics since we don't have the tun device to do it for us
|
||||||
tx metrics.Counter
|
tx metrics.Counter
|
||||||
rx metrics.Counter
|
rx metrics.Counter
|
||||||
l *slog.Logger
|
l *slog.Logger
|
||||||
|
numReaders int
|
||||||
|
}
|
||||||
|
|
||||||
|
// disabledQueue is one tio.Queue view onto a shared disabledTun. Each queue
|
||||||
|
// owns a private batchRet so concurrent Read calls from different reader
|
||||||
|
// goroutines do not race on the returned slice.
|
||||||
|
type disabledQueue struct {
|
||||||
|
parent *disabledTun
|
||||||
|
batchRet [1]tio.Packet
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *disabledQueue) Read() ([]tio.Packet, error) {
|
||||||
|
r, ok := <-q.parent.read
|
||||||
|
if !ok {
|
||||||
|
return nil, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
q.parent.tx.Inc(1)
|
||||||
|
if q.parent.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
q.parent.l.Debug("Write payload", "raw", prettyPacket(r))
|
||||||
|
}
|
||||||
|
|
||||||
|
q.batchRet[0] = tio.Packet{Bytes: r}
|
||||||
|
return q.batchRet[:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write on a queue forwards to the underlying disabledTun. All queues share
|
||||||
|
// one ICMP-handling/log path so this is a thin pass-through.
|
||||||
|
func (q *disabledQueue) Write(b []byte) (int, error) {
|
||||||
|
return q.parent.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close on a queue is a no-op. The shared channel and metrics are owned by
|
||||||
|
// the disabledTun; Close on the device tears them down once for everybody.
|
||||||
|
func (q *disabledQueue) Close() error {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *slog.Logger) *disabledTun {
|
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *slog.Logger) *disabledTun {
|
||||||
@@ -28,6 +65,7 @@ func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled boo
|
|||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
read: make(chan []byte, queueLen),
|
read: make(chan []byte, queueLen),
|
||||||
l: l,
|
l: l,
|
||||||
|
numReaders: 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
if metricsEnabled {
|
if metricsEnabled {
|
||||||
@@ -57,24 +95,6 @@ func (*disabledTun) Name() string {
|
|||||||
return "disabled"
|
return "disabled"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *disabledTun) Read(b []byte) (int, error) {
|
|
||||||
r, ok := <-t.read
|
|
||||||
if !ok {
|
|
||||||
return 0, io.EOF
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(r) > len(b) {
|
|
||||||
return 0, fmt.Errorf("packet larger than mtu: %d > %d bytes", len(r), len(b))
|
|
||||||
}
|
|
||||||
|
|
||||||
t.tx.Inc(1)
|
|
||||||
if t.l.Enabled(context.Background(), slog.LevelDebug) {
|
|
||||||
t.l.Debug("Write payload", "raw", prettyPacket(r))
|
|
||||||
}
|
|
||||||
|
|
||||||
return copy(b, r), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *disabledTun) handleICMPEchoRequest(b []byte) bool {
|
func (t *disabledTun) handleICMPEchoRequest(b []byte) bool {
|
||||||
out := make([]byte, len(b))
|
out := make([]byte, len(b))
|
||||||
out = iputil.CreateICMPEchoResponse(b, out)
|
out = iputil.CreateICMPEchoResponse(b, out)
|
||||||
@@ -110,8 +130,17 @@ func (t *disabledTun) SupportsMultiqueue() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *disabledTun) NewMultiQueueReader() error {
|
||||||
return t, nil
|
t.numReaders++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *disabledTun) Readers() []tio.Queue {
|
||||||
|
out := make([]tio.Queue, t.numReaders)
|
||||||
|
for i := range t.numReaders {
|
||||||
|
out[i] = &disabledQueue{parent: t}
|
||||||
|
}
|
||||||
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *disabledTun) Close() error {
|
func (t *disabledTun) Close() error {
|
||||||
|
|||||||
@@ -1,120 +0,0 @@
|
|||||||
//go:build linux && !android && !e2e_testing
|
|
||||||
// +build linux,!android,!e2e_testing
|
|
||||||
|
|
||||||
package overlay
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"os"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
// newReadPipe returns a read fd. The matching write fd is registered for cleanup.
|
|
||||||
// The caller takes ownership of the read fd (pass it to newTunFd / newFriend).
|
|
||||||
func newReadPipe(t *testing.T) int {
|
|
||||||
t.Helper()
|
|
||||||
var fds [2]int
|
|
||||||
if err := unix.Pipe2(fds[:], unix.O_CLOEXEC); err != nil {
|
|
||||||
t.Fatalf("pipe2: %v", err)
|
|
||||||
}
|
|
||||||
t.Cleanup(func() { _ = unix.Close(fds[1]) })
|
|
||||||
return fds[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTunFile_WakeForShutdown_UnblocksRead(t *testing.T) {
|
|
||||||
tf, err := newTunFd(newReadPipe(t))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("newTunFd: %v", err)
|
|
||||||
}
|
|
||||||
t.Cleanup(func() { _ = tf.Close() })
|
|
||||||
|
|
||||||
done := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
_, err := tf.Read(make([]byte, 64))
|
|
||||||
done <- err
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Verify Read is actually blocked in poll.
|
|
||||||
select {
|
|
||||||
case err := <-done:
|
|
||||||
t.Fatalf("Read returned before shutdown signal: %v", err)
|
|
||||||
case <-time.After(50 * time.Millisecond):
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tf.wakeForShutdown(); err != nil {
|
|
||||||
t.Fatalf("wakeForShutdown: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err := <-done:
|
|
||||||
if !errors.Is(err, os.ErrClosed) {
|
|
||||||
t.Fatalf("expected os.ErrClosed, got %v", err)
|
|
||||||
}
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
t.Fatal("Read did not wake on shutdown")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTunFile_WakeForShutdown_WakesFriends(t *testing.T) {
|
|
||||||
parent, err := newTunFd(newReadPipe(t))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("newTunFd: %v", err)
|
|
||||||
}
|
|
||||||
friend, err := parent.newFriend(newReadPipe(t))
|
|
||||||
if err != nil {
|
|
||||||
_ = parent.Close()
|
|
||||||
t.Fatalf("newFriend: %v", err)
|
|
||||||
}
|
|
||||||
t.Cleanup(func() {
|
|
||||||
_ = friend.Close()
|
|
||||||
_ = parent.Close()
|
|
||||||
})
|
|
||||||
|
|
||||||
readers := []*tunFile{parent, friend}
|
|
||||||
errs := make([]error, len(readers))
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for i, r := range readers {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(i int, r *tunFile) {
|
|
||||||
defer wg.Done()
|
|
||||||
_, errs[i] = r.Read(make([]byte, 64))
|
|
||||||
}(i, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
|
|
||||||
if err := parent.wakeForShutdown(); err != nil {
|
|
||||||
t.Fatalf("wakeForShutdown: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() { wg.Wait(); close(done) }()
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
t.Fatal("readers did not wake")
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, err := range errs {
|
|
||||||
if !errors.Is(err, os.ErrClosed) {
|
|
||||||
t.Errorf("reader %d: expected os.ErrClosed, got %v", i, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTunFile_Close_Idempotent(t *testing.T) {
|
|
||||||
tf, err := newTunFd(newReadPipe(t))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("newTunFd: %v", err)
|
|
||||||
}
|
|
||||||
if err := tf.Close(); err != nil {
|
|
||||||
t.Fatalf("first Close: %v", err)
|
|
||||||
}
|
|
||||||
if err := tf.Close(); err != nil {
|
|
||||||
t.Fatalf("second Close should be a no-op, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -20,7 +19,7 @@ import (
|
|||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/overlay/tio"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
netroute "golang.org/x/net/route"
|
netroute "golang.org/x/net/route"
|
||||||
@@ -103,6 +102,9 @@ type tun struct {
|
|||||||
readPoll [2]unix.PollFd
|
readPoll [2]unix.PollFd
|
||||||
writePoll [2]unix.PollFd
|
writePoll [2]unix.PollFd
|
||||||
closed atomic.Bool
|
closed atomic.Bool
|
||||||
|
|
||||||
|
readBuf []byte
|
||||||
|
batchRet [1]tio.Packet
|
||||||
}
|
}
|
||||||
|
|
||||||
// blockOnRead waits until the tun fd is readable or shutdown has been signaled.
|
// blockOnRead waits until the tun fd is readable or shutdown has been signaled.
|
||||||
@@ -157,7 +159,16 @@ func (t *tun) blockOnWrite() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Read(to []byte) (int, error) {
|
func (t *tun) Read() ([]tio.Packet, error) {
|
||||||
|
n, err := t.readOne(t.readBuf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
|
||||||
|
return t.batchRet[:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) readOne(to []byte) (int, error) {
|
||||||
// first 4 bytes is protocol family, in network byte order
|
// first 4 bytes is protocol family, in network byte order
|
||||||
var head [4]byte
|
var head [4]byte
|
||||||
iovecs := [2]syscall.Iovec{
|
iovecs := [2]syscall.Iovec{
|
||||||
@@ -375,6 +386,7 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t
|
|||||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||||
l: l,
|
l: l,
|
||||||
fd: fd,
|
fd: fd,
|
||||||
|
readBuf: make([]byte, defaultBatchBufSize),
|
||||||
shutdownR: shutdownR,
|
shutdownR: shutdownR,
|
||||||
shutdownW: shutdownW,
|
shutdownW: shutdownW,
|
||||||
readPoll: [2]unix.PollFd{
|
readPoll: [2]unix.PollFd{
|
||||||
@@ -565,8 +577,8 @@ func (t *tun) SupportsMultiqueue() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() error {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
return fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) addRoutes(logErrors bool) error {
|
func (t *tun) addRoutes(logErrors bool) error {
|
||||||
@@ -593,6 +605,10 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) Readers() []tio.Queue {
|
||||||
|
return []tio.Queue{t}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) removeRoutes(routes []Route) error {
|
func (t *tun) removeRoutes(routes []Route) error {
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !r.Install {
|
if !r.Install {
|
||||||
|
|||||||
@@ -16,16 +16,37 @@ import (
|
|||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/overlay/tio"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
io.ReadWriteCloser
|
rwc io.ReadWriteCloser
|
||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
l *slog.Logger
|
l *slog.Logger
|
||||||
|
|
||||||
|
readBuf []byte
|
||||||
|
batchRet [1]tio.Packet
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Read() ([]tio.Packet, error) {
|
||||||
|
n, err := t.rwc.Read(t.readBuf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
|
||||||
|
return t.batchRet[:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Write(p []byte) (int, error) {
|
||||||
|
return t.rwc.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Close() error {
|
||||||
|
return t.rwc.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||||
@@ -35,9 +56,10 @@ func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error)
|
|||||||
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
|
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
|
||||||
t := &tun{
|
t := &tun{
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
ReadWriteCloser: &tunReadCloser{f: file},
|
rwc: &tunReadCloser{f: file},
|
||||||
l: l,
|
l: l,
|
||||||
|
readBuf: make([]byte, defaultBatchBufSize),
|
||||||
}
|
}
|
||||||
|
|
||||||
err := t.reload(c, true)
|
err := t.reload(c, true)
|
||||||
@@ -155,6 +177,10 @@ func (t *tun) SupportsMultiqueue() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() error {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
|
return fmt.Errorf("TODO: multiqueue not implemented for ios")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Readers() []tio.Queue {
|
||||||
|
return []tio.Queue{t}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,9 +4,7 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -19,180 +17,15 @@ import (
|
|||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/overlay/tio"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
// tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking.
|
|
||||||
// A shared eventfd allows Close to wake all readers blocked in poll.
|
|
||||||
type tunFile struct {
|
|
||||||
fd int
|
|
||||||
shutdownFd int
|
|
||||||
lastOne bool
|
|
||||||
readPoll [2]unix.PollFd
|
|
||||||
writePoll [2]unix.PollFd
|
|
||||||
closed bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// newFriend makes a tunFile for a MultiQueueReader that copies the shutdown eventfd from the parent tun
|
|
||||||
func (r *tunFile) newFriend(fd int) (*tunFile, error) {
|
|
||||||
if err := unix.SetNonblock(fd, true); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
|
|
||||||
}
|
|
||||||
return &tunFile{
|
|
||||||
fd: fd,
|
|
||||||
shutdownFd: r.shutdownFd,
|
|
||||||
readPoll: [2]unix.PollFd{
|
|
||||||
{Fd: int32(fd), Events: unix.POLLIN},
|
|
||||||
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
|
|
||||||
},
|
|
||||||
writePoll: [2]unix.PollFd{
|
|
||||||
{Fd: int32(fd), Events: unix.POLLOUT},
|
|
||||||
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTunFd(fd int) (*tunFile, error) {
|
|
||||||
if err := unix.SetNonblock(fd, true); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create eventfd: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
out := &tunFile{
|
|
||||||
fd: fd,
|
|
||||||
shutdownFd: shutdownFd,
|
|
||||||
lastOne: true,
|
|
||||||
readPoll: [2]unix.PollFd{
|
|
||||||
{Fd: int32(fd), Events: unix.POLLIN},
|
|
||||||
{Fd: int32(shutdownFd), Events: unix.POLLIN},
|
|
||||||
},
|
|
||||||
writePoll: [2]unix.PollFd{
|
|
||||||
{Fd: int32(fd), Events: unix.POLLOUT},
|
|
||||||
{Fd: int32(shutdownFd), Events: unix.POLLIN},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *tunFile) blockOnRead() error {
|
|
||||||
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
|
|
||||||
var err error
|
|
||||||
for {
|
|
||||||
_, err = unix.Poll(r.readPoll[:], -1)
|
|
||||||
if err != unix.EINTR {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
//always reset these!
|
|
||||||
tunEvents := r.readPoll[0].Revents
|
|
||||||
shutdownEvents := r.readPoll[1].Revents
|
|
||||||
r.readPoll[0].Revents = 0
|
|
||||||
r.readPoll[1].Revents = 0
|
|
||||||
//do the err check before trusting the potentially bogus bits we just got
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
|
|
||||||
return os.ErrClosed
|
|
||||||
} else if tunEvents&problemFlags != 0 {
|
|
||||||
return os.ErrClosed
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *tunFile) blockOnWrite() error {
|
|
||||||
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
|
|
||||||
var err error
|
|
||||||
for {
|
|
||||||
_, err = unix.Poll(r.writePoll[:], -1)
|
|
||||||
if err != unix.EINTR {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
//always reset these!
|
|
||||||
tunEvents := r.writePoll[0].Revents
|
|
||||||
shutdownEvents := r.writePoll[1].Revents
|
|
||||||
r.writePoll[0].Revents = 0
|
|
||||||
r.writePoll[1].Revents = 0
|
|
||||||
//do the err check before trusting the potentially bogus bits we just got
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
|
|
||||||
return os.ErrClosed
|
|
||||||
} else if tunEvents&problemFlags != 0 {
|
|
||||||
return os.ErrClosed
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *tunFile) Read(buf []byte) (int, error) {
|
|
||||||
for {
|
|
||||||
if n, err := unix.Read(r.fd, buf); err == nil {
|
|
||||||
return n, nil
|
|
||||||
} else if err == unix.EAGAIN {
|
|
||||||
if err = r.blockOnRead(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
} else if err == unix.EINTR {
|
|
||||||
continue
|
|
||||||
} else if err == unix.EBADF {
|
|
||||||
return 0, os.ErrClosed
|
|
||||||
} else {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *tunFile) Write(buf []byte) (int, error) {
|
|
||||||
for {
|
|
||||||
if n, err := unix.Write(r.fd, buf); err == nil {
|
|
||||||
return n, nil
|
|
||||||
} else if err == unix.EAGAIN {
|
|
||||||
if err = r.blockOnWrite(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
} else if err == unix.EINTR {
|
|
||||||
continue
|
|
||||||
} else if err == unix.EBADF {
|
|
||||||
return 0, os.ErrClosed
|
|
||||||
} else {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *tunFile) wakeForShutdown() error {
|
|
||||||
var buf [8]byte
|
|
||||||
binary.NativeEndian.PutUint64(buf[:], 1)
|
|
||||||
_, err := unix.Write(int(r.readPoll[1].Fd), buf[:])
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *tunFile) Close() error {
|
|
||||||
if r.closed { // avoid closing more than once. Technically a fd could get re-used, which would be a problem
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
r.closed = true
|
|
||||||
if r.lastOne {
|
|
||||||
_ = unix.Close(r.shutdownFd)
|
|
||||||
}
|
|
||||||
return unix.Close(r.fd)
|
|
||||||
}
|
|
||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
*tunFile
|
readers tio.QueueSet
|
||||||
readers []*tunFile
|
|
||||||
closeLock sync.Mutex
|
closeLock sync.Mutex
|
||||||
Device string
|
Device string
|
||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
@@ -239,7 +72,9 @@ type ifreqQLEN struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
t, err := newTunGeneric(c, l, deviceFd, vpnNetworks)
|
// We don't know what flags the caller opened this fd with and can't turn
|
||||||
|
// on IFF_VNET_HDR after TUNSETIFF, so skip offload on inherited fds.
|
||||||
|
t, err := newTunGeneric(c, l, deviceFd, false, false, vpnNetworks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -249,46 +84,65 @@ func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
// openTunDev opens /dev/net/tun, creating the device node first if it's
|
||||||
|
// missing (docker containers occasionally omit it).
|
||||||
|
func openTunDev() (int, error) {
|
||||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err == nil {
|
||||||
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
|
return fd, nil
|
||||||
if os.IsNotExist(err) {
|
|
||||||
err = os.MkdirAll("/dev/net", 0755)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
|
|
||||||
}
|
|
||||||
err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200)))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if !os.IsNotExist(err) {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
if err = os.MkdirAll("/dev/net", 0755); err != nil {
|
||||||
|
return -1, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
|
||||||
|
}
|
||||||
|
if err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))); err != nil {
|
||||||
|
return -1, fmt.Errorf("failed to create /dev/net/tun: %w", err)
|
||||||
|
}
|
||||||
|
fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
|
if err != nil {
|
||||||
|
return -1, fmt.Errorf("created /dev/net/tun, but still failed: %w", err)
|
||||||
|
}
|
||||||
|
return fd, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tunSetIff runs TUNSETIFF with the given flags and returns the kernel-chosen
|
||||||
|
// device name on success.
|
||||||
|
func tunSetIff(fd int, name string, flags uint16) (string, error) {
|
||||||
var req ifReq
|
var req ifReq
|
||||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
|
req.Flags = flags
|
||||||
|
copy(req.Name[:], name)
|
||||||
|
if err := ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return strings.Trim(string(req.Name[:]), "\x00"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
||||||
|
baseFlags := uint16(unix.IFF_TUN | unix.IFF_NO_PI)
|
||||||
if multiqueue {
|
if multiqueue {
|
||||||
req.Flags |= unix.IFF_MULTI_QUEUE
|
baseFlags |= unix.IFF_MULTI_QUEUE
|
||||||
}
|
}
|
||||||
nameStr := c.GetString("tun.dev", "")
|
nameStr := c.GetString("tun.dev", "")
|
||||||
copy(req.Name[:], nameStr)
|
|
||||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
|
||||||
_ = unix.Close(fd)
|
|
||||||
return nil, &NameError{
|
|
||||||
Name: nameStr,
|
|
||||||
Underlying: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
name := strings.Trim(string(req.Name[:]), "\x00")
|
|
||||||
|
|
||||||
t, err := newTunGeneric(c, l, fd, vpnNetworks)
|
// First try to enable IFF_VNET_HDR via TUNSETIFF and negotiate TUN_F_*
|
||||||
|
// offloads via TUNSETOFFLOAD so we can receive TSO/USO superpackets.
|
||||||
|
// We try TSO+USO first, fall back to TSO-only on kernels without USO
|
||||||
|
// (Linux < 6.2), and finally give up on virtio headers entirely and
|
||||||
|
// reopen as a plain TUN if neither offload mask is accepted.
|
||||||
|
fd, err := openTunDev()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
name, err := tunSetIff(fd, nameStr, baseFlags)
|
||||||
|
if err != nil {
|
||||||
|
_ = unix.Close(fd)
|
||||||
|
return nil, &NameError{Name: nameStr, Underlying: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
t, err := newTunGeneric(c, l, fd, false, false, vpnNetworks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -299,15 +153,21 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue
|
|||||||
}
|
}
|
||||||
|
|
||||||
// newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error.
|
// newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error.
|
||||||
func newTunGeneric(c *config.C, l *slog.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
func newTunGeneric(c *config.C, l *slog.Logger, fd int, vnetHdr, usoEnabled bool, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
tfd, err := newTunFd(fd)
|
qs, err := tio.NewPollQueueSet()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = unix.Close(fd)
|
_ = unix.Close(fd)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
err = qs.Add(fd)
|
||||||
|
if err != nil {
|
||||||
|
_ = unix.Close(fd)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
t := &tun{
|
t := &tun{
|
||||||
tunFile: tfd,
|
readers: qs,
|
||||||
readers: []*tunFile{tfd},
|
|
||||||
closeLock: sync.Mutex{},
|
closeLock: sync.Mutex{},
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||||
@@ -410,32 +270,29 @@ func (t *tun) SupportsMultiqueue() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() error {
|
||||||
t.closeLock.Lock()
|
t.closeLock.Lock()
|
||||||
defer t.closeLock.Unlock()
|
defer t.closeLock.Unlock()
|
||||||
|
|
||||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var req ifReq
|
flags := uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
|
||||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
|
|
||||||
copy(req.Name[:], t.Device)
|
if _, err = tunSetIff(fd, t.Device, flags); err != nil {
|
||||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
|
||||||
_ = unix.Close(fd)
|
_ = unix.Close(fd)
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
out, err := t.tunFile.newFriend(fd)
|
err = t.readers.Add(fd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = unix.Close(fd)
|
_ = unix.Close(fd)
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
t.readers = append(t.readers, out)
|
return nil
|
||||||
|
|
||||||
return out, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||||
@@ -603,6 +460,15 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
|
|||||||
Table: unix.RT_TABLE_MAIN,
|
Table: unix.RT_TABLE_MAIN,
|
||||||
Type: unix.RTN_UNICAST,
|
Type: unix.RTN_UNICAST,
|
||||||
}
|
}
|
||||||
|
// Match the metric the kernel uses for its auto-installed connected
|
||||||
|
// route, so RouteReplace overwrites it in place instead of adding a
|
||||||
|
// second route at a worse metric. IPv6 connected routes are installed
|
||||||
|
// at metric 256 (IP6_RT_PRIO_KERN); IPv4 uses 0. Without this, the
|
||||||
|
// kernel route wins lookups and our MTU / AdvMSS / Features never
|
||||||
|
// apply on v6.
|
||||||
|
if cidr.Addr().Is6() {
|
||||||
|
nr.Priority = 256
|
||||||
|
}
|
||||||
err := netlink.RouteReplace(&nr)
|
err := netlink.RouteReplace(&nr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.Warn("Failed to set default route MTU, retrying", "error", err, "cidr", cidr)
|
t.l.Warn("Failed to set default route MTU, retrying", "error", err, "cidr", cidr)
|
||||||
@@ -869,6 +735,10 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
|||||||
t.routeTree.Store(newTree)
|
t.routeTree.Store(newTree)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) Readers() []tio.Queue {
|
||||||
|
return t.readers.Queues()
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) Close() error {
|
func (t *tun) Close() error {
|
||||||
t.closeLock.Lock()
|
t.closeLock.Lock()
|
||||||
defer t.closeLock.Unlock()
|
defer t.closeLock.Unlock()
|
||||||
@@ -878,32 +748,10 @@ func (t *tun) Close() error {
|
|||||||
t.routeChan = nil
|
t.routeChan = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Signal all readers blocked in poll to wake up and exit
|
|
||||||
_ = t.tunFile.wakeForShutdown()
|
|
||||||
|
|
||||||
if t.ioctlFd > 0 {
|
if t.ioctlFd > 0 {
|
||||||
_ = unix.Close(int(t.ioctlFd))
|
_ = unix.Close(int(t.ioctlFd))
|
||||||
t.ioctlFd = 0
|
t.ioctlFd = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range t.readers {
|
return t.readers.Close()
|
||||||
if i == 0 {
|
|
||||||
continue //we want to close the zeroth reader last
|
|
||||||
}
|
|
||||||
err := t.readers[i].Close()
|
|
||||||
if err != nil {
|
|
||||||
t.l.Error("error closing tun reader", "reader", i, "error", err)
|
|
||||||
} else {
|
|
||||||
t.l.Info("closed tun reader", "reader", i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//this is t.readers[0] too
|
|
||||||
err := t.tunFile.Close()
|
|
||||||
if err != nil {
|
|
||||||
t.l.Error("error closing tun reader", "reader", 0, "error", err)
|
|
||||||
} else {
|
|
||||||
t.l.Info("closed tun reader", "reader", 0)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,9 @@
|
|||||||
|
|
||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import "testing"
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
var runAdvMSSTests = []struct {
|
var runAdvMSSTests = []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ package overlay
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
@@ -17,6 +16,7 @@ import (
|
|||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/overlay/tio"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
netroute "golang.org/x/net/route"
|
netroute "golang.org/x/net/route"
|
||||||
@@ -66,6 +66,22 @@ type tun struct {
|
|||||||
l *slog.Logger
|
l *slog.Logger
|
||||||
f *os.File
|
f *os.File
|
||||||
fd int
|
fd int
|
||||||
|
|
||||||
|
readBuf []byte
|
||||||
|
batchRet [1]tio.Packet
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Read() ([]tio.Packet, error) {
|
||||||
|
n, err := t.readOne(t.readBuf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
|
||||||
|
return t.batchRet[:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Readers() []tio.Queue {
|
||||||
|
return []tio.Queue{t}
|
||||||
}
|
}
|
||||||
|
|
||||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||||
@@ -102,6 +118,7 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t
|
|||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||||
l: l,
|
l: l,
|
||||||
|
readBuf: make([]byte, defaultBatchBufSize),
|
||||||
}
|
}
|
||||||
|
|
||||||
err = t.reload(c, true)
|
err = t.reload(c, true)
|
||||||
@@ -141,7 +158,7 @@ func (t *tun) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Read(to []byte) (int, error) {
|
func (t *tun) readOne(to []byte) (int, error) {
|
||||||
rc, err := t.f.SyscallConn()
|
rc, err := t.f.SyscallConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err)
|
return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err)
|
||||||
@@ -394,8 +411,8 @@ func (t *tun) SupportsMultiqueue() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() error {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
|
return fmt.Errorf("TODO: multiqueue not implemented for netbsd")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) addRoutes(logErrors bool) error {
|
func (t *tun) addRoutes(logErrors bool) error {
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ package overlay
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
@@ -17,6 +16,7 @@ import (
|
|||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/overlay/tio"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
netroute "golang.org/x/net/route"
|
netroute "golang.org/x/net/route"
|
||||||
@@ -59,6 +59,18 @@ type tun struct {
|
|||||||
fd int
|
fd int
|
||||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||||
out []byte
|
out []byte
|
||||||
|
|
||||||
|
readBuf []byte
|
||||||
|
batchRet [1]tio.Packet
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Read() ([]tio.Packet, error) {
|
||||||
|
n, err := t.readOne(t.readBuf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
|
||||||
|
return t.batchRet[:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||||
@@ -95,6 +107,7 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t
|
|||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||||
l: l,
|
l: l,
|
||||||
|
readBuf: make([]byte, defaultBatchBufSize),
|
||||||
}
|
}
|
||||||
|
|
||||||
err = t.reload(c, true)
|
err = t.reload(c, true)
|
||||||
@@ -124,7 +137,7 @@ func (t *tun) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Read(to []byte) (int, error) {
|
func (t *tun) readOne(to []byte) (int, error) {
|
||||||
buf := make([]byte, len(to)+4)
|
buf := make([]byte, len(to)+4)
|
||||||
|
|
||||||
n, err := t.f.Read(buf)
|
n, err := t.f.Read(buf)
|
||||||
@@ -314,8 +327,8 @@ func (t *tun) SupportsMultiqueue() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() error {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
|
return fmt.Errorf("TODO: multiqueue not implemented for openbsd")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) addRoutes(logErrors bool) error {
|
func (t *tun) addRoutes(logErrors bool) error {
|
||||||
@@ -366,6 +379,10 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) Readers() []tio.Queue {
|
||||||
|
return []tio.Queue{t}
|
||||||
|
}
|
||||||
|
|
||||||
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
||||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/overlay/tio"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
@@ -28,6 +29,8 @@ type TestTun struct {
|
|||||||
closed atomic.Bool
|
closed atomic.Bool
|
||||||
rxPackets chan []byte // Packets to receive into nebula
|
rxPackets chan []byte // Packets to receive into nebula
|
||||||
TxPackets chan []byte // Packets transmitted outside by nebula
|
TxPackets chan []byte // Packets transmitted outside by nebula
|
||||||
|
|
||||||
|
batchRet [1]tio.Packet
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
|
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
|
||||||
@@ -48,6 +51,9 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*T
|
|||||||
l: l,
|
l: l,
|
||||||
rxPackets: make(chan []byte, 10),
|
rxPackets: make(chan []byte, 10),
|
||||||
TxPackets: make(chan []byte, 10),
|
TxPackets: make(chan []byte, 10),
|
||||||
|
batchRet: [1]tio.Packet{
|
||||||
|
tio.Packet{Bytes: make([]byte, udp.MTU)},
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -162,7 +168,17 @@ func (t *TestTun) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TestTun) Read(b []byte) (int, error) {
|
func (t *TestTun) Read() ([]tio.Packet, error) {
|
||||||
|
t.batchRet[0].Bytes = t.batchRet[0].Bytes[:udp.MTU]
|
||||||
|
n, err := t.read(t.batchRet[0].Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t.batchRet[0].Bytes = t.batchRet[0].Bytes[:n]
|
||||||
|
return t.batchRet[:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TestTun) read(b []byte) (int, error) {
|
||||||
p, ok := <-t.rxPackets
|
p, ok := <-t.rxPackets
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, os.ErrClosed
|
return 0, os.ErrClosed
|
||||||
@@ -177,10 +193,14 @@ func (t *TestTun) Read(b []byte) (int, error) {
|
|||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *TestTun) Readers() []tio.Queue {
|
||||||
|
return []tio.Queue{t}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *TestTun) SupportsMultiqueue() bool {
|
func (t *TestTun) SupportsMultiqueue() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *TestTun) NewMultiQueueReader() error {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented")
|
return fmt.Errorf("TODO: multiqueue not implemented")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ package overlay
|
|||||||
import (
|
import (
|
||||||
"crypto"
|
"crypto"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
@@ -18,6 +17,7 @@ import (
|
|||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/overlay/tio"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
"github.com/slackhq/nebula/wintun"
|
"github.com/slackhq/nebula/wintun"
|
||||||
@@ -45,6 +45,18 @@ type winTun struct {
|
|||||||
l *slog.Logger
|
l *slog.Logger
|
||||||
|
|
||||||
tun *wintun.NativeTun
|
tun *wintun.NativeTun
|
||||||
|
|
||||||
|
readBuf []byte
|
||||||
|
batchRet [1]tio.Packet
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *winTun) Read() ([]tio.Packet, error) {
|
||||||
|
n, err := t.tun.Read(t.readBuf, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
|
||||||
|
return t.batchRet[:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (Device, error) {
|
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (Device, error) {
|
||||||
@@ -69,6 +81,7 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*w
|
|||||||
}
|
}
|
||||||
|
|
||||||
t := &winTun{
|
t := &winTun{
|
||||||
|
readBuf: make([]byte, defaultBatchBufSize),
|
||||||
Device: deviceName,
|
Device: deviceName,
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||||
@@ -255,10 +268,6 @@ func (t *winTun) Name() string {
|
|||||||
return t.Device
|
return t.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *winTun) Read(b []byte) (int, error) {
|
|
||||||
return t.tun.Read(b, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *winTun) Write(b []byte) (int, error) {
|
func (t *winTun) Write(b []byte) (int, error) {
|
||||||
return t.tun.Write(b, 0)
|
return t.tun.Write(b, 0)
|
||||||
}
|
}
|
||||||
@@ -267,8 +276,12 @@ func (t *winTun) SupportsMultiqueue() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *winTun) NewMultiQueueReader() error {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
|
return fmt.Errorf("TODO: multiqueue not implemented for windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *winTun) Readers() []tio.Queue {
|
||||||
|
return []tio.Queue{t}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *winTun) Close() error {
|
func (t *winTun) Close() error {
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/overlay/tio"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -28,12 +30,28 @@ func NewUserDevice(vpnNetworks []netip.Prefix) (Device, error) {
|
|||||||
|
|
||||||
type UserDevice struct {
|
type UserDevice struct {
|
||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
|
numReaders int
|
||||||
|
|
||||||
outboundReader *io.PipeReader
|
outboundReader *io.PipeReader
|
||||||
outboundWriter *io.PipeWriter
|
outboundWriter *io.PipeWriter
|
||||||
|
|
||||||
inboundReader *io.PipeReader
|
inboundReader *io.PipeReader
|
||||||
inboundWriter *io.PipeWriter
|
inboundWriter *io.PipeWriter
|
||||||
|
|
||||||
|
readBuf []byte
|
||||||
|
batchRet [1]tio.Packet
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *UserDevice) Read() ([]tio.Packet, error) {
|
||||||
|
if d.readBuf == nil {
|
||||||
|
d.readBuf = make([]byte, defaultBatchBufSize)
|
||||||
|
}
|
||||||
|
n, err := d.outboundReader.Read(d.readBuf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
d.batchRet[0] = tio.Packet{Bytes: d.readBuf[:n]}
|
||||||
|
return d.batchRet[:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *UserDevice) Activate() error {
|
func (d *UserDevice) Activate() error {
|
||||||
@@ -47,23 +65,25 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *UserDevice) SupportsMultiqueue() bool {
|
func (d *UserDevice) SupportsMultiqueue() bool {
|
||||||
return true
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (d *UserDevice) NewMultiQueueReader() error {
|
||||||
return d, nil
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *UserDevice) Readers() []tio.Queue {
|
||||||
|
return []tio.Queue{d}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) {
|
func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) {
|
||||||
return d.inboundReader, d.outboundWriter
|
return d.inboundReader, d.outboundWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *UserDevice) Read(p []byte) (n int, err error) {
|
|
||||||
return d.outboundReader.Read(p)
|
|
||||||
}
|
|
||||||
func (d *UserDevice) Write(p []byte) (n int, err error) {
|
func (d *UserDevice) Write(p []byte) (n int, err error) {
|
||||||
return d.inboundWriter.Write(p)
|
return d.inboundWriter.Write(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *UserDevice) Close() error {
|
func (d *UserDevice) Close() error {
|
||||||
d.inboundWriter.Close()
|
d.inboundWriter.Close()
|
||||||
d.outboundWriter.Close()
|
d.outboundWriter.Close()
|
||||||
|
|||||||
40
udp/conn.go
40
udp/conn.go
@@ -8,16 +8,49 @@ import (
|
|||||||
|
|
||||||
const MTU = 9001
|
const MTU = 9001
|
||||||
|
|
||||||
|
// MaxWriteBatch is the largest batch any Conn.WriteBatch implementation is
|
||||||
|
// required to accept. Callers SHOULD NOT pass more than this per call; Linux
|
||||||
|
// backends preallocate sendmmsg scratch sized to this value, so exceeding it
|
||||||
|
// only costs additional sendmmsg chunks within a single WriteBatch call.
|
||||||
|
const MaxWriteBatch = 128
|
||||||
|
|
||||||
|
// RxMeta carries per-packet metadata extracted from the RX path (ancillary
|
||||||
|
// data, kernel offload state, etc.) and passed to EncReader callbacks.
|
||||||
|
// Backends that do not produce a particular signal leave its zero value.
|
||||||
|
//
|
||||||
|
// OuterECN is the 2-bit IP-level ECN codepoint stamped on the carrier
|
||||||
|
// datagram (extracted from IP_TOS / IPV6_TCLASS cmsg on Linux). Zero
|
||||||
|
// means Not-ECT, which is also the value backends without ECN RX support
|
||||||
|
// supply on every packet.
|
||||||
|
type RxMeta struct {
|
||||||
|
OuterECN byte
|
||||||
|
}
|
||||||
|
|
||||||
type EncReader func(
|
type EncReader func(
|
||||||
addr netip.AddrPort,
|
addr netip.AddrPort,
|
||||||
payload []byte,
|
payload []byte,
|
||||||
|
meta RxMeta,
|
||||||
)
|
)
|
||||||
|
|
||||||
type Conn interface {
|
type Conn interface {
|
||||||
Rebind() error
|
Rebind() error
|
||||||
LocalAddr() (netip.AddrPort, error)
|
LocalAddr() (netip.AddrPort, error)
|
||||||
ListenOut(r EncReader) error
|
// ListenOut invokes r for each received packet. On batch-capable
|
||||||
|
// backends (recvmmsg), flush is called after each batch is fully
|
||||||
|
// delivered — callers use it to flush per-batch accumulators such as
|
||||||
|
// TUN write coalescers. Single-packet backends call flush after each
|
||||||
|
// packet. flush must not be nil.
|
||||||
|
ListenOut(r EncReader, flush func()) error
|
||||||
WriteTo(b []byte, addr netip.AddrPort) error
|
WriteTo(b []byte, addr netip.AddrPort) error
|
||||||
|
// WriteBatch sends a contiguous batch of packets, each with its own
|
||||||
|
// destination. bufs and addrs must have the same length. outerECNs may
|
||||||
|
// be nil (treated as all-zero / Not-ECT); when non-nil it must have the
|
||||||
|
// same length as bufs, and outerECNs[i] is the 2-bit IP-level ECN
|
||||||
|
// codepoint to set on packet i's outer header. Linux uses sendmmsg(2)
|
||||||
|
// for a single syscall and attaches the value as IP_TOS / IPV6_TCLASS
|
||||||
|
// cmsg; other backends ignore it. Returns on the first error; callers
|
||||||
|
// may observe a partial send if some packets went out before the error.
|
||||||
|
WriteBatch(bufs [][]byte, addrs []netip.AddrPort, outerECNs []byte) error
|
||||||
ReloadConfig(c *config.C)
|
ReloadConfig(c *config.C)
|
||||||
SupportsMultipleReaders() bool
|
SupportsMultipleReaders() bool
|
||||||
Close() error
|
Close() error
|
||||||
@@ -31,7 +64,7 @@ func (NoopConn) Rebind() error {
|
|||||||
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
|
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
|
||||||
return netip.AddrPort{}, nil
|
return netip.AddrPort{}, nil
|
||||||
}
|
}
|
||||||
func (NoopConn) ListenOut(_ EncReader) error {
|
func (NoopConn) ListenOut(_ EncReader, _ func()) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (NoopConn) SupportsMultipleReaders() bool {
|
func (NoopConn) SupportsMultipleReaders() bool {
|
||||||
@@ -40,6 +73,9 @@ func (NoopConn) SupportsMultipleReaders() bool {
|
|||||||
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (NoopConn) WriteBatch(_ [][]byte, _ []netip.AddrPort, _ []byte) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
func (NoopConn) ReloadConfig(_ *config.C) {
|
func (NoopConn) ReloadConfig(_ *config.C) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -140,6 +140,15 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error {
|
||||||
|
for i, b := range bufs {
|
||||||
|
if err := u.WriteTo(b, addrs[i]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
||||||
a := u.UDPConn.LocalAddr()
|
a := u.UDPConn.LocalAddr()
|
||||||
|
|
||||||
@@ -165,7 +174,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() {
|
|||||||
return func() {}
|
return func() {}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) ListenOut(r EncReader) error {
|
func (u *StdConn) ListenOut(r EncReader, flush func()) error {
|
||||||
buffer := make([]byte, MTU)
|
buffer := make([]byte, MTU)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -179,7 +188,8 @@ func (u *StdConn) ListenOut(r EncReader) error {
|
|||||||
u.l.Error("unexpected udp socket receive error", "error", err)
|
u.l.Error("unexpected udp socket receive error", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n], RxMeta{})
|
||||||
|
flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -44,6 +44,15 @@ func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *GenericConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error {
|
||||||
|
for i, b := range bufs {
|
||||||
|
if _, err := u.UDPConn.WriteToUDPAddrPort(b, addrs[i]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (u *GenericConn) LocalAddr() (netip.AddrPort, error) {
|
func (u *GenericConn) LocalAddr() (netip.AddrPort, error) {
|
||||||
a := u.UDPConn.LocalAddr()
|
a := u.UDPConn.LocalAddr()
|
||||||
|
|
||||||
@@ -73,7 +82,7 @@ type rawMessage struct {
|
|||||||
Len uint32
|
Len uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *GenericConn) ListenOut(r EncReader) error {
|
func (u *GenericConn) ListenOut(r EncReader, flush func()) error {
|
||||||
buffer := make([]byte, MTU)
|
buffer := make([]byte, MTU)
|
||||||
|
|
||||||
var lastRecvErr time.Time
|
var lastRecvErr time.Time
|
||||||
@@ -93,7 +102,8 @@ func (u *GenericConn) ListenOut(r EncReader) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n], RxMeta{})
|
||||||
|
flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
185
udp/udp_linux.go
185
udp/udp_linux.go
@@ -24,6 +24,22 @@ type StdConn struct {
|
|||||||
isV4 bool
|
isV4 bool
|
||||||
l *slog.Logger
|
l *slog.Logger
|
||||||
batch int
|
batch int
|
||||||
|
|
||||||
|
// sendmmsg scratch. Each queue has its own StdConn, so no locking is
|
||||||
|
// needed. Sized to MaxWriteBatch at construction; WriteBatch chunks
|
||||||
|
// larger inputs.
|
||||||
|
writeMsgs []rawMessage
|
||||||
|
writeIovs []iovec
|
||||||
|
writeNames [][]byte
|
||||||
|
|
||||||
|
// sendmmsg(2) callback state. sendmmsgCB is bound once in NewListener
|
||||||
|
// to the sendmmsgRun method value so passing it to rawConn.Write does
|
||||||
|
// not allocate a fresh closure per send; sendmmsgN/Sent/Errno carry
|
||||||
|
// the inputs and outputs across the call without escaping locals.
|
||||||
|
sendmmsgCB func(fd uintptr) bool
|
||||||
|
sendmmsgN int
|
||||||
|
sendmmsgSent int
|
||||||
|
sendmmsgErrno syscall.Errno
|
||||||
}
|
}
|
||||||
|
|
||||||
func setReusePort(network, address string, c syscall.RawConn) error {
|
func setReusePort(network, address string, c syscall.RawConn) error {
|
||||||
@@ -70,9 +86,23 @@ func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int)
|
|||||||
}
|
}
|
||||||
out.isV4 = af == unix.AF_INET
|
out.isV4 = af == unix.AF_INET
|
||||||
|
|
||||||
|
out.prepareWriteMessages(MaxWriteBatch)
|
||||||
|
out.sendmmsgCB = out.sendmmsgRun
|
||||||
|
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) prepareWriteMessages(n int) {
|
||||||
|
u.writeMsgs = make([]rawMessage, n)
|
||||||
|
u.writeIovs = make([]iovec, n)
|
||||||
|
u.writeNames = make([][]byte, n)
|
||||||
|
|
||||||
|
for i := range u.writeMsgs {
|
||||||
|
u.writeNames[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||||
|
u.writeMsgs[i].Hdr.Name = &u.writeNames[i][0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (u *StdConn) SupportsMultipleReaders() bool {
|
func (u *StdConn) SupportsMultipleReaders() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -171,7 +201,7 @@ func recvmmsg(fd uintptr, msgs []rawMessage) (int, bool, error) {
|
|||||||
return int(n), true, nil
|
return int(n), true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) listenOutSingle(r EncReader) error {
|
func (u *StdConn) listenOutSingle(r EncReader, flush func()) error {
|
||||||
var err error
|
var err error
|
||||||
var n int
|
var n int
|
||||||
var from netip.AddrPort
|
var from netip.AddrPort
|
||||||
@@ -183,16 +213,33 @@ func (u *StdConn) listenOutSingle(r EncReader) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port())
|
from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port())
|
||||||
r(from, buffer[:n])
|
// listenOutSingle uses ReadFromUDPAddrPort which discards cmsgs,
|
||||||
|
// so the outer ECN field is not visible on this path. Zero RxMeta
|
||||||
|
// (Not-ECT) means RFC 6040 combine is a no-op.
|
||||||
|
r(from, buffer[:n], RxMeta{})
|
||||||
|
flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) listenOutBatch(r EncReader) error {
|
// readSockaddr decodes the source address out of a recvmmsg name buffer
|
||||||
|
func (u *StdConn) readSockaddr(name []byte) netip.AddrPort {
|
||||||
var ip netip.Addr
|
var ip netip.Addr
|
||||||
|
// It's ok to skip the ok check here, the slicing is the only error that can occur and it will panic
|
||||||
|
if u.isV4 {
|
||||||
|
ip, _ = netip.AddrFromSlice(name[4:8])
|
||||||
|
} else {
|
||||||
|
ip, _ = netip.AddrFromSlice(name[8:24])
|
||||||
|
}
|
||||||
|
return netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(name[2:4]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) listenOutBatch(r EncReader, flush func()) error {
|
||||||
var n int
|
var n int
|
||||||
var operr error
|
var operr error
|
||||||
|
|
||||||
msgs, buffers, names := u.PrepareRawMessages(u.batch)
|
bufSize := MTU
|
||||||
|
cmsgSpace := 0
|
||||||
|
msgs, buffers, names, _ := u.PrepareRawMessages(u.batch, bufSize, cmsgSpace)
|
||||||
|
|
||||||
//reader needs to capture variables from this function, since it's used as a lambda with rawConn.Read
|
//reader needs to capture variables from this function, since it's used as a lambda with rawConn.Read
|
||||||
//defining it outside the loop so it gets re-used
|
//defining it outside the loop so it gets re-used
|
||||||
@@ -211,22 +258,18 @@ func (u *StdConn) listenOutBatch(r EncReader) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
|
r(u.readSockaddr(names[i]), buffers[i][:msgs[i].Len], RxMeta{})
|
||||||
if u.isV4 {
|
|
||||||
ip, _ = netip.AddrFromSlice(names[i][4:8])
|
|
||||||
} else {
|
|
||||||
ip, _ = netip.AddrFromSlice(names[i][8:24])
|
|
||||||
}
|
|
||||||
r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) ListenOut(r EncReader) error {
|
func (u *StdConn) ListenOut(r EncReader, flush func()) error {
|
||||||
if u.batch == 1 {
|
if u.batch == 1 {
|
||||||
return u.listenOutSingle(r)
|
return u.listenOutSingle(r, flush)
|
||||||
} else {
|
} else {
|
||||||
return u.listenOutBatch(r)
|
return u.listenOutBatch(r, flush)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -235,6 +278,120 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WriteBatch sends bufs via sendmmsg(2) using the preallocated scratch on
|
||||||
|
// StdConn. If supported, consecutive packets to the same destination with
|
||||||
|
// matching segment sizes (all but possibly the last) are coalesced into a
|
||||||
|
// single mmsghdr entry
|
||||||
|
//
|
||||||
|
// If sendmmsg returns an error and zero entries went out, we fall back to
|
||||||
|
// per-packet WriteTo for that chunk so the caller still gets best-effort
|
||||||
|
// delivery. On a partial send we resume at the first un-acked entry on
|
||||||
|
// the next iteration.
|
||||||
|
func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error {
|
||||||
|
for i := 0; i < len(bufs); {
|
||||||
|
chunk := min(len(bufs)-i, len(u.writeMsgs))
|
||||||
|
|
||||||
|
for k := 0; k < chunk; k++ {
|
||||||
|
u.writeIovs[k].Base = &bufs[i+k][0]
|
||||||
|
setIovLen(&u.writeIovs[k], len(bufs[i+k]))
|
||||||
|
|
||||||
|
nlen, err := writeSockaddr(u.writeNames[k], addrs[i+k], u.isV4)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
hdr := &u.writeMsgs[k].Hdr
|
||||||
|
hdr.Iov = &u.writeIovs[k]
|
||||||
|
setMsgIovlen(hdr, 1)
|
||||||
|
hdr.Namelen = uint32(nlen)
|
||||||
|
}
|
||||||
|
|
||||||
|
sent, serr := u.sendmmsg(chunk)
|
||||||
|
if serr != nil && sent <= 0 {
|
||||||
|
// sendmmsg returns -1 / sent=0 when entry 0 itself failed; log
|
||||||
|
// that entry's destination and fall back to per-packet WriteTo
|
||||||
|
// for the whole chunk so the caller still gets best-effort
|
||||||
|
// delivery without duplicating packets the kernel accepted.
|
||||||
|
u.l.Warn("sendmmsg failed, falling back to per-packet WriteTo",
|
||||||
|
"err", serr,
|
||||||
|
"entries", chunk,
|
||||||
|
"entry0_dst", addrs[i],
|
||||||
|
"isV4", u.isV4,
|
||||||
|
)
|
||||||
|
for k := 0; k < chunk; k++ {
|
||||||
|
if werr := u.WriteTo(bufs[i+k], addrs[i+k]); werr != nil {
|
||||||
|
return werr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
i += chunk
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
i += sent
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendmmsg issues sendmmsg(2) against the first n entries of u.writeMsgs.
|
||||||
|
// The bound u.sendmmsgCB is passed to rawConn.Write so no closure is
|
||||||
|
// allocated per call; inputs and outputs ride on the StdConn fields.
|
||||||
|
func (u *StdConn) sendmmsg(n int) (int, error) {
|
||||||
|
u.sendmmsgN = n
|
||||||
|
u.sendmmsgSent = 0
|
||||||
|
u.sendmmsgErrno = 0
|
||||||
|
if err := u.rawConn.Write(u.sendmmsgCB); err != nil {
|
||||||
|
return u.sendmmsgSent, err
|
||||||
|
}
|
||||||
|
if u.sendmmsgErrno != 0 {
|
||||||
|
return u.sendmmsgSent, &net.OpError{Op: "sendmmsg", Err: u.sendmmsgErrno}
|
||||||
|
}
|
||||||
|
return u.sendmmsgSent, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendmmsgRun is the rawConn.Write callback. It is bound once into
|
||||||
|
// u.sendmmsgCB at construction so it stays alloc-free in the hot path;
|
||||||
|
// inputs (sendmmsgN) and outputs (sendmmsgSent, sendmmsgErrno) ride on
|
||||||
|
// the receiver rather than escaping locals.
|
||||||
|
func (u *StdConn) sendmmsgRun(fd uintptr) bool {
|
||||||
|
r1, _, errno := unix.Syscall6(unix.SYS_SENDMMSG, fd,
|
||||||
|
uintptr(unsafe.Pointer(&u.writeMsgs[0])), uintptr(u.sendmmsgN),
|
||||||
|
0, 0, 0,
|
||||||
|
)
|
||||||
|
if errno == syscall.EAGAIN || errno == syscall.EWOULDBLOCK {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
u.sendmmsgSent = int(r1)
|
||||||
|
u.sendmmsgErrno = errno
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeSockaddr encodes addr into buf (which must be at least
|
||||||
|
// SizeofSockaddrInet6 bytes). Returns the number of bytes used. If isV4 is
|
||||||
|
// true and addr is not a v4 (or v4-in-v6) address, returns an error.
|
||||||
|
func writeSockaddr(buf []byte, addr netip.AddrPort, isV4 bool) (int, error) {
|
||||||
|
ap := addr.Addr().Unmap()
|
||||||
|
if isV4 {
|
||||||
|
if !ap.Is4() {
|
||||||
|
return 0, ErrInvalidIPv6RemoteForSocket
|
||||||
|
}
|
||||||
|
// struct sockaddr_in: { sa_family_t(2), in_port_t(2, BE), in_addr(4), zero(8) }
|
||||||
|
// sa_family is host endian.
|
||||||
|
binary.NativeEndian.PutUint16(buf[0:2], unix.AF_INET)
|
||||||
|
binary.BigEndian.PutUint16(buf[2:4], addr.Port())
|
||||||
|
ip4 := ap.As4()
|
||||||
|
copy(buf[4:8], ip4[:])
|
||||||
|
clear(buf[8:16])
|
||||||
|
return unix.SizeofSockaddrInet4, nil
|
||||||
|
}
|
||||||
|
// struct sockaddr_in6: { sa_family_t(2), in_port_t(2, BE), flowinfo(4), in6_addr(16), scope_id(4) }
|
||||||
|
binary.NativeEndian.PutUint16(buf[0:2], unix.AF_INET6)
|
||||||
|
binary.BigEndian.PutUint16(buf[2:4], addr.Port())
|
||||||
|
binary.NativeEndian.PutUint32(buf[4:8], 0)
|
||||||
|
ip6 := addr.Addr().As16()
|
||||||
|
copy(buf[8:24], ip6[:])
|
||||||
|
binary.NativeEndian.PutUint32(buf[24:28], 0)
|
||||||
|
return unix.SizeofSockaddrInet6, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (u *StdConn) ReloadConfig(c *config.C) {
|
func (u *StdConn) ReloadConfig(c *config.C) {
|
||||||
b := c.GetInt("listen.read_buffer", 0)
|
b := c.GetInt("listen.read_buffer", 0)
|
||||||
if b > 0 {
|
if b > 0 {
|
||||||
|
|||||||
@@ -30,13 +30,18 @@ type rawMessage struct {
|
|||||||
Len uint32
|
Len uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
func (u *StdConn) PrepareRawMessages(n, bufSize, cmsgSpace int) ([]rawMessage, [][]byte, [][]byte, []byte) {
|
||||||
msgs := make([]rawMessage, n)
|
msgs := make([]rawMessage, n)
|
||||||
buffers := make([][]byte, n)
|
buffers := make([][]byte, n)
|
||||||
names := make([][]byte, n)
|
names := make([][]byte, n)
|
||||||
|
|
||||||
|
var cmsgs []byte
|
||||||
|
if cmsgSpace > 0 {
|
||||||
|
cmsgs = make([]byte, n*cmsgSpace)
|
||||||
|
}
|
||||||
|
|
||||||
for i := range msgs {
|
for i := range msgs {
|
||||||
buffers[i] = make([]byte, MTU)
|
buffers[i] = make([]byte, bufSize)
|
||||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||||
|
|
||||||
vs := []iovec{
|
vs := []iovec{
|
||||||
@@ -48,7 +53,28 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
|||||||
|
|
||||||
msgs[i].Hdr.Name = &names[i][0]
|
msgs[i].Hdr.Name = &names[i][0]
|
||||||
msgs[i].Hdr.Namelen = uint32(len(names[i]))
|
msgs[i].Hdr.Namelen = uint32(len(names[i]))
|
||||||
|
|
||||||
|
if cmsgSpace > 0 {
|
||||||
|
msgs[i].Hdr.Control = &cmsgs[i*cmsgSpace]
|
||||||
|
msgs[i].Hdr.Controllen = uint32(cmsgSpace)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return msgs, buffers, names
|
return msgs, buffers, names, cmsgs
|
||||||
|
}
|
||||||
|
|
||||||
|
func setIovLen(v *iovec, n int) {
|
||||||
|
v.Len = uint32(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setMsgIovlen(m *msghdr, n int) {
|
||||||
|
m.Iovlen = uint32(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setMsgControllen(m *msghdr, n int) {
|
||||||
|
m.Controllen = uint32(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setCmsgLen(h *unix.Cmsghdr, n int) {
|
||||||
|
h.Len = uint32(n)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,13 +33,18 @@ type rawMessage struct {
|
|||||||
Pad0 [4]byte
|
Pad0 [4]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
func (u *StdConn) PrepareRawMessages(n, bufSize, cmsgSpace int) ([]rawMessage, [][]byte, [][]byte, []byte) {
|
||||||
msgs := make([]rawMessage, n)
|
msgs := make([]rawMessage, n)
|
||||||
buffers := make([][]byte, n)
|
buffers := make([][]byte, n)
|
||||||
names := make([][]byte, n)
|
names := make([][]byte, n)
|
||||||
|
|
||||||
|
var cmsgs []byte
|
||||||
|
if cmsgSpace > 0 {
|
||||||
|
cmsgs = make([]byte, n*cmsgSpace)
|
||||||
|
}
|
||||||
|
|
||||||
for i := range msgs {
|
for i := range msgs {
|
||||||
buffers[i] = make([]byte, MTU)
|
buffers[i] = make([]byte, bufSize)
|
||||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||||
|
|
||||||
vs := []iovec{
|
vs := []iovec{
|
||||||
@@ -51,7 +56,28 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
|||||||
|
|
||||||
msgs[i].Hdr.Name = &names[i][0]
|
msgs[i].Hdr.Name = &names[i][0]
|
||||||
msgs[i].Hdr.Namelen = uint32(len(names[i]))
|
msgs[i].Hdr.Namelen = uint32(len(names[i]))
|
||||||
|
|
||||||
|
if cmsgSpace > 0 {
|
||||||
|
msgs[i].Hdr.Control = &cmsgs[i*cmsgSpace]
|
||||||
|
msgs[i].Hdr.Controllen = uint64(cmsgSpace)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return msgs, buffers, names
|
return msgs, buffers, names, cmsgs
|
||||||
|
}
|
||||||
|
|
||||||
|
func setIovLen(v *iovec, n int) {
|
||||||
|
v.Len = uint64(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setMsgIovlen(m *msghdr, n int) {
|
||||||
|
m.Iovlen = uint64(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setMsgControllen(m *msghdr, n int) {
|
||||||
|
m.Controllen = uint64(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setCmsgLen(h *unix.Cmsghdr, n int) {
|
||||||
|
h.Len = uint64(n)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ func (u *RIOConn) bind(l *slog.Logger, sa windows.Sockaddr) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *RIOConn) ListenOut(r EncReader) error {
|
func (u *RIOConn) ListenOut(r EncReader, flush func()) error {
|
||||||
buffer := make([]byte, MTU)
|
buffer := make([]byte, MTU)
|
||||||
|
|
||||||
var lastRecvErr time.Time
|
var lastRecvErr time.Time
|
||||||
@@ -161,7 +161,8 @@ func (u *RIOConn) ListenOut(r EncReader) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n])
|
r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n], RxMeta{})
|
||||||
|
flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -316,6 +317,15 @@ func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error {
|
|||||||
return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
|
return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *RIOConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error {
|
||||||
|
for i, b := range bufs {
|
||||||
|
if err := u.WriteTo(b, addrs[i]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
|
func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
|
||||||
sa, err := windows.Getsockname(u.sock)
|
sa, err := windows.Getsockname(u.sock)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -157,15 +157,24 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
func (u *TesterConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error {
|
||||||
|
for i, b := range bufs {
|
||||||
|
if err := u.WriteTo(b, addrs[i]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (u *TesterConn) ListenOut(r EncReader) error {
|
func (u *TesterConn) ListenOut(r EncReader, flush func()) error {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-u.done:
|
case <-u.done:
|
||||||
return os.ErrClosed
|
return os.ErrClosed
|
||||||
case p := <-u.RxPackets:
|
case p := <-u.RxPackets:
|
||||||
r(p.From, p.Data)
|
r(p.From, p.Data, RxMeta{})
|
||||||
p.Release()
|
p.Release()
|
||||||
|
flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user