diff --git a/inside.go b/inside.go index 68cb38ec..2b0ad70c 100644 --- a/inside.go +++ b/inside.go @@ -9,10 +9,11 @@ import ( "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/noiseutil" + "github.com/slackhq/nebula/overlay/batch" "github.com/slackhq/nebula/routing" ) -func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, sendBatch batch.TxBatcher, rejectBuf []byte, q int, localCache firewall.ConntrackCache) { err := newPacket(packet, false, fwPacket) if err != nil { if f.l.Enabled(context.Background(), slog.LevelDebug) { @@ -37,7 +38,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet // routes packets from the Nebula addr to the Nebula addr through the Nebula // TUN device. if immediatelyForwardToSelf { - _, err := f.readers[q].Write(packet) + _, err := f.readers[q].WriteFromSelf(packet) if err != nil { f.l.Error("Failed to forward to tun", "error", err) } @@ -57,7 +58,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet }) if hostinfo == nil { - f.rejectInside(packet, out, q) + f.rejectInside(packet, rejectBuf, q) if f.l.Enabled(context.Background(), slog.LevelDebug) { f.l.Debug("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks", "vpnAddr", fwPacket.RemoteAddr, @@ -73,10 +74,9 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { - f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) - + f.sendInsideMessage(hostinfo, packet, nb, sendBatch, rejectBuf, q) } else { - f.rejectInside(packet, out, q) + f.rejectInside(packet, rejectBuf, q) if f.l.Enabled(context.Background(), slog.LevelDebug) { hostinfo.logger(f.l).Debug("dropping outbound packet", "fwPacket", fwPacket, @@ -86,6 +86,67 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } } +// sendInsideMessage encrypts a firewall-approved inside packet into the +// caller's batch slot for later sendmmsg flush. When hostinfo.remote is not +// valid we fall through to the relay slow path via the unbatched sendNoMetrics +// so relay behavior is unchanged. +func (f *Interface) sendInsideMessage(hostinfo *HostInfo, p, nb []byte, sendBatch batch.TxBatcher, rejectBuf []byte, q int) { + ci := hostinfo.ConnectionState + if ci.eKey == nil { + return + } + + if !hostinfo.remote.IsValid() { + // Slow path: relay fallback. Reuse rejectBuf as the ciphertext + // scratch; sendNoMetrics arranges header space for SendVia. + f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, p, nb, rejectBuf, q) + return + } + + scratch := sendBatch.Next() + if scratch == nil { + // Batch full: bypass batching and send this packet directly so we + // never drop traffic on over-subscribed iterations. + f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, p, nb, rejectBuf, q) + return + } + + if noiseutil.EncryptLockNeeded { + ci.writeLock.Lock() + } + c := ci.messageCounter.Add(1) + + out := header.Encode(scratch, header.Version, header.Message, 0, hostinfo.remoteIndexId, c) + f.connectionManager.Out(hostinfo) + + if hostinfo.lastRebindCount != f.rebindCount { + //NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is + // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. + f.lightHouse.QueryServer(hostinfo.vpnAddrs[0]) + hostinfo.lastRebindCount = f.rebindCount + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Lighthouse update triggered for punch due to rebind counter", + "vpnAddrs", hostinfo.vpnAddrs, + ) + } + } + + out, err := ci.eKey.EncryptDanger(out, out, p, c, nb) + if noiseutil.EncryptLockNeeded { + ci.writeLock.Unlock() + } + if err != nil { + hostinfo.logger(f.l).Error("Failed to encrypt outgoing packet", + "error", err, + "udpAddr", hostinfo.remote, + "counter", c, + ) + return + } + + sendBatch.Commit(len(out), hostinfo.remote) +} + func (f *Interface) rejectInside(packet []byte, out []byte, q int) { if !f.firewall.InSendReject { return @@ -96,7 +157,7 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) { return } - _, err := f.readers[q].Write(out) + _, err := f.readers[q].WriteFromSelf(out) if err != nil { f.l.Error("Failed to write to tun", "error", err) } diff --git a/interface.go b/interface.go index 5fedcdd3..bc7e24d1 100644 --- a/interface.go +++ b/interface.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io" "log/slog" "net/netip" "sync" @@ -18,6 +17,8 @@ import ( "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/overlay" + "github.com/slackhq/nebula/overlay/batch" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/udp" ) @@ -88,8 +89,12 @@ type Interface struct { ctx context.Context writers []udp.Conn - readers []io.ReadWriteCloser - wg sync.WaitGroup + readers []tio.Queue + // batchers is one per tun queue, wrapping readers[i]. + // decryptToTun sends plaintext into the batch.RxBatcher; + // listenOut calls its Flush at the end of each UDP recvmmsg batch. + batchers []batch.RxBatcher + wg sync.WaitGroup // fatalErr holds the first unexpected reader error that caused shutdown. // nil means "no fatal error" (yet) @@ -187,7 +192,8 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { routines: c.routines, version: c.version, writers: make([]udp.Conn, c.routines), - readers: make([]io.ReadWriteCloser, c.routines), + readers: make([]tio.Queue, c.routines), + batchers: make([]batch.RxBatcher, c.routines), myVpnNetworks: cs.myVpnNetworks, myVpnNetworksTable: cs.myVpnNetworksTable, myVpnAddrs: cs.myVpnAddrs, @@ -245,15 +251,16 @@ func (f *Interface) activate() error { metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines)) // Prepare n tun queues - var reader io.ReadWriteCloser = f.inside for i := 0; i < f.routines; i++ { if i > 0 { - reader, err = f.inside.NewMultiQueueReader() - if err != nil { + if err = f.inside.NewMultiQueueReader(); err != nil { return err } } - f.readers[i] = reader + } + f.readers = f.inside.Readers() + for i := range f.readers { + f.batchers[i] = batch.NewPassthrough(f.readers[i]) } f.wg.Add(1) // for us to wait on Close() to return @@ -311,14 +318,24 @@ func (f *Interface) listenOut(i int) { ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout) lhh := f.lightHouse.NewRequestHandler() - plaintext := make([]byte, udp.MTU) h := &header.H{} fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) - err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { + coalescer := f.batchers[i] + + listener := func(fromUdpAddr netip.AddrPort, payload []byte) { + plaintext := f.batchers[i].Reserve(len(payload)) f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get()) - }) + } + + flusher := func() { + if err := coalescer.Flush(); err != nil { + f.l.Error("Failed to flush tun coalescer", "error", err) + } + } + + err := li.ListenOut(listener, flusher) if err != nil && !f.closed.Load() { f.l.Error("Error while reading inbound packet, closing", "error", err) @@ -328,16 +345,16 @@ func (f *Interface) listenOut(i int) { f.l.Debug("underlay reader is done", "reader", i) } -func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { - packet := make([]byte, mtu) - out := make([]byte, mtu) +func (f *Interface) listenIn(reader tio.Queue, i int) { + rejectBuf := make([]byte, mtu) + sb := batch.NewSendBatch(batch.SendBatchCap, udp.MTU+32) fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout) for { - n, err := reader.Read(packet) + pkts, err := reader.Read() if err != nil { if !f.closed.Load() { f.l.Error("Error while reading outbound packet, closing", "error", err, "reader", i) @@ -346,12 +363,29 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { break } - f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get()) + sb.Reset() + for _, pkt := range pkts { + if sb.Len() >= sb.Cap() { + f.flushBatch(sb, i) + sb.Reset() + } + f.consumeInsidePacket(pkt, fwPacket, nb, sb, rejectBuf, i, conntrackCache.Get()) + } + if sb.Len() > 0 { + f.flushBatch(sb, i) + } } f.l.Debug("overlay reader is done", "reader", i) } +func (f *Interface) flushBatch(sb batch.TxBatcher, q int) { + bufs, dsts := sb.Get() + if err := f.writers[q].WriteBatch(bufs, dsts); err != nil { + f.l.Error("Failed to write outgoing batch", "error", err, "writer", q) + } +} + func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { c.RegisterReloadCallback(f.reloadFirewall) c.RegisterReloadCallback(f.reloadSendRecvError) diff --git a/outside.go b/outside.go index 1e00a0a9..54104416 100644 --- a/outside.go +++ b/outside.go @@ -572,7 +572,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out } f.connectionManager.In(hostinfo) - _, err = f.readers[q].Write(out) + err = f.batchers[q].Commit(out) if err != nil { f.l.Error("Failed to write to tun", "error", err) } diff --git a/overlay/batch/batch.go b/overlay/batch/batch.go new file mode 100644 index 00000000..925d6044 --- /dev/null +++ b/overlay/batch/batch.go @@ -0,0 +1,33 @@ +package batch + +import "net/netip" + +type RxBatcher interface { + // Reserve creates a pkt to borrow + Reserve(sz int) []byte + // Commit borrows pkt. The caller must keep pkt valid until the next Flush + Commit(pkt []byte) error + // Flush emits every queued packet in arrival order. Returns the + // first error observed; keeps draining so one bad packet doesn't hold up + // the rest. After Flush returns, borrowed payload slices may be recycled. + Flush() error +} + +type TxBatcher interface { + // Next returns a zero-length slice with slotCap capacity over the next unused + // slot's backing bytes. The caller writes into the returned slice and then + // calls Commit with the final length and destination. Next returns nil when + // the batch is full. + Next() []byte + // Commit records the slot just returned by Next as a packet of length n + // destined for dst. + Commit(n int, dst netip.AddrPort) + // Reset clears committed slots; backing storage is retained for reuse. + Reset() + // Len returns the number of committed packets. + Len() int + // Cap returns the maximum number of slots in the batch. + Cap() int + // Get returns the buffers needed to send the batch + Get() ([][]byte, []netip.AddrPort) +} diff --git a/overlay/batch/passthrough.go b/overlay/batch/passthrough.go new file mode 100644 index 00000000..d971c303 --- /dev/null +++ b/overlay/batch/passthrough.go @@ -0,0 +1,57 @@ +package batch + +import ( + "io" + + "github.com/slackhq/nebula/udp" +) + +// Passthrough is a RxBatcher that doesn't batch anything, it just accumulates and then sends packets. +type Passthrough struct { + out io.Writer + slots [][]byte + backing []byte + cursor int +} + +func NewPassthrough(w io.Writer) *Passthrough { + const baseNumSlots = 128 + return &Passthrough{ + out: w, + slots: make([][]byte, 0, baseNumSlots), + backing: make([]byte, 0, baseNumSlots*udp.MTU), + } +} + +func (p *Passthrough) Reserve(sz int) []byte { + if len(p.backing)+sz > cap(p.backing) { + // Grow: allocate a fresh backing. Already-committed slices still + // reference the old array and remain valid until Flush drops them. + newCap := max(cap(p.backing)*2, sz) + p.backing = make([]byte, 0, newCap) + } + start := len(p.backing) + p.backing = p.backing[:start+sz] + return p.backing[start : start+sz : start+sz] //return zero length, sz-cap slice +} + +func (p *Passthrough) Commit(pkt []byte) error { + p.slots = append(p.slots, pkt) + return nil +} + +func (p *Passthrough) Flush() error { + var firstErr error + for _, s := range p.slots { + _, err := p.out.Write(s) + if err != nil && firstErr == nil { + firstErr = err + } + } + for i := range p.slots { + p.slots[i] = nil + } + p.slots = p.slots[:0] + p.backing = p.backing[:0] + return firstErr +} diff --git a/overlay/batch/tx_batch.go b/overlay/batch/tx_batch.go new file mode 100644 index 00000000..cac441d9 --- /dev/null +++ b/overlay/batch/tx_batch.go @@ -0,0 +1,61 @@ +package batch + +import "net/netip" + +const SendBatchCap = 128 + +// SendBatch accumulates encrypted UDP packets for potential TX offloading. +// One SendBatch is owned by each listenIn goroutine; no locking is needed. +// The backing storage holds up to batchCap packets of slotCap bytes each; +// bufs and dsts are parallel slices of committed slots. +type SendBatch struct { + bufs [][]byte + dsts []netip.AddrPort + backing []byte + slotCap int + batchCap int + nextSlot int +} + +func NewSendBatch(batchCap, slotCap int) *SendBatch { + return &SendBatch{ + bufs: make([][]byte, 0, batchCap), + dsts: make([]netip.AddrPort, 0, batchCap), + backing: make([]byte, batchCap*slotCap), + slotCap: slotCap, + batchCap: batchCap, + } +} + +func (b *SendBatch) Next() []byte { + if b.nextSlot >= b.batchCap { + return nil + } + start := b.nextSlot * b.slotCap + return b.backing[start : start : start+b.slotCap] //set len to 0 but cap to slotCap +} + +func (b *SendBatch) Commit(n int, dst netip.AddrPort) { + start := b.nextSlot * b.slotCap + b.bufs = append(b.bufs, b.backing[start:start+n]) + b.dsts = append(b.dsts, dst) + b.nextSlot++ +} + +func (b *SendBatch) Reset() { + b.bufs = b.bufs[:0] + b.dsts = b.dsts[:0] + b.nextSlot = 0 +} + +func (b *SendBatch) Len() int { + return len(b.bufs) +} + +func (b *SendBatch) Cap() int { + return b.batchCap +} + +func (b *SendBatch) Get() ([][]byte, []netip.AddrPort) { + return b.bufs, b.dsts +} diff --git a/overlay/batch/tx_batch_test.go b/overlay/batch/tx_batch_test.go new file mode 100644 index 00000000..32412492 --- /dev/null +++ b/overlay/batch/tx_batch_test.go @@ -0,0 +1,69 @@ +package batch + +import ( + "net/netip" + "testing" +) + +func TestSendBatchBookkeeping(t *testing.T) { + b := NewSendBatch(4, 32) + if b.Len() != 0 || b.Cap() != 4 { + t.Fatalf("fresh batch: len=%d cap=%d", b.Len(), b.Cap()) + } + + ap := netip.MustParseAddrPort("10.0.0.1:4242") + for i := 0; i < 4; i++ { + slot := b.Next() + if slot == nil { + t.Fatalf("slot %d: Next returned nil before cap", i) + } + if cap(slot) != 32 || len(slot) != 0 { + t.Fatalf("slot %d: got len=%d cap=%d want len=0 cap=32", i, len(slot), cap(slot)) + } + // Write a marker byte. + slot = append(slot, byte(i), byte(i+1), byte(i+2)) + b.Commit(len(slot), ap) + } + if b.Next() != nil { + t.Fatalf("Next should return nil when full") + } + if b.Len() != 4 { + t.Fatalf("Len=%d want 4", b.Len()) + } + for i, buf := range b.bufs { + if len(buf) != 3 || buf[0] != byte(i) { + t.Errorf("buf %d: %x", i, buf) + } + if b.dsts[i] != ap { + t.Errorf("dst %d: got %v want %v", i, b.dsts[i], ap) + } + } + + // Reset returns empty and Next works again. + b.Reset() + if b.Len() != 0 { + t.Fatalf("after Reset Len=%d want 0", b.Len()) + } + slot := b.Next() + if slot == nil || cap(slot) != 32 { + t.Fatalf("after Reset Next nil or wrong cap: %v cap=%d", slot == nil, cap(slot)) + } +} + +func TestSendBatchSlotsDoNotOverlap(t *testing.T) { + b := NewSendBatch(3, 8) + ap := netip.MustParseAddrPort("10.0.0.1:80") + + // Fill three slots, each with its own sentinel byte. + for i := 0; i < 3; i++ { + s := b.Next() + s = append(s, byte(0xA0+i), byte(0xB0+i)) + b.Commit(len(s), ap) + } + + for i, buf := range b.bufs { + if buf[0] != byte(0xA0+i) || buf[1] != byte(0xB0+i) { + t.Errorf("slot %d corrupted: %x", i, buf) + } + } +} diff --git a/overlay/device.go b/overlay/device.go index b6077aba..f8181421 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -4,15 +4,21 @@ import ( "io" "net/netip" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" ) +// defaultBatchBufSize is the per-Queue scratch size for Read on backends +// that don't do TSO segmentation. 65535 covers any single IP packet. +const defaultBatchBufSize = 65535 + type Device interface { - io.ReadWriteCloser + io.Closer Activate() error Networks() []netip.Prefix Name() string RoutesFor(netip.Addr) routing.Gateways - SupportsMultiqueue() bool - NewMultiQueueReader() (io.ReadWriteCloser, error) + SupportsMultiqueue() bool //todo remove? + NewMultiQueueReader() error + Readers() []tio.Queue } diff --git a/overlay/overlaytest/noop.go b/overlay/overlaytest/noop.go index 956da7dd..99ef4d15 100644 --- a/overlay/overlaytest/noop.go +++ b/overlay/overlaytest/noop.go @@ -4,9 +4,9 @@ package overlaytest import ( "errors" - "io" "net/netip" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" ) @@ -31,20 +31,28 @@ func (NoopTun) Name() string { return "noop" } -func (NoopTun) Read([]byte) (int, error) { - return 0, nil +func (NoopTun) Read() ([][]byte, error) { + return nil, nil } func (NoopTun) Write([]byte) (int, error) { return 0, nil } +func (NoopTun) WriteFromSelf(p []byte) (int, error) { + return 0, nil +} + func (NoopTun) SupportsMultiqueue() bool { return false } -func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, errors.New("unsupported") +func (NoopTun) NewMultiQueueReader() error { + return errors.New("unsupported") +} + +func (NoopTun) Readers() []tio.Queue { + return []tio.Queue{NoopTun{}} } func (NoopTun) Close() error { diff --git a/overlay/tio/container_poll_linux.go b/overlay/tio/container_poll_linux.go new file mode 100644 index 00000000..fa6367e7 --- /dev/null +++ b/overlay/tio/container_poll_linux.go @@ -0,0 +1,69 @@ +package tio + +import ( + "encoding/binary" + "errors" + "fmt" + + "golang.org/x/sys/unix" +) + +type pollContainer struct { + pq []*Poll + // pqi is exactly the same as pq, but stored as the interface type + pqi []Queue + shutdownFd int +} + +func NewPollContainer() (Container, error) { + shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC) + if err != nil { + return nil, fmt.Errorf("failed to create eventfd: %w", err) + } + + out := &pollContainer{ + pq: []*Poll{}, + pqi: []Queue{}, + shutdownFd: shutdownFd, + } + + return out, nil +} + +func (c *pollContainer) Queues() []Queue { + return c.pqi +} + +func (c *pollContainer) Add(fd int) error { + x, err := newPoll(fd, c.shutdownFd) + if err != nil { + return err + } + c.pq = append(c.pq, x) + c.pqi = append(c.pqi, x) + + return nil +} + +func (c *pollContainer) wakeForShutdown() error { + var buf [8]byte + binary.NativeEndian.PutUint64(buf[:], 1) + _, err := unix.Write(int(c.shutdownFd), buf[:]) + return err +} + +func (c *pollContainer) Close() error { + errs := []error{} + + if err := c.wakeForShutdown(); err != nil { + errs = append(errs, err) + } + + for _, x := range c.pq { + if err := x.Close(); err != nil { + errs = append(errs, err) + } + } + + return errors.Join(errs...) +} diff --git a/overlay/tio/tio.go b/overlay/tio/tio.go new file mode 100644 index 00000000..792083ab --- /dev/null +++ b/overlay/tio/tio.go @@ -0,0 +1,65 @@ +package tio + +import "io" + +// defaultBatchBufSize is the per-Queue scratch size for Read on backends +// that don't do TSO segmentation. 65535 covers any single IP packet. +const defaultBatchBufSize = 65535 + +// Container holds one or many Queue objects and helps close them in an orderly way +type Container interface { + io.Closer + Queues() []Queue + + // Add takes a tun fd, adds it to the container, and prepares it for use as a Queue + Add(fd int) error +} + +// Queue is a readable/writable Poll queue. One Queue is driven by a single +// read goroutine plus concurrent writers (see Write / WriteReject below). +type Queue interface { + io.Closer + + // Read returns one or more packets. The returned slices are borrowed + // from the Queue's internal buffer and are only valid until the next + // Read or Close on this Queue - callers must encrypt or copy each + // slice before the next call. Not safe for concurrent Reads; exactly + // one goroutine per Queue reads. + Read() ([][]byte, error) + + // Write emits a single packet on the plaintext (outside→inside) + // delivery path. May run concurrently with WriteFromSelf on the same + // Queue, but not with itself. + Write(p []byte) (int, error) + + // WriteFromSelf writes a single packet that originated from the inside + // path (reject replies or self-forward) using scratch state distinct + // from Write, so it can run concurrently with Write on the same Queue + // without a data race. On backends without a shared-scratch Write, a + // trivial delegation to Write is acceptable. + WriteFromSelf(p []byte) (int, error) +} + +// GSOWriter is implemented by Queues that can emit a TCP TSO superpacket +// assembled from a header prefix plus one or more borrowed payload +// fragments, in a single vectored write (writev with a leading +// virtio_net_hdr). This lets the coalescer avoid copying payload bytes +// between the caller's decrypt buffer and the TUN. Backends without GSO +// support return false from GSOSupported and coalescing is skipped. +// +// hdr contains the IPv4/IPv6 + TCP header prefix (mutable - callers will +// have filled in total length and pseudo-header partial). pays are +// non-overlapping payload fragments whose concatenation is the full +// superpacket payload; they are read-only from the writer's perspective +// and must remain valid until the call returns. gsoSize is the MSS: +// every segment except possibly the last is exactly that many bytes. +// csumStart is the byte offset where the TCP header begins within hdr. +// +// # TODO fold into Queue +// +// hdr's TCP checksum field must already hold the pseudo-header partial +// sum (single-fold, not inverted), per virtio NEEDS_CSUM semantics. +type GSOWriter interface { + WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error + GSOSupported() bool +} diff --git a/overlay/tio/tio_poll_linux.go b/overlay/tio/tio_poll_linux.go new file mode 100644 index 00000000..17fb73de --- /dev/null +++ b/overlay/tio/tio_poll_linux.go @@ -0,0 +1,168 @@ +package tio + +import ( + "fmt" + "os" + "sync/atomic" + + "golang.org/x/sys/unix" +) + +// Maximum size we accept for a single read from a TUN with IFF_VNET_HDR. A +// TSO superpacket can be up to 64KiB of payload plus a single L2/L3/L4 header +// prefix plus the virtio header. +const tunReadBufSize = 65535 + +type Poll struct { + fd int + + readPoll [2]unix.PollFd + writePoll [2]unix.PollFd + closed atomic.Bool + + readBuf []byte + batchRet [1][]byte +} + +func newPoll(fd int, shutdownFd int) (*Poll, error) { + if err := unix.SetNonblock(fd, true); err != nil { + _ = unix.Close(fd) + return nil, fmt.Errorf("failed to set Poll device as nonblocking: %w", err) + } + + out := &Poll{ + fd: fd, + readBuf: make([]byte, tunReadBufSize), + readPoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLIN}, + {Fd: int32(shutdownFd), Events: unix.POLLIN}, + }, + writePoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLOUT}, + {Fd: int32(shutdownFd), Events: unix.POLLIN}, + }, + } + return out, nil +} + +// blockOnRead waits until the Poll fd is readable or shutdown has been signaled. +// Returns os.ErrClosed if Close was called. +func (t *Poll) blockOnRead() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error + for { + _, err = unix.Poll(t.readPoll[:], -1) + if err != unix.EINTR { + break + } + } + tunEvents := t.readPoll[0].Revents + shutdownEvents := t.readPoll[1].Revents + t.readPoll[0].Revents = 0 + t.readPoll[1].Revents = 0 + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } + if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil +} + +func (t *Poll) blockOnWrite() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error + for { + _, err = unix.Poll(t.writePoll[:], -1) + if err != unix.EINTR { + break + } + } + tunEvents := t.writePoll[0].Revents + shutdownEvents := t.writePoll[1].Revents + t.writePoll[0].Revents = 0 + t.writePoll[1].Revents = 0 + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } + if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil +} + +func (t *Poll) Read() ([][]byte, error) { + n, err := t.readOne(t.readBuf) + if err != nil { + return nil, err + } + t.batchRet[0] = t.readBuf[:n] + return t.batchRet[:], nil +} + +func (t *Poll) readOne(to []byte) (int, error) { + for { + n, errno := unix.Read(t.fd, to) + if errno == nil { + return n, nil + } + switch errno { + case unix.EAGAIN: + if err := t.blockOnRead(); err != nil { + return 0, err + } + case unix.EINTR: + // retry + case unix.EBADF: + return 0, os.ErrClosed + default: + return 0, errno + } + } +} + +// Write is only valid for single threaded use +func (t *Poll) Write(from []byte) (int, error) { + for { + n, errno := unix.Write(t.fd, from) + if errno == nil { + return n, nil + } + switch errno { + case unix.EAGAIN: + if err := t.blockOnWrite(); err != nil { + return 0, err + } + case unix.EINTR: + // retry + case unix.EBADF: + return 0, os.ErrClosed + default: + return 0, errno + } + } +} + +func (t *Poll) Close() error { + if t.closed.Swap(true) { + return nil + } + //shutdownFd is owned by the container, so we should not close it + var err error + if t.fd >= 0 { + err = unix.Close(t.fd) + t.fd = -1 + } + + return err +} + +func (t *Poll) WriteFromSelf(p []byte) (int, error) { + return t.Write(p) +} diff --git a/overlay/tio/tun_file_linux_test.go b/overlay/tio/tun_file_linux_test.go new file mode 100644 index 00000000..8d66fb05 --- /dev/null +++ b/overlay/tio/tun_file_linux_test.go @@ -0,0 +1,82 @@ +//go:build linux && !android && !e2e_testing +// +build linux,!android,!e2e_testing + +package tio + +import ( + "errors" + "os" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/sys/unix" +) + +// newReadPipe returns a read fd. The matching write fd is registered for cleanup. +// The caller takes ownership of the read fd (pass it to newOffload / newFriend). +func newReadPipe(t *testing.T) int { + t.Helper() + var fds [2]int + if err := unix.Pipe2(fds[:], unix.O_CLOEXEC); err != nil { + t.Fatalf("pipe2: %v", err) + } + t.Cleanup(func() { _ = unix.Close(fds[1]) }) + return fds[0] +} + +func TestPoll_WakeForShutdown_WakesFriends(t *testing.T) { + pipe1 := newReadPipe(t) + pipe2 := newReadPipe(t) + parent, err := NewPollContainer() + require.NoError(t, err) + require.NoError(t, parent.Add(pipe1)) + require.NoError(t, parent.Add(pipe2)) + t.Cleanup(func() { + _ = unix.Close(pipe1) + _ = unix.Close(pipe2) + }) + + readers := parent.Queues() + errs := make([]error, len(readers)) + var wg sync.WaitGroup + for i, r := range readers { + wg.Add(1) + go func(i int, r Queue) { + defer wg.Done() + _, errs[i] = r.Read() + }(i, r) + } + + time.Sleep(50 * time.Millisecond) + + if err := parent.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("readers did not wake") + } + + for i, err := range errs { + if !errors.Is(err, os.ErrClosed) { + t.Errorf("reader %d: expected os.ErrClosed, got %v", i, err) + } + } +} + +func TestPoll_Close_Idempotent(t *testing.T) { + tf, err := newPoll(newReadPipe(t), 1) + require.NoError(t, err) + if err := tf.Close(); err != nil { + t.Fatalf("first Close: %v", err) + } + if err := tf.Close(); err != nil { + t.Fatalf("second Close should be a no-op, got %v", err) + } +} diff --git a/overlay/tun_android.go b/overlay/tun_android.go index 9cbb64be..b68dfcd5 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -13,17 +13,42 @@ import ( "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) type tun struct { - io.ReadWriteCloser + rwc io.ReadWriteCloser fd int vpnNetworks []netip.Prefix Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *slog.Logger + + readBuf []byte + batchRet [1][]byte +} + +func (t *tun) Read() ([][]byte, error) { + n, err := t.rwc.Read(t.readBuf) + if err != nil { + return nil, err + } + t.batchRet[0] = t.readBuf[:n] + return t.batchRet[:], nil +} + +func (t *tun) Write(p []byte) (int, error) { + return t.rwc.Write(p) +} + +func (t *tun) WriteFromSelf(p []byte) (int, error) { + return t.rwc.Write(p) +} + +func (t *tun) Close() error { + return t.rwc.Close() } func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { @@ -32,10 +57,11 @@ func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") t := &tun{ - ReadWriteCloser: file, - fd: deviceFd, - vpnNetworks: vpnNetworks, - l: l, + rwc: file, + fd: deviceFd, + vpnNetworks: vpnNetworks, + l: l, + readBuf: make([]byte, defaultBatchBufSize), } err := t.reload(c, true) @@ -62,7 +88,7 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { return r } -func (t tun) Activate() error { +func (t *tun) Activate() error { return nil } @@ -99,6 +125,10 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, fmt.Errorf("TODO: multiqueue not implemented for android") +func (t *tun) NewMultiQueueReader() error { + return fmt.Errorf("TODO: multiqueue not implemented for android") +} + +func (t *tun) Readers() []tio.Queue { + return []tio.Queue{t} } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 524ef0cd..9a4b70e6 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -16,6 +16,7 @@ import ( "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" @@ -23,7 +24,7 @@ import ( ) type tun struct { - io.ReadWriteCloser + rwc io.ReadWriteCloser Device string vpnNetworks []netip.Prefix DefaultMTU int @@ -34,6 +35,9 @@ type tun struct { // cache out buffer since we need to prepend 4 bytes for tun metadata out []byte + + readBuf []byte + batchRet [1][]byte } type ifReq struct { @@ -124,11 +128,12 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t } t := &tun{ - ReadWriteCloser: os.NewFile(uintptr(fd), ""), - Device: name, - vpnNetworks: vpnNetworks, - DefaultMTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, + rwc: os.NewFile(uintptr(fd), ""), + Device: name, + vpnNetworks: vpnNetworks, + DefaultMTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, + readBuf: make([]byte, defaultBatchBufSize), } err = t.reload(c, true) @@ -158,8 +163,8 @@ func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*tun, e } func (t *tun) Close() error { - if t.ReadWriteCloser != nil { - return t.ReadWriteCloser.Close() + if t.rwc != nil { + return t.rwc.Close() } return nil } @@ -502,15 +507,28 @@ func delRoute(prefix netip.Prefix, gateway netroute.Addr) error { return nil } -func (t *tun) Read(to []byte) (int, error) { +func (t *tun) readOne(to []byte) (int, error) { buf := make([]byte, len(to)+4) - n, err := t.ReadWriteCloser.Read(buf) + n, err := t.rwc.Read(buf) copy(to, buf[4:]) return n - 4, err } +func (t *tun) Read() ([][]byte, error) { + n, err := t.readOne(t.readBuf) + if err != nil { + return nil, err + } + t.batchRet[0] = t.readBuf[:n] + return t.batchRet[:], nil +} + +func (t *tun) WriteFromSelf(p []byte) (int, error) { + return t.Write(p) +} + // Write is only valid for single threaded use func (t *tun) Write(from []byte) (int, error) { buf := t.out @@ -536,7 +554,7 @@ func (t *tun) Write(from []byte) (int, error) { copy(buf[4:], from) - n, err := t.ReadWriteCloser.Write(buf) + n, err := t.rwc.Write(buf) return n - 4, err } @@ -552,6 +570,10 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") +func (t *tun) NewMultiQueueReader() error { + return fmt.Errorf("TODO: multiqueue not implemented for darwin") +} + +func (t *tun) Readers() []tio.Queue { + return []tio.Queue{t} } diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index f47880dd..5daa4797 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -10,6 +10,7 @@ import ( "github.com/rcrowley/go-metrics" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" ) @@ -18,9 +19,27 @@ type disabledTun struct { vpnNetworks []netip.Prefix // Track these metrics since we don't have the tun device to do it for us - tx metrics.Counter - rx metrics.Counter - l *slog.Logger + tx metrics.Counter + rx metrics.Counter + l *slog.Logger + numReaders int + + batchRet [1][]byte +} + +func (t *disabledTun) Read() ([][]byte, error) { + r, ok := <-t.read + if !ok { + return nil, io.EOF + } + + t.tx.Inc(1) + if t.l.Enabled(context.Background(), slog.LevelDebug) { + t.l.Debug("Write payload", "raw", prettyPacket(r)) + } + + t.batchRet[0] = r + return t.batchRet[:], nil } func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *slog.Logger) *disabledTun { @@ -28,6 +47,7 @@ func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled boo vpnNetworks: vpnNetworks, read: make(chan []byte, queueLen), l: l, + numReaders: 1, } if metricsEnabled { @@ -57,24 +77,6 @@ func (*disabledTun) Name() string { return "disabled" } -func (t *disabledTun) Read(b []byte) (int, error) { - r, ok := <-t.read - if !ok { - return 0, io.EOF - } - - if len(r) > len(b) { - return 0, fmt.Errorf("packet larger than mtu: %d > %d bytes", len(r), len(b)) - } - - t.tx.Inc(1) - if t.l.Enabled(context.Background(), slog.LevelDebug) { - t.l.Debug("Write payload", "raw", prettyPacket(r)) - } - - return copy(b, r), nil -} - func (t *disabledTun) handleICMPEchoRequest(b []byte) bool { out := make([]byte, len(b)) out = iputil.CreateICMPEchoResponse(b, out) @@ -106,12 +108,25 @@ func (t *disabledTun) Write(b []byte) (int, error) { return len(b), nil } +func (t *disabledTun) WriteFromSelf(b []byte) (int, error) { + return t.Write(b) +} + func (t *disabledTun) SupportsMultiqueue() bool { return true } -func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return t, nil +func (t *disabledTun) NewMultiQueueReader() error { + t.numReaders++ + return nil +} + +func (t *disabledTun) Readers() []tio.Queue { + out := make([]tio.Queue, t.numReaders) + for i := range t.numReaders { + out[i] = t + } + return out } func (t *disabledTun) Close() error { diff --git a/overlay/tun_file_linux_test.go b/overlay/tun_file_linux_test.go deleted file mode 100644 index 5ab87e05..00000000 --- a/overlay/tun_file_linux_test.go +++ /dev/null @@ -1,120 +0,0 @@ -//go:build linux && !android && !e2e_testing -// +build linux,!android,!e2e_testing - -package overlay - -import ( - "errors" - "os" - "sync" - "testing" - "time" - - "golang.org/x/sys/unix" -) - -// newReadPipe returns a read fd. The matching write fd is registered for cleanup. -// The caller takes ownership of the read fd (pass it to newTunFd / newFriend). -func newReadPipe(t *testing.T) int { - t.Helper() - var fds [2]int - if err := unix.Pipe2(fds[:], unix.O_CLOEXEC); err != nil { - t.Fatalf("pipe2: %v", err) - } - t.Cleanup(func() { _ = unix.Close(fds[1]) }) - return fds[0] -} - -func TestTunFile_WakeForShutdown_UnblocksRead(t *testing.T) { - tf, err := newTunFd(newReadPipe(t)) - if err != nil { - t.Fatalf("newTunFd: %v", err) - } - t.Cleanup(func() { _ = tf.Close() }) - - done := make(chan error, 1) - go func() { - _, err := tf.Read(make([]byte, 64)) - done <- err - }() - - // Verify Read is actually blocked in poll. - select { - case err := <-done: - t.Fatalf("Read returned before shutdown signal: %v", err) - case <-time.After(50 * time.Millisecond): - } - - if err := tf.wakeForShutdown(); err != nil { - t.Fatalf("wakeForShutdown: %v", err) - } - - select { - case err := <-done: - if !errors.Is(err, os.ErrClosed) { - t.Fatalf("expected os.ErrClosed, got %v", err) - } - case <-time.After(2 * time.Second): - t.Fatal("Read did not wake on shutdown") - } -} - -func TestTunFile_WakeForShutdown_WakesFriends(t *testing.T) { - parent, err := newTunFd(newReadPipe(t)) - if err != nil { - t.Fatalf("newTunFd: %v", err) - } - friend, err := parent.newFriend(newReadPipe(t)) - if err != nil { - _ = parent.Close() - t.Fatalf("newFriend: %v", err) - } - t.Cleanup(func() { - _ = friend.Close() - _ = parent.Close() - }) - - readers := []*tunFile{parent, friend} - errs := make([]error, len(readers)) - var wg sync.WaitGroup - for i, r := range readers { - wg.Add(1) - go func(i int, r *tunFile) { - defer wg.Done() - _, errs[i] = r.Read(make([]byte, 64)) - }(i, r) - } - - time.Sleep(50 * time.Millisecond) - - if err := parent.wakeForShutdown(); err != nil { - t.Fatalf("wakeForShutdown: %v", err) - } - - done := make(chan struct{}) - go func() { wg.Wait(); close(done) }() - select { - case <-done: - case <-time.After(2 * time.Second): - t.Fatal("readers did not wake") - } - - for i, err := range errs { - if !errors.Is(err, os.ErrClosed) { - t.Errorf("reader %d: expected os.ErrClosed, got %v", i, err) - } - } -} - -func TestTunFile_Close_Idempotent(t *testing.T) { - tf, err := newTunFd(newReadPipe(t)) - if err != nil { - t.Fatalf("newTunFd: %v", err) - } - if err := tf.Close(); err != nil { - t.Fatalf("first Close: %v", err) - } - if err := tf.Close(); err != nil { - t.Fatalf("second Close should be a no-op, got %v", err) - } -} diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 3d995553..980c3efb 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -7,7 +7,6 @@ import ( "bytes" "errors" "fmt" - "io" "io/fs" "log/slog" "net/netip" @@ -20,7 +19,7 @@ import ( "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" - + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" @@ -103,6 +102,9 @@ type tun struct { readPoll [2]unix.PollFd writePoll [2]unix.PollFd closed atomic.Bool + + readBuf []byte + batchRet [1][]byte } // blockOnRead waits until the tun fd is readable or shutdown has been signaled. @@ -157,7 +159,20 @@ func (t *tun) blockOnWrite() error { return nil } -func (t *tun) Read(to []byte) (int, error) { +func (t *tun) Read() ([][]byte, error) { + n, err := t.readOne(t.readBuf) + if err != nil { + return nil, err + } + t.batchRet[0] = t.readBuf[:n] + return t.batchRet[:], nil +} + +func (t *tun) WriteFromSelf(p []byte) (int, error) { + return t.Write(p) +} + +func (t *tun) readOne(to []byte) (int, error) { // first 4 bytes is protocol family, in network byte order var head [4]byte iovecs := [2]syscall.Iovec{ @@ -375,6 +390,7 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, fd: fd, + readBuf: make([]byte, defaultBatchBufSize), shutdownR: shutdownR, shutdownW: shutdownW, readPoll: [2]unix.PollFd{ @@ -565,8 +581,8 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd") +func (t *tun) NewMultiQueueReader() error { + return fmt.Errorf("TODO: multiqueue not implemented for freebsd") } func (t *tun) addRoutes(logErrors bool) error { @@ -593,6 +609,10 @@ func (t *tun) addRoutes(logErrors bool) error { return nil } +func (t *tun) Readers() []tio.Queue { + return []tio.Queue{t} +} + func (t *tun) removeRoutes(routes []Route) error { for _, r := range routes { if !r.Install { diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 6bfcbdfb..65dc0edc 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -16,16 +16,41 @@ import ( "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) type tun struct { - io.ReadWriteCloser + rwc io.ReadWriteCloser vpnNetworks []netip.Prefix Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *slog.Logger + + readBuf []byte + batchRet [1][]byte +} + +func (t *tun) Read() ([][]byte, error) { + n, err := t.rwc.Read(t.readBuf) + if err != nil { + return nil, err + } + t.batchRet[0] = t.readBuf[:n] + return t.batchRet[:], nil +} + +func (t *tun) Write(p []byte) (int, error) { + return t.rwc.Write(p) +} + +func (t *tun) WriteFromSelf(p []byte) (int, error) { + return t.rwc.Write(p) +} + +func (t *tun) Close() error { + return t.rwc.Close() } func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) { @@ -35,9 +60,10 @@ func newTun(_ *config.C, _ *slog.Logger, _ []netip.Prefix, _ bool) (*tun, error) func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/tun") t := &tun{ - vpnNetworks: vpnNetworks, - ReadWriteCloser: &tunReadCloser{f: file}, - l: l, + vpnNetworks: vpnNetworks, + rwc: &tunReadCloser{f: file}, + l: l, + readBuf: make([]byte, defaultBatchBufSize), } err := t.reload(c, true) @@ -155,6 +181,10 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, fmt.Errorf("TODO: multiqueue not implemented for ios") +func (t *tun) NewMultiQueueReader() error { + return fmt.Errorf("TODO: multiqueue not implemented for ios") +} + +func (t *tun) Readers() []tio.Queue { + return []tio.Queue{t} } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index c6cfb686..19e3ceb0 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -4,9 +4,7 @@ package overlay import ( - "encoding/binary" "fmt" - "io" "log/slog" "net" "net/netip" @@ -19,180 +17,15 @@ import ( "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" "github.com/vishvananda/netlink" "golang.org/x/sys/unix" ) -// tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking. -// A shared eventfd allows Close to wake all readers blocked in poll. -type tunFile struct { - fd int - shutdownFd int - lastOne bool - readPoll [2]unix.PollFd - writePoll [2]unix.PollFd - closed bool -} - -// newFriend makes a tunFile for a MultiQueueReader that copies the shutdown eventfd from the parent tun -func (r *tunFile) newFriend(fd int) (*tunFile, error) { - if err := unix.SetNonblock(fd, true); err != nil { - return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) - } - return &tunFile{ - fd: fd, - shutdownFd: r.shutdownFd, - readPoll: [2]unix.PollFd{ - {Fd: int32(fd), Events: unix.POLLIN}, - {Fd: int32(r.shutdownFd), Events: unix.POLLIN}, - }, - writePoll: [2]unix.PollFd{ - {Fd: int32(fd), Events: unix.POLLOUT}, - {Fd: int32(r.shutdownFd), Events: unix.POLLIN}, - }, - }, nil -} - -func newTunFd(fd int) (*tunFile, error) { - if err := unix.SetNonblock(fd, true); err != nil { - return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) - } - - shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC) - if err != nil { - return nil, fmt.Errorf("failed to create eventfd: %w", err) - } - - out := &tunFile{ - fd: fd, - shutdownFd: shutdownFd, - lastOne: true, - readPoll: [2]unix.PollFd{ - {Fd: int32(fd), Events: unix.POLLIN}, - {Fd: int32(shutdownFd), Events: unix.POLLIN}, - }, - writePoll: [2]unix.PollFd{ - {Fd: int32(fd), Events: unix.POLLOUT}, - {Fd: int32(shutdownFd), Events: unix.POLLIN}, - }, - } - - return out, nil -} - -func (r *tunFile) blockOnRead() error { - const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR - var err error - for { - _, err = unix.Poll(r.readPoll[:], -1) - if err != unix.EINTR { - break - } - } - //always reset these! - tunEvents := r.readPoll[0].Revents - shutdownEvents := r.readPoll[1].Revents - r.readPoll[0].Revents = 0 - r.readPoll[1].Revents = 0 - //do the err check before trusting the potentially bogus bits we just got - if err != nil { - return err - } - if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { - return os.ErrClosed - } else if tunEvents&problemFlags != 0 { - return os.ErrClosed - } - return nil -} - -func (r *tunFile) blockOnWrite() error { - const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR - var err error - for { - _, err = unix.Poll(r.writePoll[:], -1) - if err != unix.EINTR { - break - } - } - //always reset these! - tunEvents := r.writePoll[0].Revents - shutdownEvents := r.writePoll[1].Revents - r.writePoll[0].Revents = 0 - r.writePoll[1].Revents = 0 - //do the err check before trusting the potentially bogus bits we just got - if err != nil { - return err - } - if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { - return os.ErrClosed - } else if tunEvents&problemFlags != 0 { - return os.ErrClosed - } - return nil -} - -func (r *tunFile) Read(buf []byte) (int, error) { - for { - if n, err := unix.Read(r.fd, buf); err == nil { - return n, nil - } else if err == unix.EAGAIN { - if err = r.blockOnRead(); err != nil { - return 0, err - } - continue - } else if err == unix.EINTR { - continue - } else if err == unix.EBADF { - return 0, os.ErrClosed - } else { - return 0, err - } - } -} - -func (r *tunFile) Write(buf []byte) (int, error) { - for { - if n, err := unix.Write(r.fd, buf); err == nil { - return n, nil - } else if err == unix.EAGAIN { - if err = r.blockOnWrite(); err != nil { - return 0, err - } - continue - } else if err == unix.EINTR { - continue - } else if err == unix.EBADF { - return 0, os.ErrClosed - } else { - return 0, err - } - } -} - -func (r *tunFile) wakeForShutdown() error { - var buf [8]byte - binary.NativeEndian.PutUint64(buf[:], 1) - _, err := unix.Write(int(r.readPoll[1].Fd), buf[:]) - return err -} - -func (r *tunFile) Close() error { - if r.closed { // avoid closing more than once. Technically a fd could get re-used, which would be a problem - return nil - } - r.closed = true - if r.lastOne { - _ = unix.Close(r.shutdownFd) - } - return unix.Close(r.fd) -} - type tun struct { - *tunFile - readers []*tunFile + readers tio.Container closeLock sync.Mutex Device string vpnNetworks []netip.Prefix @@ -249,44 +82,57 @@ func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip return t, nil } -func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { +// openTunDev opens /dev/net/tun, creating the device node first if it's +// missing (docker containers occasionally omit it). +func openTunDev() (int, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) - if err != nil { - // If /dev/net/tun doesn't exist, try to create it (will happen in docker) - if os.IsNotExist(err) { - err = os.MkdirAll("/dev/net", 0755) - if err != nil { - return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err) - } - err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))) - if err != nil { - return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err) - } - - fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0) - if err != nil { - return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err) - } - } else { - return nil, err - } + if err == nil { + return fd, nil } + if !os.IsNotExist(err) { + return -1, err + } + if err = os.MkdirAll("/dev/net", 0755); err != nil { + return -1, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err) + } + if err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))); err != nil { + return -1, fmt.Errorf("failed to create /dev/net/tun: %w", err) + } + fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0) + if err != nil { + return -1, fmt.Errorf("created /dev/net/tun, but still failed: %w", err) + } + return fd, nil +} +// tunSetIff runs TUNSETIFF with the given flags and returns the kernel-chosen +// device name on success. +func tunSetIff(fd int, name string, flags uint16) (string, error) { var req ifReq - req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI) + req.Flags = flags + copy(req.Name[:], name) + if err := ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + return "", err + } + return strings.Trim(string(req.Name[:]), "\x00"), nil +} + +func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { + baseFlags := uint16(unix.IFF_TUN | unix.IFF_NO_PI) if multiqueue { - req.Flags |= unix.IFF_MULTI_QUEUE + baseFlags |= unix.IFF_MULTI_QUEUE } nameStr := c.GetString("tun.dev", "") - copy(req.Name[:], nameStr) - if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { - _ = unix.Close(fd) - return nil, &NameError{ - Name: nameStr, - Underlying: err, - } + + fd, err := openTunDev() + if err != nil { + return nil, err + } + name, err := tunSetIff(fd, nameStr, baseFlags) + if err != nil { + _ = unix.Close(fd) + return nil, &NameError{Name: nameStr, Underlying: err} } - name := strings.Trim(string(req.Name[:]), "\x00") t, err := newTunGeneric(c, l, fd, vpnNetworks) if err != nil { @@ -300,14 +146,19 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue // newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error. func newTunGeneric(c *config.C, l *slog.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) { - tfd, err := newTunFd(fd) + container, err := tio.NewPollContainer() if err != nil { _ = unix.Close(fd) return nil, err } + err = container.Add(fd) + if err != nil { + _ = unix.Close(fd) + return nil, err + } + t := &tun{ - tunFile: tfd, - readers: []*tunFile{tfd}, + readers: container, closeLock: sync.Mutex{}, vpnNetworks: vpnNetworks, TXQueueLen: c.GetInt("tun.tx_queue", 500), @@ -410,32 +261,28 @@ func (t *tun) SupportsMultiqueue() bool { return true } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *tun) NewMultiQueueReader() error { t.closeLock.Lock() defer t.closeLock.Unlock() fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { - return nil, err + return err } - var req ifReq - req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) - copy(req.Name[:], t.Device) - if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + flags := uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) + if _, err = tunSetIff(fd, t.Device, flags); err != nil { _ = unix.Close(fd) - return nil, err + return err } - out, err := t.tunFile.newFriend(fd) + err = t.readers.Add(fd) if err != nil { _ = unix.Close(fd) - return nil, err + return err } - t.readers = append(t.readers, out) - - return out, nil + return nil } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { @@ -869,6 +716,10 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { t.routeTree.Store(newTree) } +func (t *tun) Readers() []tio.Queue { + return t.readers.Queues() +} + func (t *tun) Close() error { t.closeLock.Lock() defer t.closeLock.Unlock() @@ -878,32 +729,10 @@ func (t *tun) Close() error { t.routeChan = nil } - // Signal all readers blocked in poll to wake up and exit - _ = t.tunFile.wakeForShutdown() - if t.ioctlFd > 0 { _ = unix.Close(int(t.ioctlFd)) t.ioctlFd = 0 } - for i := range t.readers { - if i == 0 { - continue //we want to close the zeroth reader last - } - err := t.readers[i].Close() - if err != nil { - t.l.Error("error closing tun reader", "reader", i, "error", err) - } else { - t.l.Info("closed tun reader", "reader", i) - } - } - - //this is t.readers[0] too - err := t.tunFile.Close() - if err != nil { - t.l.Error("error closing tun reader", "reader", 0, "error", err) - } else { - t.l.Info("closed tun reader", "reader", 0) - } - return err + return t.readers.Close() } diff --git a/overlay/tun_linux_test.go b/overlay/tun_linux_test.go index 1c1842da..1003a165 100644 --- a/overlay/tun_linux_test.go +++ b/overlay/tun_linux_test.go @@ -3,7 +3,9 @@ package overlay -import "testing" +import ( + "testing" +) var runAdvMSSTests = []struct { name string diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index c971bb6e..8275b754 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -6,7 +6,6 @@ package overlay import ( "errors" "fmt" - "io" "log/slog" "net/netip" "os" @@ -17,6 +16,7 @@ import ( "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" @@ -66,6 +66,26 @@ type tun struct { l *slog.Logger f *os.File fd int + + readBuf []byte + batchRet [1][]byte +} + +func (t *tun) Read() ([][]byte, error) { + n, err := t.readOne(t.readBuf) + if err != nil { + return nil, err + } + t.batchRet[0] = t.readBuf[:n] + return t.batchRet[:], nil +} + +func (t *tun) WriteFromSelf(p []byte) (int, error) { + return t.Write(p) +} + +func (t *tun) Readers() []tio.Queue { + return []tio.Queue{t} } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) @@ -102,6 +122,7 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, + readBuf: make([]byte, defaultBatchBufSize), } err = t.reload(c, true) @@ -141,7 +162,7 @@ func (t *tun) Close() error { return nil } -func (t *tun) Read(to []byte) (int, error) { +func (t *tun) readOne(to []byte) (int, error) { rc, err := t.f.SyscallConn() if err != nil { return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err) @@ -394,8 +415,8 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd") +func (t *tun) NewMultiQueueReader() error { + return fmt.Errorf("TODO: multiqueue not implemented for netbsd") } func (t *tun) addRoutes(logErrors bool) error { diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 81362184..8c16b977 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -6,7 +6,6 @@ package overlay import ( "errors" "fmt" - "io" "log/slog" "net/netip" "os" @@ -17,6 +16,7 @@ import ( "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" @@ -59,6 +59,22 @@ type tun struct { fd int // cache out buffer since we need to prepend 4 bytes for tun metadata out []byte + + readBuf []byte + batchRet [1][]byte +} + +func (t *tun) Read() ([][]byte, error) { + n, err := t.readOne(t.readBuf) + if err != nil { + return nil, err + } + t.batchRet[0] = t.readBuf[:n] + return t.batchRet[:], nil +} + +func (t *tun) WriteFromSelf(p []byte) (int, error) { + return t.Write(p) } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) @@ -95,6 +111,7 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, + readBuf: make([]byte, defaultBatchBufSize), } err = t.reload(c, true) @@ -124,7 +141,7 @@ func (t *tun) Close() error { return nil } -func (t *tun) Read(to []byte) (int, error) { +func (t *tun) readOne(to []byte) (int, error) { buf := make([]byte, len(to)+4) n, err := t.f.Read(buf) @@ -314,8 +331,8 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd") +func (t *tun) NewMultiQueueReader() error { + return fmt.Errorf("TODO: multiqueue not implemented for openbsd") } func (t *tun) addRoutes(logErrors bool) error { @@ -366,6 +383,10 @@ func (t *tun) deviceBytes() (o [16]byte) { return } +func (t *tun) Readers() []tio.Queue { + return []tio.Queue{t} +} + func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error { sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) if err != nil { diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index b2c2a0ea..22bc9f5c 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -14,6 +14,7 @@ import ( "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" ) @@ -27,6 +28,17 @@ type TestTun struct { closed atomic.Bool rxPackets chan []byte // Packets to receive into nebula TxPackets chan []byte // Packets transmitted outside by nebula + + batchRet [1][]byte +} + +func (t *TestTun) Read() ([][]byte, error) { + p, ok := <-t.rxPackets + if !ok { + return nil, os.ErrClosed + } + t.batchRet[0] = p + return t.batchRet[:], nil } func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { @@ -116,6 +128,10 @@ func (t *TestTun) Write(b []byte) (n int, err error) { return len(b), nil } +func (t *TestTun) WriteFromSelf(b []byte) (int, error) { + return t.Write(b) +} + func (t *TestTun) Close() error { if t.closed.CompareAndSwap(false, true) { close(t.rxPackets) @@ -124,19 +140,14 @@ func (t *TestTun) Close() error { return nil } -func (t *TestTun) Read(b []byte) (int, error) { - p, ok := <-t.rxPackets - if !ok { - return 0, os.ErrClosed - } - copy(b, p) - return len(p), nil +func (t *TestTun) Readers() []tio.Queue { + return []tio.Queue{t} } func (t *TestTun) SupportsMultiqueue() bool { return false } -func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, fmt.Errorf("TODO: multiqueue not implemented") +func (t *TestTun) NewMultiQueueReader() error { + return fmt.Errorf("TODO: multiqueue not implemented") } diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 680dddb3..c99d259f 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -6,7 +6,6 @@ package overlay import ( "crypto" "fmt" - "io" "log/slog" "net/netip" "os" @@ -18,6 +17,7 @@ import ( "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" "github.com/slackhq/nebula/wintun" @@ -36,6 +36,22 @@ type winTun struct { l *slog.Logger tun *wintun.NativeTun + + readBuf []byte + batchRet [1][]byte +} + +func (t *winTun) Read() ([][]byte, error) { + n, err := t.tun.Read(t.readBuf, 0) + if err != nil { + return nil, err + } + t.batchRet[0] = t.readBuf[:n] + return t.batchRet[:], nil +} + +func (t *winTun) WriteFromSelf(p []byte) (int, error) { + return t.Write(p) } func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (Device, error) { @@ -55,6 +71,7 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*w } t := &winTun{ + readBuf: make([]byte, defaultBatchBufSize), Device: deviceName, vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), @@ -229,10 +246,6 @@ func (t *winTun) Name() string { return t.Device } -func (t *winTun) Read(b []byte) (int, error) { - return t.tun.Read(b, 0) -} - func (t *winTun) Write(b []byte) (int, error) { return t.tun.Write(b, 0) } @@ -241,8 +254,12 @@ func (t *winTun) SupportsMultiqueue() bool { return false } -func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") +func (t *winTun) NewMultiQueueReader() error { + return fmt.Errorf("TODO: multiqueue not implemented for windows") +} + +func (t *winTun) Readers() []tio.Queue { + return []tio.Queue{t} } func (t *winTun) Close() error { diff --git a/overlay/user.go b/overlay/user.go index e5f27f37..fcb1ee51 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -6,6 +6,7 @@ import ( "net/netip" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" ) @@ -23,17 +24,34 @@ func NewUserDevice(vpnNetworks []netip.Prefix) (Device, error) { outboundWriter: ow, inboundReader: ir, inboundWriter: iw, + numReaders: 1, }, nil } type UserDevice struct { vpnNetworks []netip.Prefix + numReaders int outboundReader *io.PipeReader outboundWriter *io.PipeWriter inboundReader *io.PipeReader inboundWriter *io.PipeWriter + + readBuf []byte + batchRet [1][]byte +} + +func (d *UserDevice) Read() ([][]byte, error) { + if d.readBuf == nil { + d.readBuf = make([]byte, defaultBatchBufSize) + } + n, err := d.outboundReader.Read(d.readBuf) + if err != nil { + return nil, err + } + d.batchRet[0] = d.readBuf[:n] + return d.batchRet[:], nil } func (d *UserDevice) Activate() error { @@ -50,20 +68,29 @@ func (d *UserDevice) SupportsMultiqueue() bool { return true } -func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return d, nil +func (d *UserDevice) NewMultiQueueReader() error { + d.numReaders++ + return nil +} + +func (d *UserDevice) Readers() []tio.Queue { + out := make([]tio.Queue, d.numReaders) + for i := range d.numReaders { + out[i] = d + } + return out } func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) { return d.inboundReader, d.outboundWriter } -func (d *UserDevice) Read(p []byte) (n int, err error) { - return d.outboundReader.Read(p) -} func (d *UserDevice) Write(p []byte) (n int, err error) { return d.inboundWriter.Write(p) } +func (d *UserDevice) WriteFromSelf(p []byte) (n int, err error) { + return d.Write(p) +} func (d *UserDevice) Close() error { d.inboundWriter.Close() d.outboundWriter.Close() diff --git a/udp/conn.go b/udp/conn.go index 30d89dec..14902a76 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -8,6 +8,12 @@ import ( const MTU = 9001 +// MaxWriteBatch is the largest batch any Conn.WriteBatch implementation is +// required to accept. Callers SHOULD NOT pass more than this per call; Linux +// backends preallocate sendmmsg scratch sized to this value, so exceeding it +// only costs a chunked retry. +const MaxWriteBatch = 128 + type EncReader func( addr netip.AddrPort, payload []byte, @@ -16,8 +22,19 @@ type EncReader func( type Conn interface { Rebind() error LocalAddr() (netip.AddrPort, error) - ListenOut(r EncReader) error + // ListenOut invokes r for each received packet. On batch-capable + // backends (recvmmsg), flush is called after each batch is fully + // delivered — callers use it to flush per-batch accumulators such as + // TUN write coalescers. Single-packet backends call flush after each + // packet. flush must not be nil. + ListenOut(r EncReader, flush func()) error WriteTo(b []byte, addr netip.AddrPort) error + // WriteBatch sends a contiguous batch of packets, each with its own + // destination. bufs and addrs must have the same length. Linux uses + // sendmmsg(2) for a single syscall; other backends fall back to a + // WriteTo loop. Returns on the first error; callers may observe a + // partial send if some packets went out before the error. + WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error ReloadConfig(c *config.C) SupportsMultipleReaders() bool Close() error @@ -31,7 +48,7 @@ func (NoopConn) Rebind() error { func (NoopConn) LocalAddr() (netip.AddrPort, error) { return netip.AddrPort{}, nil } -func (NoopConn) ListenOut(_ EncReader) error { +func (NoopConn) ListenOut(_ EncReader, _ func()) error { return nil } func (NoopConn) SupportsMultipleReaders() bool { @@ -40,6 +57,9 @@ func (NoopConn) SupportsMultipleReaders() bool { func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { return nil } +func (NoopConn) WriteBatch(_ [][]byte, _ []netip.AddrPort) error { + return nil +} func (NoopConn) ReloadConfig(_ *config.C) { return } diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 8a4f5b18..2468c6c4 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -140,6 +140,15 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error { } } +func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { + for i, b := range bufs { + if err := u.WriteTo(b, addrs[i]); err != nil { + return err + } + } + return nil +} + func (u *StdConn) LocalAddr() (netip.AddrPort, error) { a := u.UDPConn.LocalAddr() @@ -165,7 +174,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() { return func() {} } -func (u *StdConn) ListenOut(r EncReader) error { +func (u *StdConn) ListenOut(r EncReader, flush func()) error { buffer := make([]byte, MTU) for { @@ -180,6 +189,7 @@ func (u *StdConn) ListenOut(r EncReader) error { } r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) + flush() } } diff --git a/udp/udp_generic.go b/udp/udp_generic.go index 131eb73b..c0dacedb 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -44,6 +44,15 @@ func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error { return err } +func (u *GenericConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { + for i, b := range bufs { + if _, err := u.UDPConn.WriteToUDPAddrPort(b, addrs[i]); err != nil { + return err + } + } + return nil +} + func (u *GenericConn) LocalAddr() (netip.AddrPort, error) { a := u.UDPConn.LocalAddr() @@ -73,7 +82,7 @@ type rawMessage struct { Len uint32 } -func (u *GenericConn) ListenOut(r EncReader) error { +func (u *GenericConn) ListenOut(r EncReader, flush func()) error { buffer := make([]byte, MTU) var lastRecvErr time.Time @@ -94,6 +103,7 @@ func (u *GenericConn) ListenOut(r EncReader) error { } r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) + flush() } } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 3e2d726a..ec840426 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -171,7 +171,7 @@ func recvmmsg(fd uintptr, msgs []rawMessage) (int, bool, error) { return int(n), true, nil } -func (u *StdConn) listenOutSingle(r EncReader) error { +func (u *StdConn) listenOutSingle(r EncReader, flush func()) error { var err error var n int var from netip.AddrPort @@ -184,15 +184,17 @@ func (u *StdConn) listenOutSingle(r EncReader) error { } from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port()) r(from, buffer[:n]) + flush() } } -func (u *StdConn) listenOutBatch(r EncReader) error { +func (u *StdConn) listenOutBatch(r EncReader, flush func()) error { var ip netip.Addr var n int var operr error - msgs, buffers, names := u.PrepareRawMessages(u.batch) + bufSize := MTU + msgs, buffers, names := u.PrepareRawMessages(u.batch, bufSize) //reader needs to capture variables from this function, since it's used as a lambda with rawConn.Read //defining it outside the loop so it gets re-used @@ -217,16 +219,22 @@ func (u *StdConn) listenOutBatch(r EncReader) error { } else { ip, _ = netip.AddrFromSlice(names[i][8:24]) } - r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) + from := netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])) + payload := buffers[i][:msgs[i].Len] + + r(from, payload) } + // End-of-batch: let callers (e.g. TUN write coalescer) flush any + // state they accumulated across this batch. + flush() } } -func (u *StdConn) ListenOut(r EncReader) error { +func (u *StdConn) ListenOut(r EncReader, flush func()) error { if u.batch == 1 { - return u.listenOutSingle(r) + return u.listenOutSingle(r, flush) } else { - return u.listenOutBatch(r) + return u.listenOutBatch(r, flush) } } @@ -235,6 +243,19 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { return err } +func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { + if len(bufs) != len(addrs) { + return fmt.Errorf("WriteBatch: len(bufs)=%d != len(addrs)=%d", len(bufs), len(addrs)) + } + //todo use sendmmsg + for i := 0; i < len(bufs); i++ { + if _, err := u.udpConn.WriteToUDPAddrPort(bufs[i], addrs[i]); err != nil { + return err + } + } + return nil +} + func (u *StdConn) ReloadConfig(c *config.C) { b := c.GetInt("listen.read_buffer", 0) if b > 0 { diff --git a/udp/udp_linux_32.go b/udp/udp_linux_32.go index de8f1cdf..e253784b 100644 --- a/udp/udp_linux_32.go +++ b/udp/udp_linux_32.go @@ -30,13 +30,13 @@ type rawMessage struct { Len uint32 } -func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { +func (u *StdConn) PrepareRawMessages(n, bufSize int) ([]rawMessage, [][]byte, [][]byte) { msgs := make([]rawMessage, n) buffers := make([][]byte, n) names := make([][]byte, n) for i := range msgs { - buffers[i] = make([]byte, MTU) + buffers[i] = make([]byte, bufSize) names[i] = make([]byte, unix.SizeofSockaddrInet6) vs := []iovec{ @@ -52,3 +52,19 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { return msgs, buffers, names } + +func setIovLen(v *iovec, n int) { + v.Len = uint32(n) +} + +func setMsgIovlen(m *msghdr, n int) { + m.Iovlen = uint32(n) +} + +func setMsgControllen(m *msghdr, n int) { + m.Controllen = uint32(n) +} + +func setCmsgLen(h *unix.Cmsghdr, n int) { + h.Len = uint32(n) +} diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 48c5a978..d18ca281 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -33,13 +33,13 @@ type rawMessage struct { Pad0 [4]byte } -func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { +func (u *StdConn) PrepareRawMessages(n, bufSize int) ([]rawMessage, [][]byte, [][]byte) { msgs := make([]rawMessage, n) buffers := make([][]byte, n) names := make([][]byte, n) for i := range msgs { - buffers[i] = make([]byte, MTU) + buffers[i] = make([]byte, bufSize) names[i] = make([]byte, unix.SizeofSockaddrInet6) vs := []iovec{ @@ -55,3 +55,19 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { return msgs, buffers, names } + +func setIovLen(v *iovec, n int) { + v.Len = uint64(n) +} + +func setMsgIovlen(m *msghdr, n int) { + m.Iovlen = uint64(n) +} + +func setMsgControllen(m *msghdr, n int) { + m.Controllen = uint64(n) +} + +func setCmsgLen(h *unix.Cmsghdr, n int) { + h.Len = uint64(n) +} diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index d110af19..007384b1 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -140,7 +140,7 @@ func (u *RIOConn) bind(l *slog.Logger, sa windows.Sockaddr) error { return nil } -func (u *RIOConn) ListenOut(r EncReader) error { +func (u *RIOConn) ListenOut(r EncReader, flush func()) error { buffer := make([]byte, MTU) var lastRecvErr time.Time @@ -162,6 +162,7 @@ func (u *RIOConn) ListenOut(r EncReader) error { } r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n]) + flush() } } @@ -316,6 +317,15 @@ func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error { return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) } +func (u *RIOConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { + for i, b := range bufs { + if err := u.WriteTo(b, addrs[i]); err != nil { + return err + } + } + return nil +} + func (u *RIOConn) LocalAddr() (netip.AddrPort, error) { sa, err := windows.Getsockname(u.sock) if err != nil { diff --git a/udp/udp_tester.go b/udp/udp_tester.go index fcd0967c..183a39ba 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -122,13 +122,23 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { } } -func (u *TesterConn) ListenOut(r EncReader) error { +func (u *TesterConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { + for i, b := range bufs { + if err := u.WriteTo(b, addrs[i]); err != nil { + return err + } + } + return nil +} + +func (u *TesterConn) ListenOut(r EncReader, flush func()) error { for { select { case <-u.done: return os.ErrClosed case p := <-u.RxPackets: r(p.From, p.Data) + flush() } } }