From ad6b918e4d161339166c824f3404d7fbbbc741b8 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Tue, 21 Apr 2026 13:31:16 -0500 Subject: [PATCH] checkpt --- interface.go | 21 +- main.go | 8 + .../coalesce/tcp_coalesce.go | 36 +- .../coalesce/tcp_coalesce_test.go | 32 +- overlay/device.go | 55 +- overlay/noop.go | 3 +- overlay/tio/container_gso_linux.go | 70 +++ overlay/tio/container_poll_linux.go | 69 +++ overlay/tio/tio.go | 63 +++ overlay/tio/tio_gso_linux.go | 405 ++++++++++++++ overlay/tio/tio_poll_linux.go | 205 +++++++ overlay/{ => tio}/tun_file_linux_test.go | 2 +- overlay/{ => tio}/tun_linux_offload.go | 64 +-- overlay/{ => tio}/tun_linux_offload_test.go | 14 +- overlay/tio/tun_linux_test.go | 38 ++ overlay/tio/vnethdr_linux.go | 39 ++ overlay/tun_android.go | 3 +- overlay/tun_darwin.go | 3 +- overlay/tun_disabled.go | 22 +- overlay/tun_freebsd.go | 3 +- overlay/tun_ios.go | 3 +- overlay/tun_linux.go | 518 ++---------------- overlay/tun_linux_test.go | 34 -- overlay/tun_netbsd.go | 3 +- overlay/tun_openbsd.go | 3 +- overlay/tun_tester.go | 3 +- overlay/tun_windows.go | 3 +- overlay/user.go | 15 +- 28 files changed, 1039 insertions(+), 698 deletions(-) rename tcp_coalesce.go => overlay/coalesce/tcp_coalesce.go (93%) rename tcp_coalesce_test.go => overlay/coalesce/tcp_coalesce_test.go (97%) create mode 100644 overlay/tio/container_gso_linux.go create mode 100644 overlay/tio/container_poll_linux.go create mode 100644 overlay/tio/tio.go create mode 100644 overlay/tio/tio_gso_linux.go create mode 100644 overlay/tio/tio_poll_linux.go rename overlay/{ => tio}/tun_file_linux_test.go (99%) rename overlay/{ => tio}/tun_linux_offload.go (79%) rename overlay/{ => tio}/tun_linux_offload_test.go (97%) create mode 100644 overlay/tio/tun_linux_test.go create mode 100644 overlay/tio/vnethdr_linux.go delete mode 100644 overlay/tun_linux_test.go diff --git a/interface.go b/interface.go index 590b81ea..82356971 100644 --- a/interface.go +++ b/interface.go @@ -16,6 +16,8 @@ import ( "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/overlay" + "github.com/slackhq/nebula/overlay/coalesce" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/udp" ) @@ -85,11 +87,11 @@ type Interface struct { conntrackCacheTimeout time.Duration writers []udp.Conn - readers []overlay.Queue + readers []tio.Queue // tunCoalescers is one tcpCoalescer per tun queue, wrapping readers[i]. // decryptToTun sends plaintext into the coalescer; listenOut calls its // Flush at the end of each UDP recvmmsg batch. - tunCoalescers []*tcpCoalescer + tunCoalescers []*coalesce.TCPCoalescer wg sync.WaitGroup // fatalErr holds the first unexpected reader error that caused shutdown. @@ -187,8 +189,8 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { routines: c.routines, version: c.version, writers: make([]udp.Conn, c.routines), - readers: make([]overlay.Queue, c.routines), - tunCoalescers: make([]*tcpCoalescer, c.routines), + readers: make([]tio.Queue, c.routines), + tunCoalescers: make([]*coalesce.TCPCoalescer, c.routines), myVpnNetworks: cs.myVpnNetworks, myVpnNetworksTable: cs.myVpnNetworksTable, myVpnAddrs: cs.myVpnAddrs, @@ -243,16 +245,17 @@ func (f *Interface) activate() error { metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines)) // Prepare n tun queues - var reader overlay.Queue = f.inside for i := 0; i < f.routines; i++ { if i > 0 { - reader, err = f.inside.NewMultiQueueReader() + err = f.inside.NewMultiQueueReader() if err != nil { return err } } - f.readers[i] = reader - f.tunCoalescers[i] = newTCPCoalescer(reader) + } + f.readers = f.inside.Readers() + for i := range f.readers { + f.tunCoalescers[i] = coalesce.NewTCPCoalescer(f.readers[i]) //todo don't always do this } f.wg.Add(1) // for us to wait on Close() to return @@ -342,7 +345,7 @@ func (f *Interface) listenOut(i int) { f.l.Debugf("underlay reader %v is done", i) } -func (f *Interface) listenIn(reader overlay.Queue, i int) { +func (f *Interface) listenIn(reader tio.Queue, i int) { rejectBuf := make([]byte, mtu) batch := newSendBatch(sendBatchCap, udp.MTU+32) fwPacket := &firewall.Packet{} diff --git a/main.go b/main.go index 8adc2921..ce769519 100644 --- a/main.go +++ b/main.go @@ -3,7 +3,10 @@ package nebula import ( "context" "fmt" + "log" "net" + "net/http" + _ "net/http/pprof" "net/netip" "runtime/debug" "strings" @@ -49,6 +52,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg l.Println(string(b)) } + //todo!!! + go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) + }() + err := configLogger(l, c) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to configure the logger", err) diff --git a/tcp_coalesce.go b/overlay/coalesce/tcp_coalesce.go similarity index 93% rename from tcp_coalesce.go rename to overlay/coalesce/tcp_coalesce.go index 26ba1181..19abd9c2 100644 --- a/tcp_coalesce.go +++ b/overlay/coalesce/tcp_coalesce.go @@ -1,11 +1,11 @@ -package nebula +package coalesce import ( "bytes" "encoding/binary" "io" - "github.com/slackhq/nebula/overlay" + "github.com/slackhq/nebula/overlay/tio" ) // ipProtoTCP is the IANA protocol number for TCP. Hardcoded instead of @@ -66,14 +66,14 @@ type coalesceSlot struct { payIovs [][]byte } -// tcpCoalescer accumulates adjacent in-flow TCP data segments across +// TCPCoalescer accumulates adjacent in-flow TCP data segments across // multiple concurrent flows and emits each flow's run as a single TSO -// superpacket via overlay.GSOWriter. All output — coalesced or not — is +// superpacket via tio.GSOWriter. All output — coalesced or not — is // deferred until Flush so arrival order is preserved on the wire. Owns // no locks; one coalescer per TUN write queue. -type tcpCoalescer struct { +type TCPCoalescer struct { plainW io.Writer - gsoW overlay.GSOWriter // nil when the queue doesn't support TSO + gsoW tio.GSOWriter // nil when the queue doesn't support TSO // slots is the ordered event queue. Flush walks it once and emits each // entry as either a WriteGSO (coalesced) or a plainW.Write (passthrough). @@ -86,14 +86,14 @@ type tcpCoalescer struct { pool []*coalesceSlot // free list for reuse } -func newTCPCoalescer(w io.Writer) *tcpCoalescer { - c := &tcpCoalescer{ +func NewTCPCoalescer(w io.Writer) *TCPCoalescer { + c := &TCPCoalescer{ plainW: w, slots: make([]*coalesceSlot, 0, initialSlots), openSlots: make(map[flowKey]*coalesceSlot, initialSlots), pool: make([]*coalesceSlot, 0, initialSlots), } - if gw, ok := w.(overlay.GSOWriter); ok && gw.GSOSupported() { + if gw, ok := w.(tio.GSOWriter); ok && gw.GSOSupported() { c.gsoW = gw } return c @@ -197,7 +197,7 @@ func (p parsedTCP) coalesceable() bool { // Add borrows pkt. The caller must keep pkt valid until the next Flush, // whether or not the packet was coalesced — passthrough (non-admissible) // packets are queued and written at Flush time, not synchronously. -func (c *tcpCoalescer) Add(pkt []byte) error { +func (c *TCPCoalescer) Add(pkt []byte) error { if c.gsoW == nil { c.addPassthrough(pkt) return nil @@ -237,7 +237,7 @@ func (c *tcpCoalescer) Add(pkt []byte) error { // via WriteGSO; passthrough slots go out via plainW.Write. Returns the // first error observed; keeps draining so one bad packet doesn't hold up // the rest. After Flush returns, borrowed payload slices may be recycled. -func (c *tcpCoalescer) Flush() error { +func (c *TCPCoalescer) Flush() error { var first error for _, s := range c.slots { var err error @@ -261,14 +261,14 @@ func (c *tcpCoalescer) Flush() error { return first } -func (c *tcpCoalescer) addPassthrough(pkt []byte) { +func (c *TCPCoalescer) addPassthrough(pkt []byte) { s := c.take() s.passthrough = true s.rawPkt = pkt c.slots = append(c.slots, s) } -func (c *tcpCoalescer) seed(pkt []byte, info parsedTCP) { +func (c *TCPCoalescer) seed(pkt []byte, info parsedTCP) { if info.hdrLen > tcpCoalesceHdrCap || info.hdrLen+info.payLen > tcpCoalesceBufSize { // Pathological shape — can't fit our scratch, emit as-is. c.addPassthrough(pkt) @@ -297,7 +297,7 @@ func (c *tcpCoalescer) seed(pkt []byte, info parsedTCP) { // canAppend reports whether info's packet extends the slot's seed: same // header shape and stable contents, adjacent seq, not oversized, chain not // closed. -func (c *tcpCoalescer) canAppend(s *coalesceSlot, pkt []byte, info parsedTCP) bool { +func (c *TCPCoalescer) canAppend(s *coalesceSlot, pkt []byte, info parsedTCP) bool { if s.psh { return false } @@ -322,7 +322,7 @@ func (c *tcpCoalescer) canAppend(s *coalesceSlot, pkt []byte, info parsedTCP) bo return true } -func (c *tcpCoalescer) appendPayload(s *coalesceSlot, pkt []byte, info parsedTCP) { +func (c *TCPCoalescer) appendPayload(s *coalesceSlot, pkt []byte, info parsedTCP) { s.payIovs = append(s.payIovs, pkt[info.hdrLen:info.hdrLen+info.payLen]) s.numSeg++ s.totalPay += info.payLen @@ -332,7 +332,7 @@ func (c *tcpCoalescer) appendPayload(s *coalesceSlot, pkt []byte, info parsedTCP } } -func (c *tcpCoalescer) take() *coalesceSlot { +func (c *TCPCoalescer) take() *coalesceSlot { if n := len(c.pool); n > 0 { s := c.pool[n-1] c.pool[n-1] = nil @@ -342,7 +342,7 @@ func (c *tcpCoalescer) take() *coalesceSlot { return &coalesceSlot{} } -func (c *tcpCoalescer) release(s *coalesceSlot) { +func (c *TCPCoalescer) release(s *coalesceSlot) { s.passthrough = false s.rawPkt = nil for i := range s.payIovs { @@ -357,7 +357,7 @@ func (c *tcpCoalescer) release(s *coalesceSlot) { // flushSlot patches the header and calls WriteGSO. Does not remove the // slot from c.slots. -func (c *tcpCoalescer) flushSlot(s *coalesceSlot) error { +func (c *TCPCoalescer) flushSlot(s *coalesceSlot) error { total := s.hdrLen + s.totalPay l4Len := total - s.ipHdrLen hdr := s.hdrBuf[:s.hdrLen] diff --git a/tcp_coalesce_test.go b/overlay/coalesce/tcp_coalesce_test.go similarity index 97% rename from tcp_coalesce_test.go rename to overlay/coalesce/tcp_coalesce_test.go index 9d7713fc..943f8e66 100644 --- a/tcp_coalesce_test.go +++ b/overlay/coalesce/tcp_coalesce_test.go @@ -1,4 +1,4 @@ -package nebula +package coalesce import ( "encoding/binary" @@ -114,7 +114,7 @@ const ( func TestCoalescerPassthroughWhenGSOUnavailable(t *testing.T) { w := &fakeTunWriter{gsoEnabled: false} - c := newTCPCoalescer(w) + c := NewTCPCoalescer(w) pkt := buildTCPv4(1000, tcpAck, []byte("hello")) if err := c.Add(pkt); err != nil { t.Fatal(err) @@ -133,7 +133,7 @@ func TestCoalescerPassthroughWhenGSOUnavailable(t *testing.T) { func TestCoalescerNonTCPPassthrough(t *testing.T) { w := &fakeTunWriter{gsoEnabled: true} - c := newTCPCoalescer(w) + c := NewTCPCoalescer(w) pkt := make([]byte, 28) pkt[0] = 0x45 binary.BigEndian.PutUint16(pkt[2:4], 28) @@ -153,7 +153,7 @@ func TestCoalescerNonTCPPassthrough(t *testing.T) { func TestCoalescerSeedThenFlushAlone(t *testing.T) { w := &fakeTunWriter{gsoEnabled: true} - c := newTCPCoalescer(w) + c := NewTCPCoalescer(w) pkt := buildTCPv4(1000, tcpAck, make([]byte, 1000)) if err := c.Add(pkt); err != nil { t.Fatal(err) @@ -180,7 +180,7 @@ func TestCoalescerSeedThenFlushAlone(t *testing.T) { func TestCoalescerCoalescesAdjacentACKs(t *testing.T) { w := &fakeTunWriter{gsoEnabled: true} - c := newTCPCoalescer(w) + c := NewTCPCoalescer(w) pay := make([]byte, 1200) if err := c.Add(buildTCPv4(1000, tcpAck, pay)); err != nil { t.Fatal(err) @@ -220,7 +220,7 @@ func TestCoalescerCoalescesAdjacentACKs(t *testing.T) { func TestCoalescerRejectsSeqGap(t *testing.T) { w := &fakeTunWriter{gsoEnabled: true} - c := newTCPCoalescer(w) + c := NewTCPCoalescer(w) pay := make([]byte, 1200) if err := c.Add(buildTCPv4(1000, tcpAck, pay)); err != nil { t.Fatal(err) @@ -239,7 +239,7 @@ func TestCoalescerRejectsSeqGap(t *testing.T) { func TestCoalescerRejectsFlagMismatch(t *testing.T) { w := &fakeTunWriter{gsoEnabled: true} - c := newTCPCoalescer(w) + c := NewTCPCoalescer(w) pay := make([]byte, 1200) if err := c.Add(buildTCPv4(1000, tcpAck, pay)); err != nil { t.Fatal(err) @@ -260,7 +260,7 @@ func TestCoalescerRejectsFlagMismatch(t *testing.T) { func TestCoalescerRejectsFIN(t *testing.T) { w := &fakeTunWriter{gsoEnabled: true} - c := newTCPCoalescer(w) + c := NewTCPCoalescer(w) fin := buildTCPv4(1000, tcpAck|tcpFin, []byte("x")) if err := c.Add(fin); err != nil { t.Fatal(err) @@ -276,7 +276,7 @@ func TestCoalescerRejectsFIN(t *testing.T) { func TestCoalescerShortLastSegmentClosesChain(t *testing.T) { w := &fakeTunWriter{gsoEnabled: true} - c := newTCPCoalescer(w) + c := NewTCPCoalescer(w) full := make([]byte, 1200) half := make([]byte, 500) if err := c.Add(buildTCPv4(1000, tcpAck, full)); err != nil { @@ -311,7 +311,7 @@ func TestCoalescerShortLastSegmentClosesChain(t *testing.T) { func TestCoalescerPSHFinalizesChain(t *testing.T) { w := &fakeTunWriter{gsoEnabled: true} - c := newTCPCoalescer(w) + c := NewTCPCoalescer(w) pay := make([]byte, 1200) if err := c.Add(buildTCPv4(1000, tcpAck, pay)); err != nil { t.Fatal(err) @@ -336,7 +336,7 @@ func TestCoalescerPSHFinalizesChain(t *testing.T) { func TestCoalescerRejectsDifferentFlow(t *testing.T) { w := &fakeTunWriter{gsoEnabled: true} - c := newTCPCoalescer(w) + c := NewTCPCoalescer(w) pay := make([]byte, 1200) p1 := buildTCPv4(1000, tcpAck, pay) p2 := buildTCPv4(2200, tcpAck, pay) @@ -358,7 +358,7 @@ func TestCoalescerRejectsDifferentFlow(t *testing.T) { func TestCoalescerRejectsIPOptions(t *testing.T) { w := &fakeTunWriter{gsoEnabled: true} - c := newTCPCoalescer(w) + c := NewTCPCoalescer(w) pay := make([]byte, 500) pkt := buildTCPv4(1000, tcpAck, pay) // Bump IHL to 6 to simulate 4 bytes of IP options. Don't actually add @@ -378,7 +378,7 @@ func TestCoalescerRejectsIPOptions(t *testing.T) { func TestCoalescerCapBySegments(t *testing.T) { w := &fakeTunWriter{gsoEnabled: true} - c := newTCPCoalescer(w) + c := NewTCPCoalescer(w) pay := make([]byte, 512) seq := uint32(1000) for i := 0; i < tcpCoalesceMaxSegs+5; i++ { @@ -402,7 +402,7 @@ func TestCoalescerCapBySegments(t *testing.T) { // flows coalesce independently in a single Flush. func TestCoalescerMultipleFlowsInSameBatch(t *testing.T) { w := &fakeTunWriter{gsoEnabled: true} - c := newTCPCoalescer(w) + c := NewTCPCoalescer(w) pay := make([]byte, 1200) // Flow A: sport 1000. Flow B: sport 3000. @@ -459,7 +459,7 @@ func TestCoalescerMultipleFlowsInSameBatch(t *testing.T) { // writing passthrough packets synchronously. func TestCoalescerPreservesArrivalOrder(t *testing.T) { w := &orderedFakeWriter{gsoEnabled: true} - c := newTCPCoalescer(w) + c := NewTCPCoalescer(w) // Sequence: coalesceable TCP, ICMP (passthrough), coalesceable TCP on // a different flow. Expected emit order: gso(X), plain(ICMP), gso(Y). pay := make([]byte, 1200) @@ -525,7 +525,7 @@ func stringSliceEq(a, b []string) bool { // packet (SYN) mid-flow only flushes its own flow, not others. func TestCoalescerInterleavedFlowsPreserveOrdering(t *testing.T) { w := &fakeTunWriter{gsoEnabled: true} - c := newTCPCoalescer(w) + c := NewTCPCoalescer(w) pay := make([]byte, 1200) // Flow A two segments. diff --git a/overlay/device.go b/overlay/device.go index 70ca01a5..f8181421 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -4,6 +4,7 @@ import ( "io" "net/netip" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" ) @@ -11,59 +12,13 @@ import ( // that don't do TSO segmentation. 65535 covers any single IP packet. const defaultBatchBufSize = 65535 -// Queue is a readable/writable tun queue. One Queue is driven by a single -// read goroutine plus concurrent writers (see Write / WriteReject below). -type Queue interface { - io.Closer - - // Read returns one or more packets. The returned slices are borrowed - // from the Queue's internal buffer and are only valid until the next - // Read or Close on this Queue — callers must encrypt or copy each - // slice before the next call. Not safe for concurrent Reads; exactly - // one goroutine per Queue reads. - Read() ([][]byte, error) - - // Write emits a single packet on the plaintext (outside→inside) - // delivery path. May run concurrently with WriteReject on the same - // Queue, but not with itself. - Write(p []byte) (int, error) - - // WriteReject writes a single packet that originated from the inside - // path (reject replies or self-forward) using scratch state distinct - // from Write, so it can run concurrently with Write on the same Queue - // without a data race. On backends without a shared-scratch Write, a - // trivial delegation to Write is acceptable. - WriteReject(p []byte) (int, error) -} - type Device interface { - Queue + io.Closer Activate() error Networks() []netip.Prefix Name() string RoutesFor(netip.Addr) routing.Gateways - SupportsMultiqueue() bool - NewMultiQueueReader() (Queue, error) -} - -// GSOWriter is implemented by Queues that can emit a TCP TSO superpacket -// assembled from a header prefix plus one or more borrowed payload -// fragments, in a single vectored write (writev with a leading -// virtio_net_hdr). This lets the coalescer avoid copying payload bytes -// between the caller's decrypt buffer and the TUN. Backends without GSO -// support return false from GSOSupported and coalescing is skipped. -// -// hdr contains the IPv4/IPv6 + TCP header prefix (mutable — callers will -// have filled in total length and pseudo-header partial). pays are -// non-overlapping payload fragments whose concatenation is the full -// superpacket payload; they are read-only from the writer's perspective -// and must remain valid until the call returns. gsoSize is the MSS: -// every segment except possibly the last is exactly that many bytes. -// csumStart is the byte offset where the TCP header begins within hdr. -// -// hdr's TCP checksum field must already hold the pseudo-header partial -// sum (single-fold, not inverted), per virtio NEEDS_CSUM semantics. -type GSOWriter interface { - WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error - GSOSupported() bool + SupportsMultiqueue() bool //todo remove? + NewMultiQueueReader() error + Readers() []tio.Queue } diff --git a/overlay/noop.go b/overlay/noop.go index dc2d3fb9..614c4241 100644 --- a/overlay/noop.go +++ b/overlay/noop.go @@ -4,6 +4,7 @@ import ( "errors" "net/netip" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" ) @@ -41,7 +42,7 @@ func (NoopTun) SupportsMultiqueue() bool { return false } -func (NoopTun) NewMultiQueueReader() (Queue, error) { +func (NoopTun) NewMultiQueueReader() (tio.Queue, error) { return nil, errors.New("unsupported") } diff --git a/overlay/tio/container_gso_linux.go b/overlay/tio/container_gso_linux.go new file mode 100644 index 00000000..f5260d68 --- /dev/null +++ b/overlay/tio/container_gso_linux.go @@ -0,0 +1,70 @@ +package tio + +import ( + "encoding/binary" + "errors" + "fmt" + + "golang.org/x/sys/unix" +) + +type gsoContainer struct { + pq []*tunFile + // pqi is exactly the same as pq, but stored as the interface type + pqi []Queue + shutdownFd int +} + +func NewGSOContainer() (Container, error) { + shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC) + if err != nil { + return nil, fmt.Errorf("failed to create eventfd: %w", err) + } + + out := &gsoContainer{ + pq: []*tunFile{}, + pqi: []Queue{}, + shutdownFd: shutdownFd, + } + + return out, nil +} + +func (c *gsoContainer) Queues() []Queue { + return c.pqi +} + +func (c *gsoContainer) Add(fd int) error { + x, err := newTunFd(fd, c.shutdownFd) + if err != nil { + return err + } + c.pq = append(c.pq, x) + c.pqi = append(c.pqi, x) + + return nil +} + +func (c *gsoContainer) wakeForShutdown() error { + var buf [8]byte + binary.NativeEndian.PutUint64(buf[:], 1) + _, err := unix.Write(int(c.shutdownFd), buf[:]) + return err +} + +func (c *gsoContainer) Close() error { + errs := []error{} + + // Signal all readers blocked in poll to wake up and exit + if err := c.wakeForShutdown(); err != nil { + errs = append(errs, err) + } + + for _, x := range c.pq { + if err := x.Close(); err != nil { + errs = append(errs, err) + } + } + + return errors.Join(errs...) +} diff --git a/overlay/tio/container_poll_linux.go b/overlay/tio/container_poll_linux.go new file mode 100644 index 00000000..fa6367e7 --- /dev/null +++ b/overlay/tio/container_poll_linux.go @@ -0,0 +1,69 @@ +package tio + +import ( + "encoding/binary" + "errors" + "fmt" + + "golang.org/x/sys/unix" +) + +type pollContainer struct { + pq []*Poll + // pqi is exactly the same as pq, but stored as the interface type + pqi []Queue + shutdownFd int +} + +func NewPollContainer() (Container, error) { + shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC) + if err != nil { + return nil, fmt.Errorf("failed to create eventfd: %w", err) + } + + out := &pollContainer{ + pq: []*Poll{}, + pqi: []Queue{}, + shutdownFd: shutdownFd, + } + + return out, nil +} + +func (c *pollContainer) Queues() []Queue { + return c.pqi +} + +func (c *pollContainer) Add(fd int) error { + x, err := newPoll(fd, c.shutdownFd) + if err != nil { + return err + } + c.pq = append(c.pq, x) + c.pqi = append(c.pqi, x) + + return nil +} + +func (c *pollContainer) wakeForShutdown() error { + var buf [8]byte + binary.NativeEndian.PutUint64(buf[:], 1) + _, err := unix.Write(int(c.shutdownFd), buf[:]) + return err +} + +func (c *pollContainer) Close() error { + errs := []error{} + + if err := c.wakeForShutdown(); err != nil { + errs = append(errs, err) + } + + for _, x := range c.pq { + if err := x.Close(); err != nil { + errs = append(errs, err) + } + } + + return errors.Join(errs...) +} diff --git a/overlay/tio/tio.go b/overlay/tio/tio.go new file mode 100644 index 00000000..c567efdd --- /dev/null +++ b/overlay/tio/tio.go @@ -0,0 +1,63 @@ +package tio + +import "io" + +// defaultBatchBufSize is the per-Queue scratch size for Read on backends +// that don't do TSO segmentation. 65535 covers any single IP packet. +const defaultBatchBufSize = 65535 + +type Container interface { + Queues() []Queue + Add(fd int) error + + io.Closer +} + +// Queue is a readable/writable Poll queue. One Queue is driven by a single +// read goroutine plus concurrent writers (see Write / WriteReject below). +type Queue interface { + io.Closer + + // Read returns one or more packets. The returned slices are borrowed + // from the Queue's internal buffer and are only valid until the next + // Read or Close on this Queue — callers must encrypt or copy each + // slice before the next call. Not safe for concurrent Reads; exactly + // one goroutine per Queue reads. + Read() ([][]byte, error) + + // Write emits a single packet on the plaintext (outside→inside) + // delivery path. May run concurrently with WriteReject on the same + // Queue, but not with itself. + Write(p []byte) (int, error) + + // WriteReject writes a single packet that originated from the inside + // path (reject replies or self-forward) using scratch state distinct + // from Write, so it can run concurrently with Write on the same Queue + // without a data race. On backends without a shared-scratch Write, a + // trivial delegation to Write is acceptable. + WriteReject(p []byte) (int, error) +} + +// GSOWriter is implemented by Queues that can emit a TCP TSO superpacket +// assembled from a header prefix plus one or more borrowed payload +// fragments, in a single vectored write (writev with a leading +// virtio_net_hdr). This lets the coalescer avoid copying payload bytes +// between the caller's decrypt buffer and the TUN. Backends without GSO +// support return false from GSOSupported and coalescing is skipped. +// +// hdr contains the IPv4/IPv6 + TCP header prefix (mutable — callers will +// have filled in total length and pseudo-header partial). pays are +// non-overlapping payload fragments whose concatenation is the full +// superpacket payload; they are read-only from the writer's perspective +// and must remain valid until the call returns. gsoSize is the MSS: +// every segment except possibly the last is exactly that many bytes. +// csumStart is the byte offset where the TCP header begins within hdr. +// +// # TODO fold into Queue +// +// hdr's TCP checksum field must already hold the pseudo-header partial +// sum (single-fold, not inverted), per virtio NEEDS_CSUM semantics. +type GSOWriter interface { + WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error + GSOSupported() bool +} diff --git a/overlay/tio/tio_gso_linux.go b/overlay/tio/tio_gso_linux.go new file mode 100644 index 00000000..4ab2ca06 --- /dev/null +++ b/overlay/tio/tio_gso_linux.go @@ -0,0 +1,405 @@ +package tio + +import ( + "fmt" + "io" + "os" + "runtime" + "sync/atomic" + "syscall" + "unsafe" + + "golang.org/x/sys/unix" +) + +// Space for segmented output. Worst case is many small segments, each paying +// an IP+TCP header. 128KiB comfortably covers the 64KiB payload ceiling. +const tunSegBufSize = 131072 + +// tunSegBufCap is the total size we allocate for the per-reader segment +// buffer. It is sized as one worst-case TSO superpacket (tunSegBufSize) plus +// the same again as drain headroom so a Read wake can accumulate +// additional packets after an initial big read without overflowing. +const tunSegBufCap = tunSegBufSize * 2 + +// tunDrainCap caps how many packets a single Read will accumulate via +// the post-wake drain loop. Sized to soak up a burst of small ACKs while +// bounding how much work a single caller holds before handing off. +const tunDrainCap = 64 + +// gsoInitialPayIovs is the starting capacity (in payload fragments) of +// tunFile.gsoIovs. Sized to cover the default coalesce segment cap without +// any reallocations. +const gsoInitialPayIovs = 66 + +// validVnetHdr is the 10-byte virtio_net_hdr we prepend to every non-GSO TUN +// write. Only flag set is VIRTIO_NET_HDR_F_DATA_VALID, which marks the skb +// CHECKSUM_UNNECESSARY so the receiving network stack skips L4 checksum +// verification. All packets that reach the plain Write / WriteReject paths +// already carry a valid L4 checksum (either supplied by a remote peer whose +// ciphertext we AEAD-authenticated, or produced by finishChecksum during TSO +// segmentation, or built locally by CreateRejectPacket), so trusting them is +// safe. +var validVnetHdr = [virtioNetHdrLen]byte{unix.VIRTIO_NET_HDR_F_DATA_VALID} + +// tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking. +// A shared eventfd allows Close to wake all readers blocked in poll. +type tunFile struct { //todo rename GSO + fd int + shutdownFd int + readPoll [2]unix.PollFd + writePoll [2]unix.PollFd + closed atomic.Bool + readBuf []byte // scratch for a single raw read (virtio hdr + superpacket) + segBuf []byte // backing store for segmented output + segOff int // cursor into segBuf for the current Read drain + pending [][]byte // segments returned from the most recent Read + writeIovs [2]unix.Iovec // preallocated iovecs for Write (coalescer passthrough); iovs[0] is fixed to validVnetHdr + // rejectIovs is a second preallocated iovec scratch used exclusively by + // WriteReject (reject + self-forward from the inside path). It mirrors + // writeIovs but lets listenIn goroutines emit reject packets without + // racing with the listenOut coalescer that owns writeIovs. + rejectIovs [2]unix.Iovec + + // gsoHdrBuf is a per-queue 10-byte scratch for the virtio_net_hdr emitted + // by WriteGSO. Separate from validVnetHdr so a concurrent non-GSO Write on + // another queue never observes a half-written header. + gsoHdrBuf [virtioNetHdrLen]byte + // gsoIovs is the writev iovec scratch for WriteGSO. Sized to hold the + // virtio header + IP/TCP header + up to gsoInitialPayIovs payload + // fragments; grown on demand if a coalescer pushes more. + gsoIovs []unix.Iovec +} + +func newTunFd(fd int, shutdownFd int) (*tunFile, error) { + if err := unix.SetNonblock(fd, true); err != nil { + return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) + } + + out := &tunFile{ + fd: fd, + shutdownFd: shutdownFd, + closed: atomic.Bool{}, + readBuf: make([]byte, tunReadBufSize), + readPoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLIN}, + {Fd: int32(shutdownFd), Events: unix.POLLIN}, + }, + writePoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLOUT}, + {Fd: int32(shutdownFd), Events: unix.POLLIN}, + }, + + segBuf: make([]byte, tunSegBufSize), + gsoIovs: make([]unix.Iovec, 2, 2+gsoInitialPayIovs), + } + + out.writeIovs[0].Base = &validVnetHdr[0] + out.writeIovs[0].SetLen(virtioNetHdrLen) + out.rejectIovs[0].Base = &validVnetHdr[0] + out.rejectIovs[0].SetLen(virtioNetHdrLen) + out.gsoIovs[0].Base = &out.gsoHdrBuf[0] + out.gsoIovs[0].SetLen(virtioNetHdrLen) + + return out, nil +} + +func (r *tunFile) blockOnRead() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error + for { + _, err = unix.Poll(r.readPoll[:], -1) + if err != unix.EINTR { + break + } + } + //always reset these! + tunEvents := r.readPoll[0].Revents + shutdownEvents := r.readPoll[1].Revents + r.readPoll[0].Revents = 0 + r.readPoll[1].Revents = 0 + //do the err check before trusting the potentially bogus bits we just got + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } else if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil +} + +func (r *tunFile) blockOnWrite() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error + for { + _, err = unix.Poll(r.writePoll[:], -1) + if err != unix.EINTR { + break + } + } + //always reset these! + tunEvents := r.writePoll[0].Revents + shutdownEvents := r.writePoll[1].Revents + r.writePoll[0].Revents = 0 + r.writePoll[1].Revents = 0 + //do the err check before trusting the potentially bogus bits we just got + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } else if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil +} + +func (r *tunFile) readRaw(buf []byte) (int, error) { + for { + if n, err := unix.Read(r.fd, buf); err == nil { + return n, nil + } else if err == unix.EAGAIN { + if err = r.blockOnRead(); err != nil { + return 0, err + } + continue + } else if err == unix.EINTR { + continue + } else if err == unix.EBADF { + return 0, os.ErrClosed + } else { + return 0, err + } + } +} + +// Read reads one or more superpackets from the tun and returns the +// resulting packets. The first read blocks via poll; once the fd is known +// readable we drain additional packets non-blocking until the kernel queue +// is empty (EAGAIN), we've collected tunDrainCap packets, or we're out of +// segBuf headroom. This amortizes the poll wake over bursts of small +// packets (e.g. TCP ACKs). Slices point into the tunFile's internal buffers +// and are only valid until the next Read or Close on this Queue. +func (r *tunFile) Read() ([][]byte, error) { + r.pending = r.pending[:0] + r.segOff = 0 + + // Initial (blocking) read. Retry on decode errors so a single bad + // packet does not stall the reader. + for { + n, err := r.readRaw(r.readBuf) + if err != nil { + return nil, err + } + if err := r.decodeRead(n); err != nil { + // Drop and read again — a bad packet should not kill the reader. + continue + } + break + } + + // Drain: non-blocking reads until the kernel queue is empty, the drain + // cap is reached, or segBuf no longer has room for another worst-case + // superpacket. + for len(r.pending) < tunDrainCap && tunSegBufCap-r.segOff >= tunSegBufSize { + n, err := unix.Read(r.fd, r.readBuf) + if err != nil { + // EAGAIN / EINTR / anything else: stop draining. We already + // have a valid batch from the first read. + break + } + if n <= 0 { + break + } + if err := r.decodeRead(n); err != nil { + // Drop this packet and stop the drain; we'd rather hand off + // what we have than keep spinning here. + break + } + } + + return r.pending, nil +} + +// decodeRead decodes the virtio header plus payload in r.readBuf[:n], appends +// the segments to r.pending, and advances r.segOff by the total scratch used. +// Caller must have already ensured r.vnetHdr is true. +func (r *tunFile) decodeRead(n int) error { + if n < virtioNetHdrLen { + return fmt.Errorf("short tun read: %d < %d", n, virtioNetHdrLen) + } + var hdr VirtioNetHdr + hdr.decode(r.readBuf[:virtioNetHdrLen]) + before := len(r.pending) + if err := segmentInto(r.readBuf[virtioNetHdrLen:n], hdr, &r.pending, r.segBuf[r.segOff:]); err != nil { + return err + } + for k := before; k < len(r.pending); k++ { + r.segOff += len(r.pending[k]) + } + return nil +} + +func (r *tunFile) Write(buf []byte) (int, error) { + return r.writeWithScratch(buf, &r.writeIovs) +} + +// WriteReject emits a packet using a dedicated iovec scratch (rejectIovs) +// distinct from the one used by the coalescer's Write path. This avoids a +// data race between the inside (listenIn) goroutine emitting reject or +// self-forward packets and the outside (listenOut) goroutine flushing TCP +// coalescer passthroughs on the same tunFile. +func (r *tunFile) WriteReject(buf []byte) (int, error) { + return r.writeWithScratch(buf, &r.rejectIovs) +} + +func (r *tunFile) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error) { + if len(buf) == 0 { + return 0, nil + } + // Point the payload iovec at the caller's buffer. iovs[0] is pre-wired + // to validVnetHdr during tunFile construction so we don't rebuild it here. + iovs[1].Base = &buf[0] + iovs[1].SetLen(len(buf)) + iovPtr := uintptr(unsafe.Pointer(&iovs[0])) + // The TUN fd is non-blocking (set in newTunFd / newFriend), so writev + // either completes promptly or returns EAGAIN — it cannot park the + // goroutine inside the kernel. That lets us use syscall.RawSyscall and + // skip the runtime.entersyscall / exitsyscall bookkeeping on every + // packet; we only pay that cost when we fall through to blockOnWrite. + for { + n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, 2) + if errno == 0 { + runtime.KeepAlive(buf) + if int(n) < virtioNetHdrLen { + return 0, io.ErrShortWrite + } + return int(n) - virtioNetHdrLen, nil + } + if errno == unix.EAGAIN { + runtime.KeepAlive(buf) + if err := r.blockOnWrite(); err != nil { + return 0, err + } + continue + } + if errno == unix.EINTR { + continue + } + if errno == unix.EBADF { + return 0, os.ErrClosed + } + runtime.KeepAlive(buf) + return 0, errno + } +} + +// GSOSupported reports whether this queue was opened with IFF_VNET_HDR and +// can accept WriteGSO. When false, callers should fall back to per-segment +// Write calls. +func (r *tunFile) GSOSupported() bool { return true } + +// WriteGSO emits a TCP TSO superpacket in a single writev. hdr is the +// IPv4/IPv6 + TCP header prefix (already finalized — total length, IP csum, +// and TCP pseudo-header partial set by the caller). pays are payload +// fragments whose concatenation forms the full coalesced payload; each +// slice is read-only and must stay valid until return. gsoSize is the MSS; +// every segment except possibly the last is exactly gsoSize bytes. +// csumStart is the byte offset where the TCP header begins within hdr. +func (r *tunFile) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error { + if len(hdr) == 0 || len(pays) == 0 { + return nil + } + + // Build the virtio_net_hdr. When pays total to <= gsoSize the kernel + // would produce a single segment; keep NEEDS_CSUM semantics but skip + // the GSO type so the kernel doesn't spuriously mark this as TSO. + vhdr := VirtioNetHdr{ + Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + HdrLen: uint16(len(hdr)), + GSOSize: gsoSize, + CsumStart: csumStart, + CsumOffset: 16, // TCP checksum field lives 16 bytes into the TCP header + } + var totalPay int + for _, p := range pays { + totalPay += len(p) + } + if totalPay > int(gsoSize) { + if isV6 { + vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV6 + } else { + vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV4 + } + } else { + vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE + vhdr.GSOSize = 0 + } + vhdr.encode(r.gsoHdrBuf[:]) + + // Build the iovec array: [virtio_hdr, hdr, pays...]. r.gsoIovs[0] is + // wired to gsoHdrBuf at construction and never changes. + need := 2 + len(pays) + if cap(r.gsoIovs) < need { + grown := make([]unix.Iovec, need) + grown[0] = r.gsoIovs[0] + r.gsoIovs = grown + } else { + r.gsoIovs = r.gsoIovs[:need] + } + r.gsoIovs[1].Base = &hdr[0] + r.gsoIovs[1].SetLen(len(hdr)) + for i, p := range pays { + r.gsoIovs[2+i].Base = &p[0] + r.gsoIovs[2+i].SetLen(len(p)) + } + + iovPtr := uintptr(unsafe.Pointer(&r.gsoIovs[0])) + iovCnt := uintptr(len(r.gsoIovs)) + for { + n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, iovCnt) + if errno == 0 { + runtime.KeepAlive(hdr) + runtime.KeepAlive(pays) + if int(n) < virtioNetHdrLen { + return io.ErrShortWrite + } + return nil + } + if errno == unix.EAGAIN { + runtime.KeepAlive(hdr) + runtime.KeepAlive(pays) + if err := r.blockOnWrite(); err != nil { + return err + } + continue + } + if errno == unix.EINTR { + continue + } + if errno == unix.EBADF { + return os.ErrClosed + } + runtime.KeepAlive(hdr) + runtime.KeepAlive(pays) + return errno + } +} + +func (r *tunFile) Close() error { + if r.closed.Swap(true) { + return nil + } + + //shutdownFd is owned by the container, so we should not close it + + var err error + if r.fd >= 0 { + err = unix.Close(r.fd) + r.fd = -1 + } + + return err +} diff --git a/overlay/tio/tio_poll_linux.go b/overlay/tio/tio_poll_linux.go new file mode 100644 index 00000000..4575b8d3 --- /dev/null +++ b/overlay/tio/tio_poll_linux.go @@ -0,0 +1,205 @@ +package tio + +import ( + "fmt" + "os" + "sync/atomic" + "syscall" + "unsafe" + + "golang.org/x/sys/unix" +) + +// Maximum size we accept for a single read from a TUN with IFF_VNET_HDR. A +// TSO superpacket can be up to 64KiB of payload plus a single L2/L3/L4 header +// prefix plus the virtio header. +const tunReadBufSize = 65535 + +type Poll struct { + fd int + + readPoll [2]unix.PollFd + writePoll [2]unix.PollFd + closed atomic.Bool + + readBuf []byte + batchRet [1][]byte +} + +func newPoll(fd int, shutdownFd int) (*Poll, error) { + if err := unix.SetNonblock(fd, true); err != nil { + _ = unix.Close(fd) + return nil, fmt.Errorf("failed to set Poll device as nonblocking: %w", err) + } + + out := &Poll{ + fd: fd, + readBuf: make([]byte, tunReadBufSize), + readPoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLIN}, + {Fd: int32(shutdownFd), Events: unix.POLLIN}, + }, + writePoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLOUT}, + {Fd: int32(shutdownFd), Events: unix.POLLIN}, + }, + } + return out, nil +} + +// blockOnRead waits until the Poll fd is readable or shutdown has been signaled. +// Returns os.ErrClosed if Close was called. +func (t *Poll) blockOnRead() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error + for { + _, err = unix.Poll(t.readPoll[:], -1) + if err != unix.EINTR { + break + } + } + tunEvents := t.readPoll[0].Revents + shutdownEvents := t.readPoll[1].Revents + t.readPoll[0].Revents = 0 + t.readPoll[1].Revents = 0 + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } + if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil +} + +func (t *Poll) blockOnWrite() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error + for { + _, err = unix.Poll(t.writePoll[:], -1) + if err != unix.EINTR { + break + } + } + tunEvents := t.writePoll[0].Revents + shutdownEvents := t.writePoll[1].Revents + t.writePoll[0].Revents = 0 + t.writePoll[1].Revents = 0 + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } + if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil +} + +func (t *Poll) Read() ([][]byte, error) { + if t.readBuf == nil { + t.readBuf = make([]byte, defaultBatchBufSize) + } + n, err := t.readOne(t.readBuf) + if err != nil { + return nil, err + } + t.batchRet[0] = t.readBuf[:n] + return t.batchRet[:], nil +} + +func (t *Poll) readOne(to []byte) (int, error) { + // first 4 bytes is protocol family, in network byte order + var head [4]byte + iovecs := [2]syscall.Iovec{ //todo plat-specific + {&head[0], 4}, + {&to[0], uint64(len(to))}, + } + for { + n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.fd), uintptr(unsafe.Pointer(&iovecs[0])), 2) + if errno == 0 { + bytesRead := int(n) + if bytesRead < 4 { + return 0, nil + } + return bytesRead - 4, nil + } + switch errno { + case unix.EAGAIN: + if err := t.blockOnRead(); err != nil { + return 0, err + } + case unix.EINTR: + // retry + case unix.EBADF: + return 0, os.ErrClosed + default: + return 0, errno + } + } +} + +// Write is only valid for single threaded use +func (t *Poll) Write(from []byte) (int, error) { + if len(from) <= 1 { + return 0, syscall.EIO + } + + ipVer := from[0] >> 4 + var head [4]byte + // first 4 bytes is protocol family, in network byte order + switch ipVer { + case 4: + head[3] = syscall.AF_INET + case 6: + head[3] = syscall.AF_INET6 + default: + return 0, fmt.Errorf("unable to determine IP version from packet") + } + + iovecs := [2]syscall.Iovec{ //todo plat specific + {&head[0], 4}, + {&from[0], uint64(len(from))}, + } + for { + n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.fd), uintptr(unsafe.Pointer(&iovecs[0])), 2) + if errno == 0 { + return int(n) - 4, nil + } + switch errno { + case unix.EAGAIN: + if err := t.blockOnWrite(); err != nil { + return 0, err + } + case unix.EINTR: + // retry + case unix.EBADF: + return 0, os.ErrClosed + default: + return 0, errno + } + } +} + +func (t *Poll) Close() error { + if t.closed.Swap(true) { + return nil + } + + //shutdownFd is owned by the container, so we should not close it + + var err error + if t.fd >= 0 { + err = unix.Close(t.fd) + t.fd = -1 + } + + return err +} + +func (t *Poll) WriteReject(p []byte) (int, error) { + return t.Write(p) +} diff --git a/overlay/tun_file_linux_test.go b/overlay/tio/tun_file_linux_test.go similarity index 99% rename from overlay/tun_file_linux_test.go rename to overlay/tio/tun_file_linux_test.go index 5ab87e05..d162c7b3 100644 --- a/overlay/tun_file_linux_test.go +++ b/overlay/tio/tun_file_linux_test.go @@ -1,7 +1,7 @@ //go:build linux && !android && !e2e_testing // +build linux,!android,!e2e_testing -package overlay +package tio import ( "errors" diff --git a/overlay/tun_linux_offload.go b/overlay/tio/tun_linux_offload.go similarity index 79% rename from overlay/tun_linux_offload.go rename to overlay/tio/tun_linux_offload.go index 2d6e9a58..cc01b3e4 100644 --- a/overlay/tun_linux_offload.go +++ b/overlay/tio/tun_linux_offload.go @@ -1,7 +1,7 @@ //go:build linux && !android && !e2e_testing // +build linux,!android,!e2e_testing -package overlay +package tio import ( "encoding/binary" @@ -10,66 +10,10 @@ import ( "golang.org/x/sys/unix" ) -// Size of the legacy struct virtio_net_hdr that the kernel prepends/expects on -// a TUN opened with IFF_VNET_HDR (TUNSETVNETHDRSZ not set). -const virtioNetHdrLen = 10 - -// Maximum size we accept for a single read from a TUN with IFF_VNET_HDR. A -// TSO superpacket can be up to 64KiB of payload plus a single L2/L3/L4 header -// prefix plus the virtio header. -const tunReadBufSize = 65535 - -// Space for segmented output. Worst case is many small segments, each paying -// an IP+TCP header. 128KiB comfortably covers the 64KiB payload ceiling. -const tunSegBufSize = 131072 - -// tunSegBufCap is the total size we allocate for the per-reader segment -// buffer. It is sized as one worst-case TSO superpacket (tunSegBufSize) plus -// the same again as drain headroom so a Read wake can accumulate -// additional packets after an initial big read without overflowing. -const tunSegBufCap = tunSegBufSize * 2 - -// tunDrainCap caps how many packets a single Read will accumulate via -// the post-wake drain loop. Sized to soak up a burst of small ACKs while -// bounding how much work a single caller holds before handing off. -const tunDrainCap = 64 - -type virtioNetHdr struct { - Flags uint8 - GSOType uint8 - HdrLen uint16 - GSOSize uint16 - CsumStart uint16 - CsumOffset uint16 -} - -// decode reads a virtio_net_hdr in host byte order (TUN default; we never -// call TUNSETVNETLE so the kernel matches our endianness). -func (h *virtioNetHdr) decode(b []byte) { - h.Flags = b[0] - h.GSOType = b[1] - h.HdrLen = binary.NativeEndian.Uint16(b[2:4]) - h.GSOSize = binary.NativeEndian.Uint16(b[4:6]) - h.CsumStart = binary.NativeEndian.Uint16(b[6:8]) - h.CsumOffset = binary.NativeEndian.Uint16(b[8:10]) -} - -// encode is the inverse of decode: writes the virtio_net_hdr fields into b -// (must be at least virtioNetHdrLen bytes). Used to emit a TSO superpacket -// on egress. -func (h *virtioNetHdr) encode(b []byte) { - b[0] = h.Flags - b[1] = h.GSOType - binary.NativeEndian.PutUint16(b[2:4], h.HdrLen) - binary.NativeEndian.PutUint16(b[4:6], h.GSOSize) - binary.NativeEndian.PutUint16(b[6:8], h.CsumStart) - binary.NativeEndian.PutUint16(b[8:10], h.CsumOffset) -} - // segmentInto splits a TUN-side packet described by hdr into one or more // IP packets, each appended to *out as a slice of scratch. scratch must be // sized to hold every segment (including replicated headers). -func segmentInto(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) error { +func segmentInto(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) error { // When RSC_INFO is set the csum_start/csum_offset fields are repurposed to // carry coalescing info rather than checksum offsets. A TUN writing via // IFF_VNET_HDR should never emit this, but if it did we would silently @@ -105,7 +49,7 @@ func segmentInto(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) er // handed us with NEEDS_CSUM set. csum_start / csum_offset point at the 16-bit // checksum field; we zero it, fold a full sum (the field was pre-loaded with // the pseudo-header partial sum by the kernel), and store the result. -func finishChecksum(seg []byte, hdr virtioNetHdr) error { +func finishChecksum(seg []byte, hdr VirtioNetHdr) error { cs := int(hdr.CsumStart) co := int(hdr.CsumOffset) if cs+co+2 > len(seg) { @@ -129,7 +73,7 @@ func finishChecksum(seg []byte, hdr virtioNetHdr) error { // are each summed once up front — every segment reuses those three pre-folded // uint32 values and combines them with small per-segment deltas (seq, flags, // tcpLen, ip_id, total_len) that are cheap to fold in. -func segmentTCP(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) error { +func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) error { if hdr.GSOSize == 0 { return fmt.Errorf("gso_size is zero") } diff --git a/overlay/tun_linux_offload_test.go b/overlay/tio/tun_linux_offload_test.go similarity index 97% rename from overlay/tun_linux_offload_test.go rename to overlay/tio/tun_linux_offload_test.go index 650165bc..252d823d 100644 --- a/overlay/tun_linux_offload_test.go +++ b/overlay/tio/tun_linux_offload_test.go @@ -1,7 +1,7 @@ //go:build linux && !android && !e2e_testing // +build linux,!android,!e2e_testing -package overlay +package tio import ( "encoding/binary" @@ -23,7 +23,7 @@ func verifyChecksum(b []byte, pseudo uint32) bool { // buildTSOv4 builds a synthetic IPv4/TCP TSO superpacket with a payload of // `payLen` bytes split at `mss`. -func buildTSOv4(t *testing.T, payLen, mss int) ([]byte, virtioNetHdr) { +func buildTSOv4(t *testing.T, payLen, mss int) ([]byte, VirtioNetHdr) { t.Helper() const ipLen = 20 const tcpLen = 20 @@ -53,7 +53,7 @@ func buildTSOv4(t *testing.T, payLen, mss int) ([]byte, virtioNetHdr) { pkt[ipLen+tcpLen+i] = byte(i & 0xff) } - return pkt, virtioNetHdr{ + return pkt, VirtioNetHdr{ Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4, HdrLen: uint16(ipLen + tcpLen), @@ -174,7 +174,7 @@ func TestSegmentTCPv6(t *testing.T) { pkt[ipLen+tcpLen+i] = byte(i) } - hdr := virtioNetHdr{ + hdr := VirtioNetHdr{ Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV6, HdrLen: uint16(ipLen + tcpLen), @@ -240,7 +240,7 @@ func TestSegmentGSONonePassesThrough(t *testing.T) { } func TestSegmentRejectsUDP(t *testing.T) { - hdr := virtioNetHdr{GSOType: unix.VIRTIO_NET_HDR_GSO_UDP} + hdr := VirtioNetHdr{GSOType: unix.VIRTIO_NET_HDR_GSO_UDP} var out [][]byte if err := segmentInto(nil, hdr, &out, nil); err == nil { t.Fatalf("expected rejection for UDP GSO") @@ -279,7 +279,7 @@ func BenchmarkSegmentTCPv4(b *testing.B) { for i := 0; i < sz.payLen; i++ { pkt[ipLen+tcpLen+i] = byte(i) } - hdr := virtioNetHdr{ + hdr := VirtioNetHdr{ Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4, HdrLen: uint16(ipLen + tcpLen), @@ -312,7 +312,7 @@ func TestTunFileWriteVnetHdrNoAlloc(t *testing.T) { } t.Cleanup(func() { _ = unix.Close(fd) }) - tf := &tunFile{fd: fd, vnetHdr: true} + tf := &tunFile{fd: fd} tf.writeIovs[0].Base = &validVnetHdr[0] tf.writeIovs[0].SetLen(virtioNetHdrLen) diff --git a/overlay/tio/tun_linux_test.go b/overlay/tio/tun_linux_test.go new file mode 100644 index 00000000..607ed828 --- /dev/null +++ b/overlay/tio/tun_linux_test.go @@ -0,0 +1,38 @@ +//go:build !e2e_testing +// +build !e2e_testing + +package tio + +import ( + "testing" + + "github.com/slackhq/nebula/overlay" +) + +var runAdvMSSTests = []struct { + name string + tun *overlay.tun + r overlay.Route + expected int +}{ + // Standard case, default MTU is the device max MTU + {"default", &overlay.tun{DefaultMTU: 1440, MaxMTU: 1440}, overlay.Route{}, 0}, + {"default-min", &overlay.tun{DefaultMTU: 1440, MaxMTU: 1440}, overlay.Route{MTU: 1440}, 0}, + {"default-low", &overlay.tun{DefaultMTU: 1440, MaxMTU: 1440}, overlay.Route{MTU: 1200}, 1160}, + + // Case where we have a route MTU set higher than the default + {"route", &overlay.tun{DefaultMTU: 1440, MaxMTU: 8941}, overlay.Route{}, 1400}, + {"route-min", &overlay.tun{DefaultMTU: 1440, MaxMTU: 8941}, overlay.Route{MTU: 1440}, 1400}, + {"route-high", &overlay.tun{DefaultMTU: 1440, MaxMTU: 8941}, overlay.Route{MTU: 8941}, 0}, +} + +func TestTunAdvMSS(t *testing.T) { + for _, tt := range runAdvMSSTests { + t.Run(tt.name, func(t *testing.T) { + o := tt.tun.advMSS(tt.r) + if o != tt.expected { + t.Errorf("got %d, want %d", o, tt.expected) + } + }) + } +} diff --git a/overlay/tio/vnethdr_linux.go b/overlay/tio/vnethdr_linux.go new file mode 100644 index 00000000..dc4ab6cb --- /dev/null +++ b/overlay/tio/vnethdr_linux.go @@ -0,0 +1,39 @@ +package tio + +import "encoding/binary" + +// Size of the legacy struct virtio_net_hdr that the kernel prepends/expects on +// a TUN opened with IFF_VNET_HDR (TUNSETVNETHDRSZ not set). +const virtioNetHdrLen = 10 + +type VirtioNetHdr struct { + Flags uint8 + GSOType uint8 + HdrLen uint16 + GSOSize uint16 + CsumStart uint16 + CsumOffset uint16 +} + +// decode reads a virtio_net_hdr in host byte order (TUN default; we never +// call TUNSETVNETLE so the kernel matches our endianness). +func (h *VirtioNetHdr) decode(b []byte) { + h.Flags = b[0] + h.GSOType = b[1] + h.HdrLen = binary.NativeEndian.Uint16(b[2:4]) + h.GSOSize = binary.NativeEndian.Uint16(b[4:6]) + h.CsumStart = binary.NativeEndian.Uint16(b[6:8]) + h.CsumOffset = binary.NativeEndian.Uint16(b[8:10]) +} + +// encode is the inverse of decode: writes the virtio_net_hdr fields into b +// (must be at least virtioNetHdrLen bytes). Used to emit a TSO superpacket +// on egress. +func (h *VirtioNetHdr) encode(b []byte) { + b[0] = h.Flags + b[1] = h.GSOType + binary.NativeEndian.PutUint16(b[2:4], h.HdrLen) + binary.NativeEndian.PutUint16(b[4:6], h.GSOSize) + binary.NativeEndian.PutUint16(b[6:8], h.CsumStart) + binary.NativeEndian.PutUint16(b[8:10], h.CsumOffset) +} diff --git a/overlay/tun_android.go b/overlay/tun_android.go index 62de337d..8f541e18 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -13,6 +13,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) @@ -126,6 +127,6 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (Queue, error) { +func (t *tun) NewMultiQueueReader() (tio.Queue, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for android") } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 7f50c705..c2843697 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -16,6 +16,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" @@ -572,6 +573,6 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (Queue, error) { +func (t *tun) NewMultiQueueReader() (tio.Queue, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") } diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index 8a691ae0..8ce9ef1d 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -9,6 +9,7 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" ) @@ -17,9 +18,10 @@ type disabledTun struct { vpnNetworks []netip.Prefix // Track these metrics since we don't have the tun device to do it for us - tx metrics.Counter - rx metrics.Counter - l *logrus.Logger + tx metrics.Counter + rx metrics.Counter + l *logrus.Logger + numReaders int batchRet [1][]byte } @@ -44,6 +46,7 @@ func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled boo vpnNetworks: vpnNetworks, read: make(chan []byte, queueLen), l: l, + numReaders: 1, } if metricsEnabled { @@ -112,8 +115,17 @@ func (t *disabledTun) SupportsMultiqueue() bool { return true } -func (t *disabledTun) NewMultiQueueReader() (Queue, error) { - return t, nil +func (t *disabledTun) NewMultiQueueReader() error { + t.numReaders++ + return nil +} + +func (t *disabledTun) Readers() []tio.Queue { + out := make([]tio.Queue, t.numReaders) + for i := range t.numReaders { + out[i] = t + } + return out } func (t *disabledTun) Close() error { diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 68278932..a73ed180 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -18,6 +18,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" @@ -581,7 +582,7 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (Queue, error) { +func (t *tun) NewMultiQueueReader() (tio.Queue, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd") } diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index ebf134b8..e64a5663 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -16,6 +16,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) @@ -182,6 +183,6 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (Queue, error) { +func (t *tun) NewMultiQueueReader() (tio.Queue, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for ios") } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 3a75685b..a81a958d 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -4,478 +4,28 @@ package overlay import ( - "encoding/binary" "fmt" - "io" "net" "net/netip" "os" - "runtime" "strings" "sync" "sync/atomic" - "syscall" "time" "unsafe" "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" "github.com/vishvananda/netlink" "golang.org/x/sys/unix" ) -// tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking. -// A shared eventfd allows Close to wake all readers blocked in poll. -type tunFile struct { - fd int - shutdownFd int - lastOne bool - readPoll [2]unix.PollFd - writePoll [2]unix.PollFd - closed bool - - // vnetHdr is true when this fd was opened with IFF_VNET_HDR and the - // kernel successfully accepted TUNSETOFFLOAD. Reads include a leading - // virtio_net_hdr and may carry a TSO superpacket we must segment; - // writes must prepend a zeroed virtio_net_hdr. - vnetHdr bool - readBuf []byte // scratch for a single raw read (virtio hdr + superpacket) - segBuf []byte // backing store for segmented output - segOff int // cursor into segBuf for the current Read drain - pending [][]byte // segments returned from the most recent Read - writeIovs [2]unix.Iovec // preallocated iovecs for Write (coalescer passthrough); iovs[0] is fixed to validVnetHdr - // rejectIovs is a second preallocated iovec scratch used exclusively by - // WriteReject (reject + self-forward from the inside path). It mirrors - // writeIovs but lets listenIn goroutines emit reject packets without - // racing with the listenOut coalescer that owns writeIovs. - rejectIovs [2]unix.Iovec - - // gsoHdrBuf is a per-queue 10-byte scratch for the virtio_net_hdr emitted - // by WriteGSO. Separate from validVnetHdr so a concurrent non-GSO Write on - // another queue never observes a half-written header. - gsoHdrBuf [virtioNetHdrLen]byte - // gsoIovs is the writev iovec scratch for WriteGSO. Sized to hold the - // virtio header + IP/TCP header + up to gsoInitialPayIovs payload - // fragments; grown on demand if a coalescer pushes more. - gsoIovs []unix.Iovec -} - -// gsoInitialPayIovs is the starting capacity (in payload fragments) of -// tunFile.gsoIovs. Sized to cover the default coalesce segment cap without -// any reallocations. -const gsoInitialPayIovs = 66 - -// validVnetHdr is the 10-byte virtio_net_hdr we prepend to every non-GSO TUN -// write. Only flag set is VIRTIO_NET_HDR_F_DATA_VALID, which marks the skb -// CHECKSUM_UNNECESSARY so the receiving network stack skips L4 checksum -// verification. All packets that reach the plain Write / WriteReject paths -// already carry a valid L4 checksum (either supplied by a remote peer whose -// ciphertext we AEAD-authenticated, or produced by finishChecksum during TSO -// segmentation, or built locally by CreateRejectPacket), so trusting them is -// safe. -var validVnetHdr = [virtioNetHdrLen]byte{unix.VIRTIO_NET_HDR_F_DATA_VALID} - -// newFriend makes a tunFile for a MultiQueueReader that copies the shutdown eventfd from the parent tun -func (r *tunFile) newFriend(fd int) (*tunFile, error) { - if err := unix.SetNonblock(fd, true); err != nil { - return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) - } - out := &tunFile{ - fd: fd, - shutdownFd: r.shutdownFd, - vnetHdr: r.vnetHdr, - readBuf: make([]byte, tunReadBufSize), - readPoll: [2]unix.PollFd{ - {Fd: int32(fd), Events: unix.POLLIN}, - {Fd: int32(r.shutdownFd), Events: unix.POLLIN}, - }, - writePoll: [2]unix.PollFd{ - {Fd: int32(fd), Events: unix.POLLOUT}, - {Fd: int32(r.shutdownFd), Events: unix.POLLIN}, - }, - } - if r.vnetHdr { - out.segBuf = make([]byte, tunSegBufCap) - out.writeIovs[0].Base = &validVnetHdr[0] - out.writeIovs[0].SetLen(virtioNetHdrLen) - out.rejectIovs[0].Base = &validVnetHdr[0] - out.rejectIovs[0].SetLen(virtioNetHdrLen) - out.gsoIovs = make([]unix.Iovec, 2, 2+gsoInitialPayIovs) - out.gsoIovs[0].Base = &out.gsoHdrBuf[0] - out.gsoIovs[0].SetLen(virtioNetHdrLen) - } - return out, nil -} - -func newTunFd(fd int, vnetHdr bool) (*tunFile, error) { - if err := unix.SetNonblock(fd, true); err != nil { - return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) - } - - shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC) - if err != nil { - return nil, fmt.Errorf("failed to create eventfd: %w", err) - } - - out := &tunFile{ - fd: fd, - shutdownFd: shutdownFd, - lastOne: true, - vnetHdr: vnetHdr, - readBuf: make([]byte, tunReadBufSize), - readPoll: [2]unix.PollFd{ - {Fd: int32(fd), Events: unix.POLLIN}, - {Fd: int32(shutdownFd), Events: unix.POLLIN}, - }, - writePoll: [2]unix.PollFd{ - {Fd: int32(fd), Events: unix.POLLOUT}, - {Fd: int32(shutdownFd), Events: unix.POLLIN}, - }, - } - if vnetHdr { - out.segBuf = make([]byte, tunSegBufCap) - out.writeIovs[0].Base = &validVnetHdr[0] - out.writeIovs[0].SetLen(virtioNetHdrLen) - out.rejectIovs[0].Base = &validVnetHdr[0] - out.rejectIovs[0].SetLen(virtioNetHdrLen) - out.gsoIovs = make([]unix.Iovec, 2, 2+gsoInitialPayIovs) - out.gsoIovs[0].Base = &out.gsoHdrBuf[0] - out.gsoIovs[0].SetLen(virtioNetHdrLen) - } - - return out, nil -} - -func (r *tunFile) blockOnRead() error { - const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR - var err error - for { - _, err = unix.Poll(r.readPoll[:], -1) - if err != unix.EINTR { - break - } - } - //always reset these! - tunEvents := r.readPoll[0].Revents - shutdownEvents := r.readPoll[1].Revents - r.readPoll[0].Revents = 0 - r.readPoll[1].Revents = 0 - //do the err check before trusting the potentially bogus bits we just got - if err != nil { - return err - } - if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { - return os.ErrClosed - } else if tunEvents&problemFlags != 0 { - return os.ErrClosed - } - return nil -} - -func (r *tunFile) blockOnWrite() error { - const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR - var err error - for { - _, err = unix.Poll(r.writePoll[:], -1) - if err != unix.EINTR { - break - } - } - //always reset these! - tunEvents := r.writePoll[0].Revents - shutdownEvents := r.writePoll[1].Revents - r.writePoll[0].Revents = 0 - r.writePoll[1].Revents = 0 - //do the err check before trusting the potentially bogus bits we just got - if err != nil { - return err - } - if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { - return os.ErrClosed - } else if tunEvents&problemFlags != 0 { - return os.ErrClosed - } - return nil -} - -func (r *tunFile) readRaw(buf []byte) (int, error) { - for { - if n, err := unix.Read(r.fd, buf); err == nil { - return n, nil - } else if err == unix.EAGAIN { - if err = r.blockOnRead(); err != nil { - return 0, err - } - continue - } else if err == unix.EINTR { - continue - } else if err == unix.EBADF { - return 0, os.ErrClosed - } else { - return 0, err - } - } -} - -// Read reads one or more superpackets from the tun and returns the -// resulting packets. The first read blocks via poll; once the fd is known -// readable we drain additional packets non-blocking until the kernel queue -// is empty (EAGAIN), we've collected tunDrainCap packets, or we're out of -// segBuf headroom. This amortizes the poll wake over bursts of small -// packets (e.g. TCP ACKs). Slices point into the tunFile's internal buffers -// and are only valid until the next Read or Close on this Queue. -func (r *tunFile) Read() ([][]byte, error) { - r.pending = r.pending[:0] - r.segOff = 0 - - // Initial (blocking) read. Retry on decode errors so a single bad - // packet does not stall the reader. - for { - n, err := r.readRaw(r.readBuf) - if err != nil { - return nil, err - } - if !r.vnetHdr { - r.pending = append(r.pending, r.readBuf[:n]) - // Non-vnetHdr mode shares one readBuf so we can't drain safely - // without copying; return the single packet as before. - return r.pending, nil - } - if err := r.decodeRead(n); err != nil { - // Drop and read again — a bad packet should not kill the reader. - continue - } - break - } - - // Drain: non-blocking reads until the kernel queue is empty, the drain - // cap is reached, or segBuf no longer has room for another worst-case - // superpacket. - for len(r.pending) < tunDrainCap && tunSegBufCap-r.segOff >= tunSegBufSize { - n, err := unix.Read(r.fd, r.readBuf) - if err != nil { - // EAGAIN / EINTR / anything else: stop draining. We already - // have a valid batch from the first read. - break - } - if n <= 0 { - break - } - if err := r.decodeRead(n); err != nil { - // Drop this packet and stop the drain; we'd rather hand off - // what we have than keep spinning here. - break - } - } - - return r.pending, nil -} - -// decodeRead decodes the virtio header plus payload in r.readBuf[:n], appends -// the segments to r.pending, and advances r.segOff by the total scratch used. -// Caller must have already ensured r.vnetHdr is true. -func (r *tunFile) decodeRead(n int) error { - if n < virtioNetHdrLen { - return fmt.Errorf("short tun read: %d < %d", n, virtioNetHdrLen) - } - var hdr virtioNetHdr - hdr.decode(r.readBuf[:virtioNetHdrLen]) - before := len(r.pending) - if err := segmentInto(r.readBuf[virtioNetHdrLen:n], hdr, &r.pending, r.segBuf[r.segOff:]); err != nil { - return err - } - for k := before; k < len(r.pending); k++ { - r.segOff += len(r.pending[k]) - } - return nil -} - -func (r *tunFile) Write(buf []byte) (int, error) { - return r.writeWithScratch(buf, &r.writeIovs) -} - -// WriteReject emits a packet using a dedicated iovec scratch (rejectIovs) -// distinct from the one used by the coalescer's Write path. This avoids a -// data race between the inside (listenIn) goroutine emitting reject or -// self-forward packets and the outside (listenOut) goroutine flushing TCP -// coalescer passthroughs on the same tunFile. -func (r *tunFile) WriteReject(buf []byte) (int, error) { - return r.writeWithScratch(buf, &r.rejectIovs) -} - -func (r *tunFile) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error) { - if !r.vnetHdr { - for { - if n, err := unix.Write(r.fd, buf); err == nil { - return n, nil - } else if err == unix.EAGAIN { - if err = r.blockOnWrite(); err != nil { - return 0, err - } - continue - } else if err == unix.EINTR { - continue - } else if err == unix.EBADF { - return 0, os.ErrClosed - } else { - return 0, err - } - } - } - - if len(buf) == 0 { - return 0, nil - } - // Point the payload iovec at the caller's buffer. iovs[0] is pre-wired - // to validVnetHdr during tunFile construction so we don't rebuild it here. - iovs[1].Base = &buf[0] - iovs[1].SetLen(len(buf)) - iovPtr := uintptr(unsafe.Pointer(&iovs[0])) - // The TUN fd is non-blocking (set in newTunFd / newFriend), so writev - // either completes promptly or returns EAGAIN — it cannot park the - // goroutine inside the kernel. That lets us use syscall.RawSyscall and - // skip the runtime.entersyscall / exitsyscall bookkeeping on every - // packet; we only pay that cost when we fall through to blockOnWrite. - for { - n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, 2) - if errno == 0 { - runtime.KeepAlive(buf) - if int(n) < virtioNetHdrLen { - return 0, io.ErrShortWrite - } - return int(n) - virtioNetHdrLen, nil - } - if errno == unix.EAGAIN { - runtime.KeepAlive(buf) - if err := r.blockOnWrite(); err != nil { - return 0, err - } - continue - } - if errno == unix.EINTR { - continue - } - runtime.KeepAlive(buf) - return 0, errno - } -} - -// GSOSupported reports whether this queue was opened with IFF_VNET_HDR and -// can accept WriteGSO. When false, callers should fall back to per-segment -// Write calls. -func (r *tunFile) GSOSupported() bool { return r.vnetHdr } - -// WriteGSO emits a TCP TSO superpacket in a single writev. hdr is the -// IPv4/IPv6 + TCP header prefix (already finalized — total length, IP csum, -// and TCP pseudo-header partial set by the caller). pays are payload -// fragments whose concatenation forms the full coalesced payload; each -// slice is read-only and must stay valid until return. gsoSize is the MSS; -// every segment except possibly the last is exactly gsoSize bytes. -// csumStart is the byte offset where the TCP header begins within hdr. -func (r *tunFile) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error { - if !r.vnetHdr { - return fmt.Errorf("WriteGSO called on tun without IFF_VNET_HDR") - } - if len(hdr) == 0 || len(pays) == 0 { - return nil - } - - // Build the virtio_net_hdr. When pays total to <= gsoSize the kernel - // would produce a single segment; keep NEEDS_CSUM semantics but skip - // the GSO type so the kernel doesn't spuriously mark this as TSO. - vhdr := virtioNetHdr{ - Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, - HdrLen: uint16(len(hdr)), - GSOSize: gsoSize, - CsumStart: csumStart, - CsumOffset: 16, // TCP checksum field lives 16 bytes into the TCP header - } - var totalPay int - for _, p := range pays { - totalPay += len(p) - } - if totalPay > int(gsoSize) { - if isV6 { - vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV6 - } else { - vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV4 - } - } else { - vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE - vhdr.GSOSize = 0 - } - vhdr.encode(r.gsoHdrBuf[:]) - - // Build the iovec array: [virtio_hdr, hdr, pays...]. r.gsoIovs[0] is - // wired to gsoHdrBuf at construction and never changes. - need := 2 + len(pays) - if cap(r.gsoIovs) < need { - grown := make([]unix.Iovec, need) - grown[0] = r.gsoIovs[0] - r.gsoIovs = grown - } else { - r.gsoIovs = r.gsoIovs[:need] - } - r.gsoIovs[1].Base = &hdr[0] - r.gsoIovs[1].SetLen(len(hdr)) - for i, p := range pays { - r.gsoIovs[2+i].Base = &p[0] - r.gsoIovs[2+i].SetLen(len(p)) - } - - iovPtr := uintptr(unsafe.Pointer(&r.gsoIovs[0])) - iovCnt := uintptr(len(r.gsoIovs)) - for { - n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, iovCnt) - if errno == 0 { - runtime.KeepAlive(hdr) - runtime.KeepAlive(pays) - if int(n) < virtioNetHdrLen { - return io.ErrShortWrite - } - return nil - } - if errno == unix.EAGAIN { - runtime.KeepAlive(hdr) - runtime.KeepAlive(pays) - if err := r.blockOnWrite(); err != nil { - return err - } - continue - } - if errno == unix.EINTR { - continue - } - runtime.KeepAlive(hdr) - runtime.KeepAlive(pays) - return errno - } -} - -func (r *tunFile) wakeForShutdown() error { - var buf [8]byte - binary.NativeEndian.PutUint64(buf[:], 1) - _, err := unix.Write(int(r.readPoll[1].Fd), buf[:]) - return err -} - -func (r *tunFile) Close() error { - if r.closed { // avoid closing more than once. Technically a fd could get re-used, which would be a problem - return nil - } - r.closed = true - if r.lastOne { - _ = unix.Close(r.shutdownFd) - } - return unix.Close(r.fd) -} - type tun struct { - *tunFile - readers []*tunFile + readers tio.Container closeLock sync.Mutex Device string vpnNetworks []netip.Prefix @@ -484,6 +34,7 @@ type tun struct { TXQueueLen int deviceIndex int ioctlFd uintptr + vnetHdr bool Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] @@ -622,15 +173,28 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu // newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error. func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vnetHdr bool, vpnNetworks []netip.Prefix) (*tun, error) { - tfd, err := newTunFd(fd, vnetHdr) + var container tio.Container + var err error + if vnetHdr { + container, err = tio.NewGSOContainer() + } else { + container, err = tio.NewPollContainer() + } + if err != nil { _ = unix.Close(fd) return nil, err } + err = container.Add(fd) + if err != nil { + _ = unix.Close(fd) + return nil, err + } + t := &tun{ - tunFile: tfd, - readers: []*tunFile{tfd}, + readers: container, closeLock: sync.Mutex{}, + vnetHdr: vnetHdr, vpnNetworks: vpnNetworks, TXQueueLen: c.GetInt("tun.tx_queue", 500), useSystemRoutes: c.GetBool("tun.use_system_route_table", false), @@ -732,13 +296,13 @@ func (t *tun) SupportsMultiqueue() bool { return true } -func (t *tun) NewMultiQueueReader() (Queue, error) { +func (t *tun) NewMultiQueueReader() error { t.closeLock.Lock() defer t.closeLock.Unlock() fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { - return nil, err + return err } flags := uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) @@ -747,25 +311,23 @@ func (t *tun) NewMultiQueueReader() (Queue, error) { } if _, err = tunSetIff(fd, t.Device, flags); err != nil { _ = unix.Close(fd) - return nil, err + return err } if t.vnetHdr { if err = ioctl(uintptr(fd), unix.TUNSETOFFLOAD, uintptr(tsoOffloadFlags)); err != nil { _ = unix.Close(fd) - return nil, fmt.Errorf("failed to enable offload on multiqueue tun fd: %w", err) + return fmt.Errorf("failed to enable offload on multiqueue tun fd: %w", err) } } - out, err := t.tunFile.newFriend(fd) + err = t.readers.Add(fd) if err != nil { _ = unix.Close(fd) - return nil, err + return err } - t.readers = append(t.readers, out) - - return out, nil + return nil } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { @@ -1195,6 +757,10 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { t.routeTree.Store(newTree) } +func (t *tun) Readers() []tio.Queue { + return t.readers.Queues() +} + func (t *tun) Close() error { t.closeLock.Lock() defer t.closeLock.Unlock() @@ -1204,32 +770,10 @@ func (t *tun) Close() error { t.routeChan = nil } - // Signal all readers blocked in poll to wake up and exit - _ = t.tunFile.wakeForShutdown() - if t.ioctlFd > 0 { _ = unix.Close(int(t.ioctlFd)) t.ioctlFd = 0 } - for i := range t.readers { - if i == 0 { - continue //we want to close the zeroth reader last - } - err := t.readers[i].Close() - if err != nil { - t.l.WithField("reader", i).WithError(err).Error("error closing tun reader") - } else { - t.l.WithField("reader", i).Info("closed tun reader") - } - } - - //this is t.readers[0] too - err := t.tunFile.Close() - if err != nil { - t.l.WithField("reader", 0).WithError(err).Error("error closing tun reader") - } else { - t.l.WithField("reader", 0).Info("closed tun reader") - } - return err + return t.readers.Close() } diff --git a/overlay/tun_linux_test.go b/overlay/tun_linux_test.go deleted file mode 100644 index 1c1842da..00000000 --- a/overlay/tun_linux_test.go +++ /dev/null @@ -1,34 +0,0 @@ -//go:build !e2e_testing -// +build !e2e_testing - -package overlay - -import "testing" - -var runAdvMSSTests = []struct { - name string - tun *tun - r Route - expected int -}{ - // Standard case, default MTU is the device max MTU - {"default", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0}, - {"default-min", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0}, - {"default-low", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160}, - - // Case where we have a route MTU set higher than the default - {"route", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400}, - {"route-min", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400}, - {"route-high", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0}, -} - -func TestTunAdvMSS(t *testing.T) { - for _, tt := range runAdvMSSTests { - t.Run(tt.name, func(t *testing.T) { - o := tt.tun.advMSS(tt.r) - if o != tt.expected { - t.Errorf("got %d, want %d", o, tt.expected) - } - }) - } -} diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 995a9a9f..c07e2a8e 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -16,6 +16,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" @@ -412,7 +413,7 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (Queue, error) { +func (t *tun) NewMultiQueueReader() (tio.Queue, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd") } diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index aab29bb5..f7f06421 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -16,6 +16,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" @@ -332,7 +333,7 @@ func (t *tun) SupportsMultiqueue() bool { return false } -func (t *tun) NewMultiQueueReader() (Queue, error) { +func (t *tun) NewMultiQueueReader() (tio.Queue, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd") } diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 684d1ce1..6a612fce 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -13,6 +13,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" ) @@ -142,6 +143,6 @@ func (t *TestTun) SupportsMultiqueue() bool { return false } -func (t *TestTun) NewMultiQueueReader() (Queue, error) { +func (t *TestTun) NewMultiQueueReader() (tio.Queue, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented") } diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index b02f33d5..2268e9b0 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -17,6 +17,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" "github.com/slackhq/nebula/wintun" @@ -255,7 +256,7 @@ func (t *winTun) SupportsMultiqueue() bool { return false } -func (t *winTun) NewMultiQueueReader() (Queue, error) { +func (t *winTun) NewMultiQueueReader() (tio.Queue, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") } diff --git a/overlay/user.go b/overlay/user.go index 77c2d025..d8b53cf0 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -6,6 +6,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" ) @@ -28,6 +29,7 @@ func NewUserDevice(vpnNetworks []netip.Prefix) (Device, error) { type UserDevice struct { vpnNetworks []netip.Prefix + numReaders int outboundReader *io.PipeReader outboundWriter *io.PipeWriter @@ -65,8 +67,17 @@ func (d *UserDevice) SupportsMultiqueue() bool { return true } -func (d *UserDevice) NewMultiQueueReader() (Queue, error) { - return d, nil +func (d *UserDevice) NewMultiQueueReader() error { + d.numReaders++ + return nil +} + +func (d *UserDevice) Readers() []tio.Queue { + out := make([]tio.Queue, d.numReaders) + for i := range d.numReaders { + out[i] = d + } + return out } func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) {