mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-16 04:47:38 +02:00
better and batched tun interface
This commit is contained in:
75
inside.go
75
inside.go
@@ -9,10 +9,11 @@ import (
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/noiseutil"
|
||||
"github.com/slackhq/nebula/overlay/batch"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, sendBatch batch.TxBatcher, rejectBuf []byte, q int, localCache firewall.ConntrackCache) {
|
||||
err := newPacket(packet, false, fwPacket)
|
||||
if err != nil {
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
@@ -37,7 +38,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
||||
// routes packets from the Nebula addr to the Nebula addr through the Nebula
|
||||
// TUN device.
|
||||
if immediatelyForwardToSelf {
|
||||
_, err := f.readers[q].Write(packet)
|
||||
_, err := f.readers[q].WriteFromSelf(packet)
|
||||
if err != nil {
|
||||
f.l.Error("Failed to forward to tun", "error", err)
|
||||
}
|
||||
@@ -57,7 +58,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
||||
})
|
||||
|
||||
if hostinfo == nil {
|
||||
f.rejectInside(packet, out, q)
|
||||
f.rejectInside(packet, rejectBuf, q)
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
f.l.Debug("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks",
|
||||
"vpnAddr", fwPacket.RemoteAddr,
|
||||
@@ -73,10 +74,9 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
||||
|
||||
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
if dropReason == nil {
|
||||
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
|
||||
|
||||
f.sendInsideMessage(hostinfo, packet, nb, sendBatch, rejectBuf, q)
|
||||
} else {
|
||||
f.rejectInside(packet, out, q)
|
||||
f.rejectInside(packet, rejectBuf, q)
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
hostinfo.logger(f.l).Debug("dropping outbound packet",
|
||||
"fwPacket", fwPacket,
|
||||
@@ -86,6 +86,67 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
||||
}
|
||||
}
|
||||
|
||||
// sendInsideMessage encrypts a firewall-approved inside packet into the
|
||||
// caller's batch slot for later sendmmsg flush. When hostinfo.remote is not
|
||||
// valid we fall through to the relay slow path via the unbatched sendNoMetrics
|
||||
// so relay behavior is unchanged.
|
||||
func (f *Interface) sendInsideMessage(hostinfo *HostInfo, p, nb []byte, sendBatch batch.TxBatcher, rejectBuf []byte, q int) {
|
||||
ci := hostinfo.ConnectionState
|
||||
if ci.eKey == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !hostinfo.remote.IsValid() {
|
||||
// Slow path: relay fallback. Reuse rejectBuf as the ciphertext
|
||||
// scratch; sendNoMetrics arranges header space for SendVia.
|
||||
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, p, nb, rejectBuf, q)
|
||||
return
|
||||
}
|
||||
|
||||
scratch := sendBatch.Next()
|
||||
if scratch == nil {
|
||||
// Batch full: bypass batching and send this packet directly so we
|
||||
// never drop traffic on over-subscribed iterations.
|
||||
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, p, nb, rejectBuf, q)
|
||||
return
|
||||
}
|
||||
|
||||
if noiseutil.EncryptLockNeeded {
|
||||
ci.writeLock.Lock()
|
||||
}
|
||||
c := ci.messageCounter.Add(1)
|
||||
|
||||
out := header.Encode(scratch, header.Version, header.Message, 0, hostinfo.remoteIndexId, c)
|
||||
f.connectionManager.Out(hostinfo)
|
||||
|
||||
if hostinfo.lastRebindCount != f.rebindCount {
|
||||
//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
|
||||
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
|
||||
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
|
||||
hostinfo.lastRebindCount = f.rebindCount
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
hostinfo.logger(f.l).Debug("Lighthouse update triggered for punch due to rebind counter",
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
out, err := ci.eKey.EncryptDanger(out, out, p, c, nb)
|
||||
if noiseutil.EncryptLockNeeded {
|
||||
ci.writeLock.Unlock()
|
||||
}
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).Error("Failed to encrypt outgoing packet",
|
||||
"error", err,
|
||||
"udpAddr", hostinfo.remote,
|
||||
"counter", c,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
sendBatch.Commit(len(out), hostinfo.remote)
|
||||
}
|
||||
|
||||
func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
|
||||
if !f.firewall.InSendReject {
|
||||
return
|
||||
@@ -96,7 +157,7 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
|
||||
return
|
||||
}
|
||||
|
||||
_, err := f.readers[q].Write(out)
|
||||
_, err := f.readers[q].WriteFromSelf(out)
|
||||
if err != nil {
|
||||
f.l.Error("Failed to write to tun", "error", err)
|
||||
}
|
||||
|
||||
64
interface.go
64
interface.go
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"sync"
|
||||
@@ -18,6 +17,8 @@ import (
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/overlay"
|
||||
"github.com/slackhq/nebula/overlay/batch"
|
||||
"github.com/slackhq/nebula/overlay/tio"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
)
|
||||
|
||||
@@ -88,7 +89,11 @@ type Interface struct {
|
||||
|
||||
ctx context.Context
|
||||
writers []udp.Conn
|
||||
readers []io.ReadWriteCloser
|
||||
readers []tio.Queue
|
||||
// batchers is one per tun queue, wrapping readers[i].
|
||||
// decryptToTun sends plaintext into the batch.RxBatcher;
|
||||
// listenOut calls its Flush at the end of each UDP recvmmsg batch.
|
||||
batchers []batch.RxBatcher
|
||||
wg sync.WaitGroup
|
||||
|
||||
// fatalErr holds the first unexpected reader error that caused shutdown.
|
||||
@@ -187,7 +192,8 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
||||
routines: c.routines,
|
||||
version: c.version,
|
||||
writers: make([]udp.Conn, c.routines),
|
||||
readers: make([]io.ReadWriteCloser, c.routines),
|
||||
readers: make([]tio.Queue, c.routines),
|
||||
batchers: make([]batch.RxBatcher, c.routines),
|
||||
myVpnNetworks: cs.myVpnNetworks,
|
||||
myVpnNetworksTable: cs.myVpnNetworksTable,
|
||||
myVpnAddrs: cs.myVpnAddrs,
|
||||
@@ -245,15 +251,16 @@ func (f *Interface) activate() error {
|
||||
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
|
||||
|
||||
// Prepare n tun queues
|
||||
var reader io.ReadWriteCloser = f.inside
|
||||
for i := 0; i < f.routines; i++ {
|
||||
if i > 0 {
|
||||
reader, err = f.inside.NewMultiQueueReader()
|
||||
if err != nil {
|
||||
if err = f.inside.NewMultiQueueReader(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
f.readers[i] = reader
|
||||
}
|
||||
f.readers = f.inside.Readers()
|
||||
for i := range f.readers {
|
||||
f.batchers[i] = batch.NewPassthrough(f.readers[i])
|
||||
}
|
||||
|
||||
f.wg.Add(1) // for us to wait on Close() to return
|
||||
@@ -311,14 +318,24 @@ func (f *Interface) listenOut(i int) {
|
||||
|
||||
ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
|
||||
lhh := f.lightHouse.NewRequestHandler()
|
||||
plaintext := make([]byte, udp.MTU)
|
||||
h := &header.H{}
|
||||
fwPacket := &firewall.Packet{}
|
||||
nb := make([]byte, 12, 12)
|
||||
|
||||
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
||||
coalescer := f.batchers[i]
|
||||
|
||||
listener := func(fromUdpAddr netip.AddrPort, payload []byte) {
|
||||
plaintext := f.batchers[i].Reserve(len(payload))
|
||||
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get())
|
||||
})
|
||||
}
|
||||
|
||||
flusher := func() {
|
||||
if err := coalescer.Flush(); err != nil {
|
||||
f.l.Error("Failed to flush tun coalescer", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
err := li.ListenOut(listener, flusher)
|
||||
|
||||
if err != nil && !f.closed.Load() {
|
||||
f.l.Error("Error while reading inbound packet, closing", "error", err)
|
||||
@@ -328,16 +345,16 @@ func (f *Interface) listenOut(i int) {
|
||||
f.l.Debug("underlay reader is done", "reader", i)
|
||||
}
|
||||
|
||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||
packet := make([]byte, mtu)
|
||||
out := make([]byte, mtu)
|
||||
func (f *Interface) listenIn(reader tio.Queue, i int) {
|
||||
rejectBuf := make([]byte, mtu)
|
||||
sb := batch.NewSendBatch(batch.SendBatchCap, udp.MTU+32)
|
||||
fwPacket := &firewall.Packet{}
|
||||
nb := make([]byte, 12, 12)
|
||||
|
||||
conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
|
||||
|
||||
for {
|
||||
n, err := reader.Read(packet)
|
||||
pkts, err := reader.Read()
|
||||
if err != nil {
|
||||
if !f.closed.Load() {
|
||||
f.l.Error("Error while reading outbound packet, closing", "error", err, "reader", i)
|
||||
@@ -346,12 +363,29 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||
break
|
||||
}
|
||||
|
||||
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get())
|
||||
sb.Reset()
|
||||
for _, pkt := range pkts {
|
||||
if sb.Len() >= sb.Cap() {
|
||||
f.flushBatch(sb, i)
|
||||
sb.Reset()
|
||||
}
|
||||
f.consumeInsidePacket(pkt, fwPacket, nb, sb, rejectBuf, i, conntrackCache.Get())
|
||||
}
|
||||
if sb.Len() > 0 {
|
||||
f.flushBatch(sb, i)
|
||||
}
|
||||
}
|
||||
|
||||
f.l.Debug("overlay reader is done", "reader", i)
|
||||
}
|
||||
|
||||
func (f *Interface) flushBatch(sb batch.TxBatcher, q int) {
|
||||
bufs, dsts := sb.Get()
|
||||
if err := f.writers[q].WriteBatch(bufs, dsts); err != nil {
|
||||
f.l.Error("Failed to write outgoing batch", "error", err, "writer", q)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
||||
c.RegisterReloadCallback(f.reloadFirewall)
|
||||
c.RegisterReloadCallback(f.reloadSendRecvError)
|
||||
|
||||
@@ -572,7 +572,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
||||
}
|
||||
|
||||
f.connectionManager.In(hostinfo)
|
||||
_, err = f.readers[q].Write(out)
|
||||
err = f.batchers[q].Commit(out)
|
||||
if err != nil {
|
||||
f.l.Error("Failed to write to tun", "error", err)
|
||||
}
|
||||
|
||||
33
overlay/batch/batch.go
Normal file
33
overlay/batch/batch.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package batch
|
||||
|
||||
import "net/netip"
|
||||
|
||||
type RxBatcher interface {
|
||||
// Reserve creates a pkt to borrow
|
||||
Reserve(sz int) []byte
|
||||
// Commit borrows pkt. The caller must keep pkt valid until the next Flush
|
||||
Commit(pkt []byte) error
|
||||
// Flush emits every queued packet in arrival order. Returns the
|
||||
// first error observed; keeps draining so one bad packet doesn't hold up
|
||||
// the rest. After Flush returns, borrowed payload slices may be recycled.
|
||||
Flush() error
|
||||
}
|
||||
|
||||
type TxBatcher interface {
|
||||
// Next returns a zero-length slice with slotCap capacity over the next unused
|
||||
// slot's backing bytes. The caller writes into the returned slice and then
|
||||
// calls Commit with the final length and destination. Next returns nil when
|
||||
// the batch is full.
|
||||
Next() []byte
|
||||
// Commit records the slot just returned by Next as a packet of length n
|
||||
// destined for dst.
|
||||
Commit(n int, dst netip.AddrPort)
|
||||
// Reset clears committed slots; backing storage is retained for reuse.
|
||||
Reset()
|
||||
// Len returns the number of committed packets.
|
||||
Len() int
|
||||
// Cap returns the maximum number of slots in the batch.
|
||||
Cap() int
|
||||
// Get returns the buffers needed to send the batch
|
||||
Get() ([][]byte, []netip.AddrPort)
|
||||
}
|
||||
57
overlay/batch/passthrough.go
Normal file
57
overlay/batch/passthrough.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package batch
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/slackhq/nebula/udp"
|
||||
)
|
||||
|
||||
// Passthrough is a RxBatcher that doesn't batch anything, it just accumulates and then sends packets.
|
||||
type Passthrough struct {
|
||||
out io.Writer
|
||||
slots [][]byte
|
||||
backing []byte
|
||||
cursor int
|
||||
}
|
||||
|
||||
func NewPassthrough(w io.Writer) *Passthrough {
|
||||
const baseNumSlots = 128
|
||||
return &Passthrough{
|
||||
out: w,
|
||||
slots: make([][]byte, 0, baseNumSlots),
|
||||
backing: make([]byte, 0, baseNumSlots*udp.MTU),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Passthrough) Reserve(sz int) []byte {
|
||||
if len(p.backing)+sz > cap(p.backing) {
|
||||
// Grow: allocate a fresh backing. Already-committed slices still
|
||||
// reference the old array and remain valid until Flush drops them.
|
||||
newCap := max(cap(p.backing)*2, sz)
|
||||
p.backing = make([]byte, 0, newCap)
|
||||
}
|
||||
start := len(p.backing)
|
||||
p.backing = p.backing[:start+sz]
|
||||
return p.backing[start : start+sz : start+sz] //return zero length, sz-cap slice
|
||||
}
|
||||
|
||||
func (p *Passthrough) Commit(pkt []byte) error {
|
||||
p.slots = append(p.slots, pkt)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Passthrough) Flush() error {
|
||||
var firstErr error
|
||||
for _, s := range p.slots {
|
||||
_, err := p.out.Write(s)
|
||||
if err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
for i := range p.slots {
|
||||
p.slots[i] = nil
|
||||
}
|
||||
p.slots = p.slots[:0]
|
||||
p.backing = p.backing[:0]
|
||||
return firstErr
|
||||
}
|
||||
61
overlay/batch/tx_batch.go
Normal file
61
overlay/batch/tx_batch.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package batch
|
||||
|
||||
import "net/netip"
|
||||
|
||||
const SendBatchCap = 128
|
||||
|
||||
// SendBatch accumulates encrypted UDP packets for potential TX offloading.
|
||||
// One SendBatch is owned by each listenIn goroutine; no locking is needed.
|
||||
// The backing storage holds up to batchCap packets of slotCap bytes each;
|
||||
// bufs and dsts are parallel slices of committed slots.
|
||||
type SendBatch struct {
|
||||
bufs [][]byte
|
||||
dsts []netip.AddrPort
|
||||
backing []byte
|
||||
slotCap int
|
||||
batchCap int
|
||||
nextSlot int
|
||||
}
|
||||
|
||||
func NewSendBatch(batchCap, slotCap int) *SendBatch {
|
||||
return &SendBatch{
|
||||
bufs: make([][]byte, 0, batchCap),
|
||||
dsts: make([]netip.AddrPort, 0, batchCap),
|
||||
backing: make([]byte, batchCap*slotCap),
|
||||
slotCap: slotCap,
|
||||
batchCap: batchCap,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *SendBatch) Next() []byte {
|
||||
if b.nextSlot >= b.batchCap {
|
||||
return nil
|
||||
}
|
||||
start := b.nextSlot * b.slotCap
|
||||
return b.backing[start : start : start+b.slotCap] //set len to 0 but cap to slotCap
|
||||
}
|
||||
|
||||
func (b *SendBatch) Commit(n int, dst netip.AddrPort) {
|
||||
start := b.nextSlot * b.slotCap
|
||||
b.bufs = append(b.bufs, b.backing[start:start+n])
|
||||
b.dsts = append(b.dsts, dst)
|
||||
b.nextSlot++
|
||||
}
|
||||
|
||||
func (b *SendBatch) Reset() {
|
||||
b.bufs = b.bufs[:0]
|
||||
b.dsts = b.dsts[:0]
|
||||
b.nextSlot = 0
|
||||
}
|
||||
|
||||
func (b *SendBatch) Len() int {
|
||||
return len(b.bufs)
|
||||
}
|
||||
|
||||
func (b *SendBatch) Cap() int {
|
||||
return b.batchCap
|
||||
}
|
||||
|
||||
func (b *SendBatch) Get() ([][]byte, []netip.AddrPort) {
|
||||
return b.bufs, b.dsts
|
||||
}
|
||||
69
overlay/batch/tx_batch_test.go
Normal file
69
overlay/batch/tx_batch_test.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package batch
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSendBatchBookkeeping(t *testing.T) {
|
||||
b := NewSendBatch(4, 32)
|
||||
if b.Len() != 0 || b.Cap() != 4 {
|
||||
t.Fatalf("fresh batch: len=%d cap=%d", b.Len(), b.Cap())
|
||||
}
|
||||
|
||||
ap := netip.MustParseAddrPort("10.0.0.1:4242")
|
||||
for i := 0; i < 4; i++ {
|
||||
slot := b.Next()
|
||||
if slot == nil {
|
||||
t.Fatalf("slot %d: Next returned nil before cap", i)
|
||||
}
|
||||
if cap(slot) != 32 || len(slot) != 0 {
|
||||
t.Fatalf("slot %d: got len=%d cap=%d want len=0 cap=32", i, len(slot), cap(slot))
|
||||
}
|
||||
// Write a marker byte.
|
||||
slot = append(slot, byte(i), byte(i+1), byte(i+2))
|
||||
b.Commit(len(slot), ap)
|
||||
}
|
||||
if b.Next() != nil {
|
||||
t.Fatalf("Next should return nil when full")
|
||||
}
|
||||
if b.Len() != 4 {
|
||||
t.Fatalf("Len=%d want 4", b.Len())
|
||||
}
|
||||
for i, buf := range b.bufs {
|
||||
if len(buf) != 3 || buf[0] != byte(i) {
|
||||
t.Errorf("buf %d: %x", i, buf)
|
||||
}
|
||||
if b.dsts[i] != ap {
|
||||
t.Errorf("dst %d: got %v want %v", i, b.dsts[i], ap)
|
||||
}
|
||||
}
|
||||
|
||||
// Reset returns empty and Next works again.
|
||||
b.Reset()
|
||||
if b.Len() != 0 {
|
||||
t.Fatalf("after Reset Len=%d want 0", b.Len())
|
||||
}
|
||||
slot := b.Next()
|
||||
if slot == nil || cap(slot) != 32 {
|
||||
t.Fatalf("after Reset Next nil or wrong cap: %v cap=%d", slot == nil, cap(slot))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendBatchSlotsDoNotOverlap(t *testing.T) {
|
||||
b := NewSendBatch(3, 8)
|
||||
ap := netip.MustParseAddrPort("10.0.0.1:80")
|
||||
|
||||
// Fill three slots, each with its own sentinel byte.
|
||||
for i := 0; i < 3; i++ {
|
||||
s := b.Next()
|
||||
s = append(s, byte(0xA0+i), byte(0xB0+i))
|
||||
b.Commit(len(s), ap)
|
||||
}
|
||||
|
||||
for i, buf := range b.bufs {
|
||||
if buf[0] != byte(0xA0+i) || buf[1] != byte(0xB0+i) {
|
||||
t.Errorf("slot %d corrupted: %x", i, buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4,15 +4,21 @@ import (
|
||||
"io"
|
||||
"net/netip"
|
||||
|
||||
"github.com/slackhq/nebula/overlay/tio"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
// defaultBatchBufSize is the per-Queue scratch size for Read on backends
|
||||
// that don't do TSO segmentation. 65535 covers any single IP packet.
|
||||
const defaultBatchBufSize = 65535
|
||||
|
||||
type Device interface {
|
||||
io.ReadWriteCloser
|
||||
io.Closer
|
||||
Activate() error
|
||||
Networks() []netip.Prefix
|
||||
Name() string
|
||||
RoutesFor(netip.Addr) routing.Gateways
|
||||
SupportsMultiqueue() bool
|
||||
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
||||
SupportsMultiqueue() bool //todo remove?
|
||||
NewMultiQueueReader() error
|
||||
Readers() []tio.Queue
|
||||
}
|
||||
|
||||
@@ -4,9 +4,9 @@ package overlaytest
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/netip"
|
||||
|
||||
"github.com/slackhq/nebula/overlay/tio"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
@@ -31,20 +31,28 @@ func (NoopTun) Name() string {
|
||||
return "noop"
|
||||
}
|
||||
|
||||
func (NoopTun) Read([]byte) (int, error) {
|
||||
return 0, nil
|
||||
func (NoopTun) Read() ([][]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (NoopTun) Write([]byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (NoopTun) WriteFromSelf(p []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (NoopTun) SupportsMultiqueue() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, errors.New("unsupported")
|
||||
func (NoopTun) NewMultiQueueReader() error {
|
||||
return errors.New("unsupported")
|
||||
}
|
||||
|
||||
func (NoopTun) Readers() []tio.Queue {
|
||||
return []tio.Queue{NoopTun{}}
|
||||
}
|
||||
|
||||
func (NoopTun) Close() error {
|
||||
|
||||
69
overlay/tio/container_poll_linux.go
Normal file
69
overlay/tio/container_poll_linux.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package tio
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type pollContainer struct {
|
||||
pq []*Poll
|
||||
// pqi is exactly the same as pq, but stored as the interface type
|
||||
pqi []Queue
|
||||
shutdownFd int
|
||||
}
|
||||
|
||||
func NewPollContainer() (Container, error) {
|
||||
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create eventfd: %w", err)
|
||||
}
|
||||
|
||||
out := &pollContainer{
|
||||
pq: []*Poll{},
|
||||
pqi: []Queue{},
|
||||
shutdownFd: shutdownFd,
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *pollContainer) Queues() []Queue {
|
||||
return c.pqi
|
||||
}
|
||||
|
||||
func (c *pollContainer) Add(fd int) error {
|
||||
x, err := newPoll(fd, c.shutdownFd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.pq = append(c.pq, x)
|
||||
c.pqi = append(c.pqi, x)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *pollContainer) wakeForShutdown() error {
|
||||
var buf [8]byte
|
||||
binary.NativeEndian.PutUint64(buf[:], 1)
|
||||
_, err := unix.Write(int(c.shutdownFd), buf[:])
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *pollContainer) Close() error {
|
||||
errs := []error{}
|
||||
|
||||
if err := c.wakeForShutdown(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
for _, x := range c.pq {
|
||||
if err := x.Close(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
65
overlay/tio/tio.go
Normal file
65
overlay/tio/tio.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package tio
|
||||
|
||||
import "io"
|
||||
|
||||
// defaultBatchBufSize is the per-Queue scratch size for Read on backends
|
||||
// that don't do TSO segmentation. 65535 covers any single IP packet.
|
||||
const defaultBatchBufSize = 65535
|
||||
|
||||
// Container holds one or many Queue objects and helps close them in an orderly way
|
||||
type Container interface {
|
||||
io.Closer
|
||||
Queues() []Queue
|
||||
|
||||
// Add takes a tun fd, adds it to the container, and prepares it for use as a Queue
|
||||
Add(fd int) error
|
||||
}
|
||||
|
||||
// Queue is a readable/writable Poll queue. One Queue is driven by a single
|
||||
// read goroutine plus concurrent writers (see Write / WriteReject below).
|
||||
type Queue interface {
|
||||
io.Closer
|
||||
|
||||
// Read returns one or more packets. The returned slices are borrowed
|
||||
// from the Queue's internal buffer and are only valid until the next
|
||||
// Read or Close on this Queue - callers must encrypt or copy each
|
||||
// slice before the next call. Not safe for concurrent Reads; exactly
|
||||
// one goroutine per Queue reads.
|
||||
Read() ([][]byte, error)
|
||||
|
||||
// Write emits a single packet on the plaintext (outside→inside)
|
||||
// delivery path. May run concurrently with WriteFromSelf on the same
|
||||
// Queue, but not with itself.
|
||||
Write(p []byte) (int, error)
|
||||
|
||||
// WriteFromSelf writes a single packet that originated from the inside
|
||||
// path (reject replies or self-forward) using scratch state distinct
|
||||
// from Write, so it can run concurrently with Write on the same Queue
|
||||
// without a data race. On backends without a shared-scratch Write, a
|
||||
// trivial delegation to Write is acceptable.
|
||||
WriteFromSelf(p []byte) (int, error)
|
||||
}
|
||||
|
||||
// GSOWriter is implemented by Queues that can emit a TCP TSO superpacket
|
||||
// assembled from a header prefix plus one or more borrowed payload
|
||||
// fragments, in a single vectored write (writev with a leading
|
||||
// virtio_net_hdr). This lets the coalescer avoid copying payload bytes
|
||||
// between the caller's decrypt buffer and the TUN. Backends without GSO
|
||||
// support return false from GSOSupported and coalescing is skipped.
|
||||
//
|
||||
// hdr contains the IPv4/IPv6 + TCP header prefix (mutable - callers will
|
||||
// have filled in total length and pseudo-header partial). pays are
|
||||
// non-overlapping payload fragments whose concatenation is the full
|
||||
// superpacket payload; they are read-only from the writer's perspective
|
||||
// and must remain valid until the call returns. gsoSize is the MSS:
|
||||
// every segment except possibly the last is exactly that many bytes.
|
||||
// csumStart is the byte offset where the TCP header begins within hdr.
|
||||
//
|
||||
// # TODO fold into Queue
|
||||
//
|
||||
// hdr's TCP checksum field must already hold the pseudo-header partial
|
||||
// sum (single-fold, not inverted), per virtio NEEDS_CSUM semantics.
|
||||
type GSOWriter interface {
|
||||
WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error
|
||||
GSOSupported() bool
|
||||
}
|
||||
168
overlay/tio/tio_poll_linux.go
Normal file
168
overlay/tio/tio_poll_linux.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package tio
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// Maximum size we accept for a single read from a TUN with IFF_VNET_HDR. A
|
||||
// TSO superpacket can be up to 64KiB of payload plus a single L2/L3/L4 header
|
||||
// prefix plus the virtio header.
|
||||
const tunReadBufSize = 65535
|
||||
|
||||
type Poll struct {
|
||||
fd int
|
||||
|
||||
readPoll [2]unix.PollFd
|
||||
writePoll [2]unix.PollFd
|
||||
closed atomic.Bool
|
||||
|
||||
readBuf []byte
|
||||
batchRet [1][]byte
|
||||
}
|
||||
|
||||
func newPoll(fd int, shutdownFd int) (*Poll, error) {
|
||||
if err := unix.SetNonblock(fd, true); err != nil {
|
||||
_ = unix.Close(fd)
|
||||
return nil, fmt.Errorf("failed to set Poll device as nonblocking: %w", err)
|
||||
}
|
||||
|
||||
out := &Poll{
|
||||
fd: fd,
|
||||
readBuf: make([]byte, tunReadBufSize),
|
||||
readPoll: [2]unix.PollFd{
|
||||
{Fd: int32(fd), Events: unix.POLLIN},
|
||||
{Fd: int32(shutdownFd), Events: unix.POLLIN},
|
||||
},
|
||||
writePoll: [2]unix.PollFd{
|
||||
{Fd: int32(fd), Events: unix.POLLOUT},
|
||||
{Fd: int32(shutdownFd), Events: unix.POLLIN},
|
||||
},
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// blockOnRead waits until the Poll fd is readable or shutdown has been signaled.
|
||||
// Returns os.ErrClosed if Close was called.
|
||||
func (t *Poll) blockOnRead() error {
|
||||
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
|
||||
var err error
|
||||
for {
|
||||
_, err = unix.Poll(t.readPoll[:], -1)
|
||||
if err != unix.EINTR {
|
||||
break
|
||||
}
|
||||
}
|
||||
tunEvents := t.readPoll[0].Revents
|
||||
shutdownEvents := t.readPoll[1].Revents
|
||||
t.readPoll[0].Revents = 0
|
||||
t.readPoll[1].Revents = 0
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
|
||||
return os.ErrClosed
|
||||
}
|
||||
if tunEvents&problemFlags != 0 {
|
||||
return os.ErrClosed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Poll) blockOnWrite() error {
|
||||
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
|
||||
var err error
|
||||
for {
|
||||
_, err = unix.Poll(t.writePoll[:], -1)
|
||||
if err != unix.EINTR {
|
||||
break
|
||||
}
|
||||
}
|
||||
tunEvents := t.writePoll[0].Revents
|
||||
shutdownEvents := t.writePoll[1].Revents
|
||||
t.writePoll[0].Revents = 0
|
||||
t.writePoll[1].Revents = 0
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
|
||||
return os.ErrClosed
|
||||
}
|
||||
if tunEvents&problemFlags != 0 {
|
||||
return os.ErrClosed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Poll) Read() ([][]byte, error) {
|
||||
n, err := t.readOne(t.readBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.batchRet[0] = t.readBuf[:n]
|
||||
return t.batchRet[:], nil
|
||||
}
|
||||
|
||||
func (t *Poll) readOne(to []byte) (int, error) {
|
||||
for {
|
||||
n, errno := unix.Read(t.fd, to)
|
||||
if errno == nil {
|
||||
return n, nil
|
||||
}
|
||||
switch errno {
|
||||
case unix.EAGAIN:
|
||||
if err := t.blockOnRead(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
case unix.EINTR:
|
||||
// retry
|
||||
case unix.EBADF:
|
||||
return 0, os.ErrClosed
|
||||
default:
|
||||
return 0, errno
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write is only valid for single threaded use
|
||||
func (t *Poll) Write(from []byte) (int, error) {
|
||||
for {
|
||||
n, errno := unix.Write(t.fd, from)
|
||||
if errno == nil {
|
||||
return n, nil
|
||||
}
|
||||
switch errno {
|
||||
case unix.EAGAIN:
|
||||
if err := t.blockOnWrite(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
case unix.EINTR:
|
||||
// retry
|
||||
case unix.EBADF:
|
||||
return 0, os.ErrClosed
|
||||
default:
|
||||
return 0, errno
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Poll) Close() error {
|
||||
if t.closed.Swap(true) {
|
||||
return nil
|
||||
}
|
||||
//shutdownFd is owned by the container, so we should not close it
|
||||
var err error
|
||||
if t.fd >= 0 {
|
||||
err = unix.Close(t.fd)
|
||||
t.fd = -1
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Poll) WriteFromSelf(p []byte) (int, error) {
|
||||
return t.Write(p)
|
||||
}
|
||||
82
overlay/tio/tun_file_linux_test.go
Normal file
82
overlay/tio/tun_file_linux_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
//go:build linux && !android && !e2e_testing
|
||||
// +build linux,!android,!e2e_testing
|
||||
|
||||
package tio
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// newReadPipe returns a read fd. The matching write fd is registered for cleanup.
|
||||
// The caller takes ownership of the read fd (pass it to newOffload / newFriend).
|
||||
func newReadPipe(t *testing.T) int {
|
||||
t.Helper()
|
||||
var fds [2]int
|
||||
if err := unix.Pipe2(fds[:], unix.O_CLOEXEC); err != nil {
|
||||
t.Fatalf("pipe2: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = unix.Close(fds[1]) })
|
||||
return fds[0]
|
||||
}
|
||||
|
||||
func TestPoll_WakeForShutdown_WakesFriends(t *testing.T) {
|
||||
pipe1 := newReadPipe(t)
|
||||
pipe2 := newReadPipe(t)
|
||||
parent, err := NewPollContainer()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, parent.Add(pipe1))
|
||||
require.NoError(t, parent.Add(pipe2))
|
||||
t.Cleanup(func() {
|
||||
_ = unix.Close(pipe1)
|
||||
_ = unix.Close(pipe2)
|
||||
})
|
||||
|
||||
readers := parent.Queues()
|
||||
errs := make([]error, len(readers))
|
||||
var wg sync.WaitGroup
|
||||
for i, r := range readers {
|
||||
wg.Add(1)
|
||||
go func(i int, r Queue) {
|
||||
defer wg.Done()
|
||||
_, errs[i] = r.Read()
|
||||
}(i, r)
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if err := parent.Close(); err != nil {
|
||||
t.Fatalf("Close: %v", err)
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() { wg.Wait(); close(done) }()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("readers did not wake")
|
||||
}
|
||||
|
||||
for i, err := range errs {
|
||||
if !errors.Is(err, os.ErrClosed) {
|
||||
t.Errorf("reader %d: expected os.ErrClosed, got %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoll_Close_Idempotent(t *testing.T) {
|
||||
tf, err := newPoll(newReadPipe(t), 1)
|
||||
require.NoError(t, err)
|
||||
if err := tf.Close(); err != nil {
|
||||
t.Fatalf("first Close: %v", err)
|
||||
}
|
||||
if err := tf.Close(); err != nil {
|
||||
t.Fatalf("second Close should be a no-op, got %v", err)
|
||||
}
|
||||
}
|
||||
@@ -13,17 +13,42 @@ import (
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/overlay/tio"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
rwc io.ReadWriteCloser
|
||||
fd int
|
||||
vpnNetworks []netip.Prefix
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *slog.Logger
|
||||
|
||||
readBuf []byte
|
||||
batchRet [1][]byte
|
||||
}
|
||||
|
||||
func (t *tun) Read() ([][]byte, error) {
|
||||
n, err := t.rwc.Read(t.readBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.batchRet[0] = t.readBuf[:n]
|
||||
return t.batchRet[:], nil
|
||||
}
|
||||
|
||||
func (t *tun) Write(p []byte) (int, error) {
|
||||
return t.rwc.Write(p)
|
||||
}
|
||||
|
||||
func (t *tun) WriteFromSelf(p []byte) (int, error) {
|
||||
return t.rwc.Write(p)
|
||||
}
|
||||
|
||||
func (t *tun) Close() error {
|
||||
return t.rwc.Close()
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
@@ -32,10 +57,11 @@ func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
rwc: file,
|
||||
fd: deviceFd,
|
||||
vpnNetworks: vpnNetworks,
|
||||
l: l,
|
||||
readBuf: make([]byte, defaultBatchBufSize),
|
||||
}
|
||||
|
||||
err := t.reload(c, true)
|
||||
@@ -62,7 +88,7 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
return r
|
||||
}
|
||||
|
||||
func (t tun) Activate() error {
|
||||
func (t *tun) Activate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -99,6 +125,10 @@ func (t *tun) SupportsMultiqueue() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
|
||||
func (t *tun) NewMultiQueueReader() error {
|
||||
return fmt.Errorf("TODO: multiqueue not implemented for android")
|
||||
}
|
||||
|
||||
func (t *tun) Readers() []tio.Queue {
|
||||
return []tio.Queue{t}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/overlay/tio"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
netroute "golang.org/x/net/route"
|
||||
@@ -23,7 +24,7 @@ import (
|
||||
)
|
||||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
rwc io.ReadWriteCloser
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
DefaultMTU int
|
||||
@@ -34,6 +35,9 @@ type tun struct {
|
||||
|
||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||
out []byte
|
||||
|
||||
readBuf []byte
|
||||
batchRet [1][]byte
|
||||
}
|
||||
|
||||
type ifReq struct {
|
||||
@@ -124,11 +128,12 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t
|
||||
}
|
||||
|
||||
t := &tun{
|
||||
ReadWriteCloser: os.NewFile(uintptr(fd), ""),
|
||||
rwc: os.NewFile(uintptr(fd), ""),
|
||||
Device: name,
|
||||
vpnNetworks: vpnNetworks,
|
||||
DefaultMTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
readBuf: make([]byte, defaultBatchBufSize),
|
||||
}
|
||||
|
||||
err = t.reload(c, true)
|
||||
@@ -158,8 +163,8 @@ func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, e
|
||||
}
|
||||
|
||||
func (t *tun) Close() error {
|
||||
if t.ReadWriteCloser != nil {
|
||||
return t.ReadWriteCloser.Close()
|
||||
if t.rwc != nil {
|
||||
return t.rwc.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -502,15 +507,28 @@ func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
func (t *tun) readOne(to []byte) (int, error) {
|
||||
buf := make([]byte, len(to)+4)
|
||||
|
||||
n, err := t.ReadWriteCloser.Read(buf)
|
||||
n, err := t.rwc.Read(buf)
|
||||
|
||||
copy(to, buf[4:])
|
||||
return n - 4, err
|
||||
}
|
||||
|
||||
func (t *tun) Read() ([][]byte, error) {
|
||||
n, err := t.readOne(t.readBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.batchRet[0] = t.readBuf[:n]
|
||||
return t.batchRet[:], nil
|
||||
}
|
||||
|
||||
func (t *tun) WriteFromSelf(p []byte) (int, error) {
|
||||
return t.Write(p)
|
||||
}
|
||||
|
||||
// Write is only valid for single threaded use
|
||||
func (t *tun) Write(from []byte) (int, error) {
|
||||
buf := t.out
|
||||
@@ -536,7 +554,7 @@ func (t *tun) Write(from []byte) (int, error) {
|
||||
|
||||
copy(buf[4:], from)
|
||||
|
||||
n, err := t.ReadWriteCloser.Write(buf)
|
||||
n, err := t.rwc.Write(buf)
|
||||
return n - 4, err
|
||||
}
|
||||
|
||||
@@ -552,6 +570,10 @@ func (t *tun) SupportsMultiqueue() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
||||
func (t *tun) NewMultiQueueReader() error {
|
||||
return fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
||||
}
|
||||
|
||||
func (t *tun) Readers() []tio.Queue {
|
||||
return []tio.Queue{t}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/overlay/tio"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
@@ -21,6 +22,24 @@ type disabledTun struct {
|
||||
tx metrics.Counter
|
||||
rx metrics.Counter
|
||||
l *slog.Logger
|
||||
numReaders int
|
||||
|
||||
batchRet [1][]byte
|
||||
}
|
||||
|
||||
func (t *disabledTun) Read() ([][]byte, error) {
|
||||
r, ok := <-t.read
|
||||
if !ok {
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
t.tx.Inc(1)
|
||||
if t.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
t.l.Debug("Write payload", "raw", prettyPacket(r))
|
||||
}
|
||||
|
||||
t.batchRet[0] = r
|
||||
return t.batchRet[:], nil
|
||||
}
|
||||
|
||||
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *slog.Logger) *disabledTun {
|
||||
@@ -28,6 +47,7 @@ func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled boo
|
||||
vpnNetworks: vpnNetworks,
|
||||
read: make(chan []byte, queueLen),
|
||||
l: l,
|
||||
numReaders: 1,
|
||||
}
|
||||
|
||||
if metricsEnabled {
|
||||
@@ -57,24 +77,6 @@ func (*disabledTun) Name() string {
|
||||
return "disabled"
|
||||
}
|
||||
|
||||
func (t *disabledTun) Read(b []byte) (int, error) {
|
||||
r, ok := <-t.read
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if len(r) > len(b) {
|
||||
return 0, fmt.Errorf("packet larger than mtu: %d > %d bytes", len(r), len(b))
|
||||
}
|
||||
|
||||
t.tx.Inc(1)
|
||||
if t.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
t.l.Debug("Write payload", "raw", prettyPacket(r))
|
||||
}
|
||||
|
||||
return copy(b, r), nil
|
||||
}
|
||||
|
||||
func (t *disabledTun) handleICMPEchoRequest(b []byte) bool {
|
||||
out := make([]byte, len(b))
|
||||
out = iputil.CreateICMPEchoResponse(b, out)
|
||||
@@ -106,12 +108,25 @@ func (t *disabledTun) Write(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (t *disabledTun) WriteFromSelf(b []byte) (int, error) {
|
||||
return t.Write(b)
|
||||
}
|
||||
|
||||
func (t *disabledTun) SupportsMultiqueue() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return t, nil
|
||||
func (t *disabledTun) NewMultiQueueReader() error {
|
||||
t.numReaders++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *disabledTun) Readers() []tio.Queue {
|
||||
out := make([]tio.Queue, t.numReaders)
|
||||
for i := range t.numReaders {
|
||||
out[i] = t
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *disabledTun) Close() error {
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
//go:build linux && !android && !e2e_testing
|
||||
// +build linux,!android,!e2e_testing
|
||||
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// newReadPipe returns a read fd. The matching write fd is registered for cleanup.
|
||||
// The caller takes ownership of the read fd (pass it to newTunFd / newFriend).
|
||||
func newReadPipe(t *testing.T) int {
|
||||
t.Helper()
|
||||
var fds [2]int
|
||||
if err := unix.Pipe2(fds[:], unix.O_CLOEXEC); err != nil {
|
||||
t.Fatalf("pipe2: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = unix.Close(fds[1]) })
|
||||
return fds[0]
|
||||
}
|
||||
|
||||
func TestTunFile_WakeForShutdown_UnblocksRead(t *testing.T) {
|
||||
tf, err := newTunFd(newReadPipe(t))
|
||||
if err != nil {
|
||||
t.Fatalf("newTunFd: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = tf.Close() })
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := tf.Read(make([]byte, 64))
|
||||
done <- err
|
||||
}()
|
||||
|
||||
// Verify Read is actually blocked in poll.
|
||||
select {
|
||||
case err := <-done:
|
||||
t.Fatalf("Read returned before shutdown signal: %v", err)
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
if err := tf.wakeForShutdown(); err != nil {
|
||||
t.Fatalf("wakeForShutdown: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if !errors.Is(err, os.ErrClosed) {
|
||||
t.Fatalf("expected os.ErrClosed, got %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Read did not wake on shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTunFile_WakeForShutdown_WakesFriends(t *testing.T) {
|
||||
parent, err := newTunFd(newReadPipe(t))
|
||||
if err != nil {
|
||||
t.Fatalf("newTunFd: %v", err)
|
||||
}
|
||||
friend, err := parent.newFriend(newReadPipe(t))
|
||||
if err != nil {
|
||||
_ = parent.Close()
|
||||
t.Fatalf("newFriend: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = friend.Close()
|
||||
_ = parent.Close()
|
||||
})
|
||||
|
||||
readers := []*tunFile{parent, friend}
|
||||
errs := make([]error, len(readers))
|
||||
var wg sync.WaitGroup
|
||||
for i, r := range readers {
|
||||
wg.Add(1)
|
||||
go func(i int, r *tunFile) {
|
||||
defer wg.Done()
|
||||
_, errs[i] = r.Read(make([]byte, 64))
|
||||
}(i, r)
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if err := parent.wakeForShutdown(); err != nil {
|
||||
t.Fatalf("wakeForShutdown: %v", err)
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() { wg.Wait(); close(done) }()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("readers did not wake")
|
||||
}
|
||||
|
||||
for i, err := range errs {
|
||||
if !errors.Is(err, os.ErrClosed) {
|
||||
t.Errorf("reader %d: expected os.ErrClosed, got %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTunFile_Close_Idempotent(t *testing.T) {
|
||||
tf, err := newTunFd(newReadPipe(t))
|
||||
if err != nil {
|
||||
t.Fatalf("newTunFd: %v", err)
|
||||
}
|
||||
if err := tf.Close(); err != nil {
|
||||
t.Fatalf("first Close: %v", err)
|
||||
}
|
||||
if err := tf.Close(); err != nil {
|
||||
t.Fatalf("second Close should be a no-op, got %v", err)
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
@@ -20,7 +19,7 @@ import (
|
||||
"github.com/gaissmai/bart"
|
||||
|
||||
"github.com/slackhq/nebula/config"
|
||||
|
||||
"github.com/slackhq/nebula/overlay/tio"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
netroute "golang.org/x/net/route"
|
||||
@@ -103,6 +102,9 @@ type tun struct {
|
||||
readPoll [2]unix.PollFd
|
||||
writePoll [2]unix.PollFd
|
||||
closed atomic.Bool
|
||||
|
||||
readBuf []byte
|
||||
batchRet [1][]byte
|
||||
}
|
||||
|
||||
// blockOnRead waits until the tun fd is readable or shutdown has been signaled.
|
||||
@@ -157,7 +159,20 @@ func (t *tun) blockOnWrite() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
func (t *tun) Read() ([][]byte, error) {
|
||||
n, err := t.readOne(t.readBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.batchRet[0] = t.readBuf[:n]
|
||||
return t.batchRet[:], nil
|
||||
}
|
||||
|
||||
func (t *tun) WriteFromSelf(p []byte) (int, error) {
|
||||
return t.Write(p)
|
||||
}
|
||||
|
||||
func (t *tun) readOne(to []byte) (int, error) {
|
||||
// first 4 bytes is protocol family, in network byte order
|
||||
var head [4]byte
|
||||
iovecs := [2]syscall.Iovec{
|
||||
@@ -375,6 +390,7 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
fd: fd,
|
||||
readBuf: make([]byte, defaultBatchBufSize),
|
||||
shutdownR: shutdownR,
|
||||
shutdownW: shutdownW,
|
||||
readPoll: [2]unix.PollFd{
|
||||
@@ -565,8 +581,8 @@ func (t *tun) SupportsMultiqueue() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
||||
func (t *tun) NewMultiQueueReader() error {
|
||||
return fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
||||
}
|
||||
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
@@ -593,6 +609,10 @@ func (t *tun) addRoutes(logErrors bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Readers() []tio.Queue {
|
||||
return []tio.Queue{t}
|
||||
}
|
||||
|
||||
func (t *tun) removeRoutes(routes []Route) error {
|
||||
for _, r := range routes {
|
||||
if !r.Install {
|
||||
|
||||
@@ -16,16 +16,41 @@ import (
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/overlay/tio"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
rwc io.ReadWriteCloser
|
||||
vpnNetworks []netip.Prefix
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *slog.Logger
|
||||
|
||||
readBuf []byte
|
||||
batchRet [1][]byte
|
||||
}
|
||||
|
||||
func (t *tun) Read() ([][]byte, error) {
|
||||
n, err := t.rwc.Read(t.readBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.batchRet[0] = t.readBuf[:n]
|
||||
return t.batchRet[:], nil
|
||||
}
|
||||
|
||||
func (t *tun) Write(p []byte) (int, error) {
|
||||
return t.rwc.Write(p)
|
||||
}
|
||||
|
||||
func (t *tun) WriteFromSelf(p []byte) (int, error) {
|
||||
return t.rwc.Write(p)
|
||||
}
|
||||
|
||||
func (t *tun) Close() error {
|
||||
return t.rwc.Close()
|
||||
}
|
||||
|
||||
func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||
@@ -36,8 +61,9 @@ func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
|
||||
t := &tun{
|
||||
vpnNetworks: vpnNetworks,
|
||||
ReadWriteCloser: &tunReadCloser{f: file},
|
||||
rwc: &tunReadCloser{f: file},
|
||||
l: l,
|
||||
readBuf: make([]byte, defaultBatchBufSize),
|
||||
}
|
||||
|
||||
err := t.reload(c, true)
|
||||
@@ -155,6 +181,10 @@ func (t *tun) SupportsMultiqueue() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
|
||||
func (t *tun) NewMultiQueueReader() error {
|
||||
return fmt.Errorf("TODO: multiqueue not implemented for ios")
|
||||
}
|
||||
|
||||
func (t *tun) Readers() []tio.Queue {
|
||||
return []tio.Queue{t}
|
||||
}
|
||||
|
||||
@@ -4,9 +4,7 @@
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -19,180 +17,15 @@ import (
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/overlay/tio"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking.
|
||||
// A shared eventfd allows Close to wake all readers blocked in poll.
|
||||
type tunFile struct {
|
||||
fd int
|
||||
shutdownFd int
|
||||
lastOne bool
|
||||
readPoll [2]unix.PollFd
|
||||
writePoll [2]unix.PollFd
|
||||
closed bool
|
||||
}
|
||||
|
||||
// newFriend makes a tunFile for a MultiQueueReader that copies the shutdown eventfd from the parent tun
|
||||
func (r *tunFile) newFriend(fd int) (*tunFile, error) {
|
||||
if err := unix.SetNonblock(fd, true); err != nil {
|
||||
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
|
||||
}
|
||||
return &tunFile{
|
||||
fd: fd,
|
||||
shutdownFd: r.shutdownFd,
|
||||
readPoll: [2]unix.PollFd{
|
||||
{Fd: int32(fd), Events: unix.POLLIN},
|
||||
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
|
||||
},
|
||||
writePoll: [2]unix.PollFd{
|
||||
{Fd: int32(fd), Events: unix.POLLOUT},
|
||||
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newTunFd(fd int) (*tunFile, error) {
|
||||
if err := unix.SetNonblock(fd, true); err != nil {
|
||||
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
|
||||
}
|
||||
|
||||
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create eventfd: %w", err)
|
||||
}
|
||||
|
||||
out := &tunFile{
|
||||
fd: fd,
|
||||
shutdownFd: shutdownFd,
|
||||
lastOne: true,
|
||||
readPoll: [2]unix.PollFd{
|
||||
{Fd: int32(fd), Events: unix.POLLIN},
|
||||
{Fd: int32(shutdownFd), Events: unix.POLLIN},
|
||||
},
|
||||
writePoll: [2]unix.PollFd{
|
||||
{Fd: int32(fd), Events: unix.POLLOUT},
|
||||
{Fd: int32(shutdownFd), Events: unix.POLLIN},
|
||||
},
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *tunFile) blockOnRead() error {
|
||||
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
|
||||
var err error
|
||||
for {
|
||||
_, err = unix.Poll(r.readPoll[:], -1)
|
||||
if err != unix.EINTR {
|
||||
break
|
||||
}
|
||||
}
|
||||
//always reset these!
|
||||
tunEvents := r.readPoll[0].Revents
|
||||
shutdownEvents := r.readPoll[1].Revents
|
||||
r.readPoll[0].Revents = 0
|
||||
r.readPoll[1].Revents = 0
|
||||
//do the err check before trusting the potentially bogus bits we just got
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
|
||||
return os.ErrClosed
|
||||
} else if tunEvents&problemFlags != 0 {
|
||||
return os.ErrClosed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *tunFile) blockOnWrite() error {
|
||||
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
|
||||
var err error
|
||||
for {
|
||||
_, err = unix.Poll(r.writePoll[:], -1)
|
||||
if err != unix.EINTR {
|
||||
break
|
||||
}
|
||||
}
|
||||
//always reset these!
|
||||
tunEvents := r.writePoll[0].Revents
|
||||
shutdownEvents := r.writePoll[1].Revents
|
||||
r.writePoll[0].Revents = 0
|
||||
r.writePoll[1].Revents = 0
|
||||
//do the err check before trusting the potentially bogus bits we just got
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
|
||||
return os.ErrClosed
|
||||
} else if tunEvents&problemFlags != 0 {
|
||||
return os.ErrClosed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *tunFile) Read(buf []byte) (int, error) {
|
||||
for {
|
||||
if n, err := unix.Read(r.fd, buf); err == nil {
|
||||
return n, nil
|
||||
} else if err == unix.EAGAIN {
|
||||
if err = r.blockOnRead(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
continue
|
||||
} else if err == unix.EINTR {
|
||||
continue
|
||||
} else if err == unix.EBADF {
|
||||
return 0, os.ErrClosed
|
||||
} else {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *tunFile) Write(buf []byte) (int, error) {
|
||||
for {
|
||||
if n, err := unix.Write(r.fd, buf); err == nil {
|
||||
return n, nil
|
||||
} else if err == unix.EAGAIN {
|
||||
if err = r.blockOnWrite(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
continue
|
||||
} else if err == unix.EINTR {
|
||||
continue
|
||||
} else if err == unix.EBADF {
|
||||
return 0, os.ErrClosed
|
||||
} else {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *tunFile) wakeForShutdown() error {
|
||||
var buf [8]byte
|
||||
binary.NativeEndian.PutUint64(buf[:], 1)
|
||||
_, err := unix.Write(int(r.readPoll[1].Fd), buf[:])
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *tunFile) Close() error {
|
||||
if r.closed { // avoid closing more than once. Technically a fd could get re-used, which would be a problem
|
||||
return nil
|
||||
}
|
||||
r.closed = true
|
||||
if r.lastOne {
|
||||
_ = unix.Close(r.shutdownFd)
|
||||
}
|
||||
return unix.Close(r.fd)
|
||||
}
|
||||
|
||||
type tun struct {
|
||||
*tunFile
|
||||
readers []*tunFile
|
||||
readers tio.Container
|
||||
closeLock sync.Mutex
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
@@ -249,44 +82,57 @@ func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
||||
// openTunDev opens /dev/net/tun, creating the device node first if it's
|
||||
// missing (docker containers occasionally omit it).
|
||||
func openTunDev() (int, error) {
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
|
||||
if os.IsNotExist(err) {
|
||||
err = os.MkdirAll("/dev/net", 0755)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
|
||||
if err == nil {
|
||||
return fd, nil
|
||||
}
|
||||
err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
|
||||
if !os.IsNotExist(err) {
|
||||
return -1, err
|
||||
}
|
||||
if err = os.MkdirAll("/dev/net", 0755); err != nil {
|
||||
return -1, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
|
||||
}
|
||||
if err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))); err != nil {
|
||||
return -1, fmt.Errorf("failed to create /dev/net/tun: %w", err)
|
||||
}
|
||||
|
||||
fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
return nil, err
|
||||
return -1, fmt.Errorf("created /dev/net/tun, but still failed: %w", err)
|
||||
}
|
||||
return fd, nil
|
||||
}
|
||||
|
||||
// tunSetIff runs TUNSETIFF with the given flags and returns the kernel-chosen
|
||||
// device name on success.
|
||||
func tunSetIff(fd int, name string, flags uint16) (string, error) {
|
||||
var req ifReq
|
||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
|
||||
req.Flags = flags
|
||||
copy(req.Name[:], name)
|
||||
if err := ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.Trim(string(req.Name[:]), "\x00"), nil
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
||||
baseFlags := uint16(unix.IFF_TUN | unix.IFF_NO_PI)
|
||||
if multiqueue {
|
||||
req.Flags |= unix.IFF_MULTI_QUEUE
|
||||
baseFlags |= unix.IFF_MULTI_QUEUE
|
||||
}
|
||||
nameStr := c.GetString("tun.dev", "")
|
||||
copy(req.Name[:], nameStr)
|
||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
|
||||
fd, err := openTunDev()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
name, err := tunSetIff(fd, nameStr, baseFlags)
|
||||
if err != nil {
|
||||
_ = unix.Close(fd)
|
||||
return nil, &NameError{
|
||||
Name: nameStr,
|
||||
Underlying: err,
|
||||
return nil, &NameError{Name: nameStr, Underlying: err}
|
||||
}
|
||||
}
|
||||
name := strings.Trim(string(req.Name[:]), "\x00")
|
||||
|
||||
t, err := newTunGeneric(c, l, fd, vpnNetworks)
|
||||
if err != nil {
|
||||
@@ -300,14 +146,19 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue
|
||||
|
||||
// newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error.
|
||||
func newTunGeneric(c *config.C, l *slog.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
tfd, err := newTunFd(fd)
|
||||
container, err := tio.NewPollContainer()
|
||||
if err != nil {
|
||||
_ = unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
err = container.Add(fd)
|
||||
if err != nil {
|
||||
_ = unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := &tun{
|
||||
tunFile: tfd,
|
||||
readers: []*tunFile{tfd},
|
||||
readers: container,
|
||||
closeLock: sync.Mutex{},
|
||||
vpnNetworks: vpnNetworks,
|
||||
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||
@@ -410,32 +261,28 @@ func (t *tun) SupportsMultiqueue() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
func (t *tun) NewMultiQueueReader() error {
|
||||
t.closeLock.Lock()
|
||||
defer t.closeLock.Unlock()
|
||||
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
var req ifReq
|
||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
|
||||
copy(req.Name[:], t.Device)
|
||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
flags := uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
|
||||
if _, err = tunSetIff(fd, t.Device, flags); err != nil {
|
||||
_ = unix.Close(fd)
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
out, err := t.tunFile.newFriend(fd)
|
||||
err = t.readers.Add(fd)
|
||||
if err != nil {
|
||||
_ = unix.Close(fd)
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
t.readers = append(t.readers, out)
|
||||
|
||||
return out, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
@@ -869,6 +716,10 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
||||
t.routeTree.Store(newTree)
|
||||
}
|
||||
|
||||
func (t *tun) Readers() []tio.Queue {
|
||||
return t.readers.Queues()
|
||||
}
|
||||
|
||||
func (t *tun) Close() error {
|
||||
t.closeLock.Lock()
|
||||
defer t.closeLock.Unlock()
|
||||
@@ -878,32 +729,10 @@ func (t *tun) Close() error {
|
||||
t.routeChan = nil
|
||||
}
|
||||
|
||||
// Signal all readers blocked in poll to wake up and exit
|
||||
_ = t.tunFile.wakeForShutdown()
|
||||
|
||||
if t.ioctlFd > 0 {
|
||||
_ = unix.Close(int(t.ioctlFd))
|
||||
t.ioctlFd = 0
|
||||
}
|
||||
|
||||
for i := range t.readers {
|
||||
if i == 0 {
|
||||
continue //we want to close the zeroth reader last
|
||||
}
|
||||
err := t.readers[i].Close()
|
||||
if err != nil {
|
||||
t.l.Error("error closing tun reader", "reader", i, "error", err)
|
||||
} else {
|
||||
t.l.Info("closed tun reader", "reader", i)
|
||||
}
|
||||
}
|
||||
|
||||
//this is t.readers[0] too
|
||||
err := t.tunFile.Close()
|
||||
if err != nil {
|
||||
t.l.Error("error closing tun reader", "reader", 0, "error", err)
|
||||
} else {
|
||||
t.l.Info("closed tun reader", "reader", 0)
|
||||
}
|
||||
return err
|
||||
return t.readers.Close()
|
||||
}
|
||||
|
||||
@@ -3,7 +3,9 @@
|
||||
|
||||
package overlay
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
var runAdvMSSTests = []struct {
|
||||
name string
|
||||
|
||||
@@ -6,7 +6,6 @@ package overlay
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"os"
|
||||
@@ -17,6 +16,7 @@ import (
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/overlay/tio"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
netroute "golang.org/x/net/route"
|
||||
@@ -66,6 +66,26 @@ type tun struct {
|
||||
l *slog.Logger
|
||||
f *os.File
|
||||
fd int
|
||||
|
||||
readBuf []byte
|
||||
batchRet [1][]byte
|
||||
}
|
||||
|
||||
func (t *tun) Read() ([][]byte, error) {
|
||||
n, err := t.readOne(t.readBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.batchRet[0] = t.readBuf[:n]
|
||||
return t.batchRet[:], nil
|
||||
}
|
||||
|
||||
func (t *tun) WriteFromSelf(p []byte) (int, error) {
|
||||
return t.Write(p)
|
||||
}
|
||||
|
||||
func (t *tun) Readers() []tio.Queue {
|
||||
return []tio.Queue{t}
|
||||
}
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
@@ -102,6 +122,7 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
readBuf: make([]byte, defaultBatchBufSize),
|
||||
}
|
||||
|
||||
err = t.reload(c, true)
|
||||
@@ -141,7 +162,7 @@ func (t *tun) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
func (t *tun) readOne(to []byte) (int, error) {
|
||||
rc, err := t.f.SyscallConn()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err)
|
||||
@@ -394,8 +415,8 @@ func (t *tun) SupportsMultiqueue() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
|
||||
func (t *tun) NewMultiQueueReader() error {
|
||||
return fmt.Errorf("TODO: multiqueue not implemented for netbsd")
|
||||
}
|
||||
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
|
||||
@@ -6,7 +6,6 @@ package overlay
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"os"
|
||||
@@ -17,6 +16,7 @@ import (
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/overlay/tio"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
netroute "golang.org/x/net/route"
|
||||
@@ -59,6 +59,22 @@ type tun struct {
|
||||
fd int
|
||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||
out []byte
|
||||
|
||||
readBuf []byte
|
||||
batchRet [1][]byte
|
||||
}
|
||||
|
||||
func (t *tun) Read() ([][]byte, error) {
|
||||
n, err := t.readOne(t.readBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.batchRet[0] = t.readBuf[:n]
|
||||
return t.batchRet[:], nil
|
||||
}
|
||||
|
||||
func (t *tun) WriteFromSelf(p []byte) (int, error) {
|
||||
return t.Write(p)
|
||||
}
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
@@ -95,6 +111,7 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
readBuf: make([]byte, defaultBatchBufSize),
|
||||
}
|
||||
|
||||
err = t.reload(c, true)
|
||||
@@ -124,7 +141,7 @@ func (t *tun) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
func (t *tun) readOne(to []byte) (int, error) {
|
||||
buf := make([]byte, len(to)+4)
|
||||
|
||||
n, err := t.f.Read(buf)
|
||||
@@ -314,8 +331,8 @@ func (t *tun) SupportsMultiqueue() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
|
||||
func (t *tun) NewMultiQueueReader() error {
|
||||
return fmt.Errorf("TODO: multiqueue not implemented for openbsd")
|
||||
}
|
||||
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
@@ -366,6 +383,10 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
||||
return
|
||||
}
|
||||
|
||||
func (t *tun) Readers() []tio.Queue {
|
||||
return []tio.Queue{t}
|
||||
}
|
||||
|
||||
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/overlay/tio"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
@@ -27,6 +28,17 @@ type TestTun struct {
|
||||
closed atomic.Bool
|
||||
rxPackets chan []byte // Packets to receive into nebula
|
||||
TxPackets chan []byte // Packets transmitted outside by nebula
|
||||
|
||||
batchRet [1][]byte
|
||||
}
|
||||
|
||||
func (t *TestTun) Read() ([][]byte, error) {
|
||||
p, ok := <-t.rxPackets
|
||||
if !ok {
|
||||
return nil, os.ErrClosed
|
||||
}
|
||||
t.batchRet[0] = p
|
||||
return t.batchRet[:], nil
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
|
||||
@@ -116,6 +128,10 @@ func (t *TestTun) Write(b []byte) (n int, err error) {
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (t *TestTun) WriteFromSelf(b []byte) (int, error) {
|
||||
return t.Write(b)
|
||||
}
|
||||
|
||||
func (t *TestTun) Close() error {
|
||||
if t.closed.CompareAndSwap(false, true) {
|
||||
close(t.rxPackets)
|
||||
@@ -124,19 +140,14 @@ func (t *TestTun) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TestTun) Read(b []byte) (int, error) {
|
||||
p, ok := <-t.rxPackets
|
||||
if !ok {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
copy(b, p)
|
||||
return len(p), nil
|
||||
func (t *TestTun) Readers() []tio.Queue {
|
||||
return []tio.Queue{t}
|
||||
}
|
||||
|
||||
func (t *TestTun) SupportsMultiqueue() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented")
|
||||
func (t *TestTun) NewMultiQueueReader() error {
|
||||
return fmt.Errorf("TODO: multiqueue not implemented")
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ package overlay
|
||||
import (
|
||||
"crypto"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"os"
|
||||
@@ -18,6 +17,7 @@ import (
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/overlay/tio"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/slackhq/nebula/wintun"
|
||||
@@ -36,6 +36,22 @@ type winTun struct {
|
||||
l *slog.Logger
|
||||
|
||||
tun *wintun.NativeTun
|
||||
|
||||
readBuf []byte
|
||||
batchRet [1][]byte
|
||||
}
|
||||
|
||||
func (t *winTun) Read() ([][]byte, error) {
|
||||
n, err := t.tun.Read(t.readBuf, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.batchRet[0] = t.readBuf[:n]
|
||||
return t.batchRet[:], nil
|
||||
}
|
||||
|
||||
func (t *winTun) WriteFromSelf(p []byte) (int, error) {
|
||||
return t.Write(p)
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (Device, error) {
|
||||
@@ -55,6 +71,7 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*w
|
||||
}
|
||||
|
||||
t := &winTun{
|
||||
readBuf: make([]byte, defaultBatchBufSize),
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
@@ -229,10 +246,6 @@ func (t *winTun) Name() string {
|
||||
return t.Device
|
||||
}
|
||||
|
||||
func (t *winTun) Read(b []byte) (int, error) {
|
||||
return t.tun.Read(b, 0)
|
||||
}
|
||||
|
||||
func (t *winTun) Write(b []byte) (int, error) {
|
||||
return t.tun.Write(b, 0)
|
||||
}
|
||||
@@ -241,8 +254,12 @@ func (t *winTun) SupportsMultiqueue() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
|
||||
func (t *winTun) NewMultiQueueReader() error {
|
||||
return fmt.Errorf("TODO: multiqueue not implemented for windows")
|
||||
}
|
||||
|
||||
func (t *winTun) Readers() []tio.Queue {
|
||||
return []tio.Queue{t}
|
||||
}
|
||||
|
||||
func (t *winTun) Close() error {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/overlay/tio"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
@@ -23,17 +24,34 @@ func NewUserDevice(vpnNetworks []netip.Prefix) (Device, error) {
|
||||
outboundWriter: ow,
|
||||
inboundReader: ir,
|
||||
inboundWriter: iw,
|
||||
numReaders: 1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type UserDevice struct {
|
||||
vpnNetworks []netip.Prefix
|
||||
numReaders int
|
||||
|
||||
outboundReader *io.PipeReader
|
||||
outboundWriter *io.PipeWriter
|
||||
|
||||
inboundReader *io.PipeReader
|
||||
inboundWriter *io.PipeWriter
|
||||
|
||||
readBuf []byte
|
||||
batchRet [1][]byte
|
||||
}
|
||||
|
||||
func (d *UserDevice) Read() ([][]byte, error) {
|
||||
if d.readBuf == nil {
|
||||
d.readBuf = make([]byte, defaultBatchBufSize)
|
||||
}
|
||||
n, err := d.outboundReader.Read(d.readBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.batchRet[0] = d.readBuf[:n]
|
||||
return d.batchRet[:], nil
|
||||
}
|
||||
|
||||
func (d *UserDevice) Activate() error {
|
||||
@@ -50,20 +68,29 @@ func (d *UserDevice) SupportsMultiqueue() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return d, nil
|
||||
func (d *UserDevice) NewMultiQueueReader() error {
|
||||
d.numReaders++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *UserDevice) Readers() []tio.Queue {
|
||||
out := make([]tio.Queue, d.numReaders)
|
||||
for i := range d.numReaders {
|
||||
out[i] = d
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) {
|
||||
return d.inboundReader, d.outboundWriter
|
||||
}
|
||||
|
||||
func (d *UserDevice) Read(p []byte) (n int, err error) {
|
||||
return d.outboundReader.Read(p)
|
||||
}
|
||||
func (d *UserDevice) Write(p []byte) (n int, err error) {
|
||||
return d.inboundWriter.Write(p)
|
||||
}
|
||||
func (d *UserDevice) WriteFromSelf(p []byte) (n int, err error) {
|
||||
return d.Write(p)
|
||||
}
|
||||
func (d *UserDevice) Close() error {
|
||||
d.inboundWriter.Close()
|
||||
d.outboundWriter.Close()
|
||||
|
||||
24
udp/conn.go
24
udp/conn.go
@@ -8,6 +8,12 @@ import (
|
||||
|
||||
const MTU = 9001
|
||||
|
||||
// MaxWriteBatch is the largest batch any Conn.WriteBatch implementation is
|
||||
// required to accept. Callers SHOULD NOT pass more than this per call; Linux
|
||||
// backends preallocate sendmmsg scratch sized to this value, so exceeding it
|
||||
// only costs a chunked retry.
|
||||
const MaxWriteBatch = 128
|
||||
|
||||
type EncReader func(
|
||||
addr netip.AddrPort,
|
||||
payload []byte,
|
||||
@@ -16,8 +22,19 @@ type EncReader func(
|
||||
type Conn interface {
|
||||
Rebind() error
|
||||
LocalAddr() (netip.AddrPort, error)
|
||||
ListenOut(r EncReader) error
|
||||
// ListenOut invokes r for each received packet. On batch-capable
|
||||
// backends (recvmmsg), flush is called after each batch is fully
|
||||
// delivered — callers use it to flush per-batch accumulators such as
|
||||
// TUN write coalescers. Single-packet backends call flush after each
|
||||
// packet. flush must not be nil.
|
||||
ListenOut(r EncReader, flush func()) error
|
||||
WriteTo(b []byte, addr netip.AddrPort) error
|
||||
// WriteBatch sends a contiguous batch of packets, each with its own
|
||||
// destination. bufs and addrs must have the same length. Linux uses
|
||||
// sendmmsg(2) for a single syscall; other backends fall back to a
|
||||
// WriteTo loop. Returns on the first error; callers may observe a
|
||||
// partial send if some packets went out before the error.
|
||||
WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error
|
||||
ReloadConfig(c *config.C)
|
||||
SupportsMultipleReaders() bool
|
||||
Close() error
|
||||
@@ -31,7 +48,7 @@ func (NoopConn) Rebind() error {
|
||||
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
|
||||
return netip.AddrPort{}, nil
|
||||
}
|
||||
func (NoopConn) ListenOut(_ EncReader) error {
|
||||
func (NoopConn) ListenOut(_ EncReader, _ func()) error {
|
||||
return nil
|
||||
}
|
||||
func (NoopConn) SupportsMultipleReaders() bool {
|
||||
@@ -40,6 +57,9 @@ func (NoopConn) SupportsMultipleReaders() bool {
|
||||
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
||||
return nil
|
||||
}
|
||||
func (NoopConn) WriteBatch(_ [][]byte, _ []netip.AddrPort) error {
|
||||
return nil
|
||||
}
|
||||
func (NoopConn) ReloadConfig(_ *config.C) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -140,6 +140,15 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error {
|
||||
for i, b := range bufs {
|
||||
if err := u.WriteTo(b, addrs[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
||||
a := u.UDPConn.LocalAddr()
|
||||
|
||||
@@ -165,7 +174,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() {
|
||||
return func() {}
|
||||
}
|
||||
|
||||
func (u *StdConn) ListenOut(r EncReader) error {
|
||||
func (u *StdConn) ListenOut(r EncReader, flush func()) error {
|
||||
buffer := make([]byte, MTU)
|
||||
|
||||
for {
|
||||
@@ -180,6 +189,7 @@ func (u *StdConn) ListenOut(r EncReader) error {
|
||||
}
|
||||
|
||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
||||
flush()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -44,6 +44,15 @@ func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (u *GenericConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error {
|
||||
for i, b := range bufs {
|
||||
if _, err := u.UDPConn.WriteToUDPAddrPort(b, addrs[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *GenericConn) LocalAddr() (netip.AddrPort, error) {
|
||||
a := u.UDPConn.LocalAddr()
|
||||
|
||||
@@ -73,7 +82,7 @@ type rawMessage struct {
|
||||
Len uint32
|
||||
}
|
||||
|
||||
func (u *GenericConn) ListenOut(r EncReader) error {
|
||||
func (u *GenericConn) ListenOut(r EncReader, flush func()) error {
|
||||
buffer := make([]byte, MTU)
|
||||
|
||||
var lastRecvErr time.Time
|
||||
@@ -94,6 +103,7 @@ func (u *GenericConn) ListenOut(r EncReader) error {
|
||||
}
|
||||
|
||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
||||
flush()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -171,7 +171,7 @@ func recvmmsg(fd uintptr, msgs []rawMessage) (int, bool, error) {
|
||||
return int(n), true, nil
|
||||
}
|
||||
|
||||
func (u *StdConn) listenOutSingle(r EncReader) error {
|
||||
func (u *StdConn) listenOutSingle(r EncReader, flush func()) error {
|
||||
var err error
|
||||
var n int
|
||||
var from netip.AddrPort
|
||||
@@ -184,15 +184,17 @@ func (u *StdConn) listenOutSingle(r EncReader) error {
|
||||
}
|
||||
from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port())
|
||||
r(from, buffer[:n])
|
||||
flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (u *StdConn) listenOutBatch(r EncReader) error {
|
||||
func (u *StdConn) listenOutBatch(r EncReader, flush func()) error {
|
||||
var ip netip.Addr
|
||||
var n int
|
||||
var operr error
|
||||
|
||||
msgs, buffers, names := u.PrepareRawMessages(u.batch)
|
||||
bufSize := MTU
|
||||
msgs, buffers, names := u.PrepareRawMessages(u.batch, bufSize)
|
||||
|
||||
//reader needs to capture variables from this function, since it's used as a lambda with rawConn.Read
|
||||
//defining it outside the loop so it gets re-used
|
||||
@@ -217,16 +219,22 @@ func (u *StdConn) listenOutBatch(r EncReader) error {
|
||||
} else {
|
||||
ip, _ = netip.AddrFromSlice(names[i][8:24])
|
||||
}
|
||||
r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len])
|
||||
from := netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
|
||||
payload := buffers[i][:msgs[i].Len]
|
||||
|
||||
r(from, payload)
|
||||
}
|
||||
// End-of-batch: let callers (e.g. TUN write coalescer) flush any
|
||||
// state they accumulated across this batch.
|
||||
flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (u *StdConn) ListenOut(r EncReader) error {
|
||||
func (u *StdConn) ListenOut(r EncReader, flush func()) error {
|
||||
if u.batch == 1 {
|
||||
return u.listenOutSingle(r)
|
||||
return u.listenOutSingle(r, flush)
|
||||
} else {
|
||||
return u.listenOutBatch(r)
|
||||
return u.listenOutBatch(r, flush)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -235,6 +243,19 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error {
|
||||
if len(bufs) != len(addrs) {
|
||||
return fmt.Errorf("WriteBatch: len(bufs)=%d != len(addrs)=%d", len(bufs), len(addrs))
|
||||
}
|
||||
//todo use sendmmsg
|
||||
for i := 0; i < len(bufs); i++ {
|
||||
if _, err := u.udpConn.WriteToUDPAddrPort(bufs[i], addrs[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *StdConn) ReloadConfig(c *config.C) {
|
||||
b := c.GetInt("listen.read_buffer", 0)
|
||||
if b > 0 {
|
||||
|
||||
@@ -30,13 +30,13 @@ type rawMessage struct {
|
||||
Len uint32
|
||||
}
|
||||
|
||||
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
||||
func (u *StdConn) PrepareRawMessages(n, bufSize int) ([]rawMessage, [][]byte, [][]byte) {
|
||||
msgs := make([]rawMessage, n)
|
||||
buffers := make([][]byte, n)
|
||||
names := make([][]byte, n)
|
||||
|
||||
for i := range msgs {
|
||||
buffers[i] = make([]byte, MTU)
|
||||
buffers[i] = make([]byte, bufSize)
|
||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||
|
||||
vs := []iovec{
|
||||
@@ -52,3 +52,19 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
||||
|
||||
return msgs, buffers, names
|
||||
}
|
||||
|
||||
func setIovLen(v *iovec, n int) {
|
||||
v.Len = uint32(n)
|
||||
}
|
||||
|
||||
func setMsgIovlen(m *msghdr, n int) {
|
||||
m.Iovlen = uint32(n)
|
||||
}
|
||||
|
||||
func setMsgControllen(m *msghdr, n int) {
|
||||
m.Controllen = uint32(n)
|
||||
}
|
||||
|
||||
func setCmsgLen(h *unix.Cmsghdr, n int) {
|
||||
h.Len = uint32(n)
|
||||
}
|
||||
|
||||
@@ -33,13 +33,13 @@ type rawMessage struct {
|
||||
Pad0 [4]byte
|
||||
}
|
||||
|
||||
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
||||
func (u *StdConn) PrepareRawMessages(n, bufSize int) ([]rawMessage, [][]byte, [][]byte) {
|
||||
msgs := make([]rawMessage, n)
|
||||
buffers := make([][]byte, n)
|
||||
names := make([][]byte, n)
|
||||
|
||||
for i := range msgs {
|
||||
buffers[i] = make([]byte, MTU)
|
||||
buffers[i] = make([]byte, bufSize)
|
||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||
|
||||
vs := []iovec{
|
||||
@@ -55,3 +55,19 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
||||
|
||||
return msgs, buffers, names
|
||||
}
|
||||
|
||||
func setIovLen(v *iovec, n int) {
|
||||
v.Len = uint64(n)
|
||||
}
|
||||
|
||||
func setMsgIovlen(m *msghdr, n int) {
|
||||
m.Iovlen = uint64(n)
|
||||
}
|
||||
|
||||
func setMsgControllen(m *msghdr, n int) {
|
||||
m.Controllen = uint64(n)
|
||||
}
|
||||
|
||||
func setCmsgLen(h *unix.Cmsghdr, n int) {
|
||||
h.Len = uint64(n)
|
||||
}
|
||||
|
||||
@@ -140,7 +140,7 @@ func (u *RIOConn) bind(l *slog.Logger, sa windows.Sockaddr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *RIOConn) ListenOut(r EncReader) error {
|
||||
func (u *RIOConn) ListenOut(r EncReader, flush func()) error {
|
||||
buffer := make([]byte, MTU)
|
||||
|
||||
var lastRecvErr time.Time
|
||||
@@ -162,6 +162,7 @@ func (u *RIOConn) ListenOut(r EncReader) error {
|
||||
}
|
||||
|
||||
r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n])
|
||||
flush()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -316,6 +317,15 @@ func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error {
|
||||
return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
|
||||
}
|
||||
|
||||
func (u *RIOConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error {
|
||||
for i, b := range bufs {
|
||||
if err := u.WriteTo(b, addrs[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
|
||||
sa, err := windows.Getsockname(u.sock)
|
||||
if err != nil {
|
||||
|
||||
@@ -122,13 +122,23 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (u *TesterConn) ListenOut(r EncReader) error {
|
||||
func (u *TesterConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error {
|
||||
for i, b := range bufs {
|
||||
if err := u.WriteTo(b, addrs[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *TesterConn) ListenOut(r EncReader, flush func()) error {
|
||||
for {
|
||||
select {
|
||||
case <-u.done:
|
||||
return os.ErrClosed
|
||||
case p := <-u.RxPackets:
|
||||
r(p.From, p.Data)
|
||||
flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user