diff --git a/batch_test.go b/batch_test.go index ee72c63e..f33f9381 100644 --- a/batch_test.go +++ b/batch_test.go @@ -50,6 +50,74 @@ func TestSendBatchBookkeeping(t *testing.T) { } } +func TestBatchSegmentable(t *testing.T) { + ap := netip.MustParseAddrPort("10.0.0.1:4242") + other := netip.MustParseAddrPort("10.0.0.2:4242") + + mk := func(addrs []netip.AddrPort, sizes []int) *sendBatch { + b := newSendBatch(len(addrs), 64) + for i, a := range addrs { + s := b.Next() + for j := 0; j < sizes[i]; j++ { + s = append(s, byte(j)) + } + b.Commit(len(s), a) + } + return b + } + + t.Run("uniform same dst", func(t *testing.T) { + b := mk([]netip.AddrPort{ap, ap, ap}, []int{10, 10, 10}) + seg, ok := batchSegmentable(b) + if !ok || seg != 10 { + t.Fatalf("got seg=%d ok=%v", seg, ok) + } + }) + + t.Run("last segment short ok", func(t *testing.T) { + b := mk([]netip.AddrPort{ap, ap, ap}, []int{10, 10, 4}) + seg, ok := batchSegmentable(b) + if !ok || seg != 10 { + t.Fatalf("got seg=%d ok=%v", seg, ok) + } + }) + + t.Run("mixed dst rejected", func(t *testing.T) { + b := mk([]netip.AddrPort{ap, other, ap}, []int{10, 10, 10}) + if _, ok := batchSegmentable(b); ok { + t.Fatalf("expected rejection for mixed dst") + } + }) + + t.Run("mid-batch short rejected", func(t *testing.T) { + b := mk([]netip.AddrPort{ap, ap, ap}, []int{10, 4, 10}) + if _, ok := batchSegmentable(b); ok { + t.Fatalf("expected rejection for short mid-batch") + } + }) + + t.Run("mid-batch longer rejected", func(t *testing.T) { + b := mk([]netip.AddrPort{ap, ap, ap}, []int{10, 11, 10}) + if _, ok := batchSegmentable(b); ok { + t.Fatalf("expected rejection for longer mid-batch") + } + }) + + t.Run("last longer rejected", func(t *testing.T) { + b := mk([]netip.AddrPort{ap, ap, ap}, []int{10, 10, 11}) + if _, ok := batchSegmentable(b); ok { + t.Fatalf("expected rejection for longer last segment") + } + }) + + t.Run("first zero rejected", func(t *testing.T) { + b := mk([]netip.AddrPort{ap, ap}, []int{0, 10}) + if _, ok := batchSegmentable(b); ok { + t.Fatalf("expected rejection for zero first") + } + }) +} + func TestSendBatchSlotsDoNotOverlap(t *testing.T) { b := newSendBatch(3, 8) ap := netip.MustParseAddrPort("10.0.0.1:80") diff --git a/interface.go b/interface.go index 824c21b5..f4a66c6f 100644 --- a/interface.go +++ b/interface.go @@ -355,17 +355,54 @@ func (f *Interface) listenIn(reader overlay.Queue, i int) { } func (f *Interface) flushBatch(batch *sendBatch, q int) { - if len(batch.bufs) == 1 { - if err := f.writers[q].WriteTo(batch.bufs[0], batch.dsts[0]); err != nil { - f.l.WithError(err).WithField("writer", q).Error("Failed to write outgoing single-batch") + //if len(batch.bufs) == 1 { + // if err := f.writers[q].WriteTo(batch.bufs[0], batch.dsts[0]); err != nil { + // f.l.WithError(err).WithField("writer", q).Error("Failed to write outgoing single-batch") + // } + // return + //} + w := f.writers[q] + if w.SupportsGSO() { + if segSize, ok := batchSegmentable(batch); ok { + if err := w.WriteSegmented(batch.bufs, batch.dsts[0], segSize); err != nil { + f.l.WithError(err).WithField("writer", q).Error("Failed to write outgoing GSO batch") + } + return } - return } - if err := f.writers[q].WriteBatch(batch.bufs, batch.dsts); err != nil { + if err := w.WriteBatch(batch.bufs, batch.dsts); err != nil { f.l.WithError(err).WithField("writer", q).Error("Failed to write outgoing batch") } } +// batchSegmentable reports whether a batch can be emitted as a single UDP GSO +// superpacket: all packets go to the same destination, and every packet +// except possibly the last has the same length. Returns the segment size on +// success. The single-packet case is handled in flushBatch before this runs. +func batchSegmentable(b *sendBatch) (int, bool) { + segSize := len(b.bufs[0]) + if segSize == 0 { + return 0, false + } + dst := b.dsts[0] + last := len(b.bufs) - 1 + for i := 1; i <= last; i++ { + if b.dsts[i] != dst { + return 0, false + } + if i < last { + if len(b.bufs[i]) != segSize { + return 0, false + } + } else { + if len(b.bufs[i]) == 0 || len(b.bufs[i]) > segSize { + return 0, false + } + } + } + return segSize, true +} + func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { c.RegisterReloadCallback(f.reloadFirewall) c.RegisterReloadCallback(f.reloadSendRecvError) diff --git a/udp/conn.go b/udp/conn.go index 652ff79d..f66deba9 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -30,6 +30,16 @@ type Conn interface { // WriteTo loop. Returns on the first error; callers may observe a // partial send if some packets went out before the error. WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error + // WriteSegmented sends bufs as a single UDP GSO sendmsg when the kernel + // supports it: all bufs go to the same addr, each must be exactly segSize + // bytes except the last which may be shorter. The kernel emits one + // datagram per buf on the wire. Backends / kernels without GSO support + // fall back to a per-packet WriteTo loop. Returns on the first error. + WriteSegmented(bufs [][]byte, addr netip.AddrPort, segSize int) error + // SupportsGSO reports whether WriteSegmented takes the single-syscall + // GSO path. Callers use this to decide at batch-assembly time whether + // the uniform-size / same-dst check is worth running. + SupportsGSO() bool ReloadConfig(c *config.C) SupportsMultipleReaders() bool Close() error @@ -55,6 +65,12 @@ func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { func (NoopConn) WriteBatch(_ [][]byte, _ []netip.AddrPort) error { return nil } +func (NoopConn) WriteSegmented(_ [][]byte, _ netip.AddrPort, _ int) error { + return nil +} +func (NoopConn) SupportsGSO() bool { + return false +} func (NoopConn) ReloadConfig(_ *config.C) { return } diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 30f8f50e..00c88203 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -149,6 +149,17 @@ func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { return nil } +func (u *StdConn) WriteSegmented(bufs [][]byte, addr netip.AddrPort, _ int) error { + for _, b := range bufs { + if err := u.WriteTo(b, addr); err != nil { + return err + } + } + return nil +} + +func (u *StdConn) SupportsGSO() bool { return false } + func (u *StdConn) LocalAddr() (netip.AddrPort, error) { a := u.UDPConn.LocalAddr() diff --git a/udp/udp_generic.go b/udp/udp_generic.go index 9e39436f..f29bbc1f 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -51,6 +51,17 @@ func (u *GenericConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { return nil } +func (u *GenericConn) WriteSegmented(bufs [][]byte, addr netip.AddrPort, _ int) error { + for _, b := range bufs { + if _, err := u.UDPConn.WriteToUDPAddrPort(b, addr); err != nil { + return err + } + } + return nil +} + +func (u *GenericConn) SupportsGSO() bool { return false } + func (u *GenericConn) LocalAddr() (netip.AddrPort, error) { a := u.UDPConn.LocalAddr() diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 09c85033..ca9988c9 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -38,6 +38,18 @@ type StdConn struct { writeSent int writeErrno syscall.Errno writeFunc func(fd uintptr) bool + + // UDP GSO (sendmsg with UDP_SEGMENT cmsg) support. gsoSupported is + // probed once at socket creation. When true, WriteSegmented takes a + // single-syscall GSO path; otherwise it falls back to a WriteTo loop. + gsoSupported bool + gsoMsg msghdr + gsoIovs []iovec + gsoName []byte // SizeofSockaddrInet6 + gsoCmsg []byte // CmsgSpace(2) + gsoSent int + gsoErrno syscall.Errno + gsoFunc func(fd uintptr) bool } func setReusePort(network, address string, c syscall.RawConn) error { @@ -87,9 +99,58 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in out.prepareWriteMessages(MaxWriteBatch) out.writeFunc = out.sendmmsgRawWrite + out.prepareGSO() + return out, nil } +// maxGSOSegments caps the per-sendmsg GSO fan-out. Linux kernels have +// historically capped UDP_MAX_SEGMENTS at 64; newer kernels raise it to 128 +// but we stay conservative so the same code works everywhere. +const maxGSOSegments = 64 + +// maxGSOBytes bounds the total payload per sendmsg() when UDP_SEGMENT is +// set. The kernel stitches all iovecs into a single skb whose length the +// UDP length field can represent, and also enforces sk_gso_max_size (which +// on most devices is 65536). We use 65535 so ciphertext + headers always +// fits, avoiding EMSGSIZE on large TSO superpackets. +const maxGSOBytes = 65535 + +// prepareGSO probes UDP_SEGMENT support and, on success, sets up the +// reusable sendmsg scratch (iovecs, sockaddr, cmsg) plus the preallocated +// raw-write closure used to avoid heap allocations on the hot path. +func (u *StdConn) prepareGSO() { + var probeErr error + if err := u.rawConn.Control(func(fd uintptr) { + probeErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT, 0) + }); err != nil { + return + } + if probeErr != nil { + return + } + u.gsoSupported = true + u.gsoIovs = make([]iovec, maxGSOSegments) + u.gsoName = make([]byte, unix.SizeofSockaddrInet6) + u.gsoCmsg = make([]byte, unix.CmsgSpace(2)) + + // Wire up the static pieces of gsoMsg. Iovlen / Controllen / Namelen / + // cmsg contents get refreshed per call; Iov, Name, Control pointers are + // fixed because the scratch slices never move. + u.gsoMsg.Iov = &u.gsoIovs[0] + u.gsoMsg.Name = &u.gsoName[0] + u.gsoMsg.Control = &u.gsoCmsg[0] + + // Prepopulate the cmsg header. Len/Level/Type are constant for our use; + // only the 2-byte gso_size payload changes per call. + cmsghdr := (*unix.Cmsghdr)(unsafe.Pointer(&u.gsoCmsg[0])) + cmsghdr.Level = unix.SOL_UDP + cmsghdr.Type = unix.UDP_SEGMENT + setCmsgLen(cmsghdr, unix.CmsgLen(2)) + + u.gsoFunc = u.sendmsgRawWriteGSO +} + func (u *StdConn) SupportsMultipleReaders() bool { return true } @@ -331,6 +392,115 @@ func (u *StdConn) sendmmsgRawWrite(fd uintptr) bool { return true } +func (u *StdConn) SupportsGSO() bool { + return u.gsoSupported +} + +// WriteSegmented sends bufs to addr as a UDP GSO superpacket. The kernel +// emits one datagram per iovec on the wire; all iovecs except the last must +// be exactly segSize bytes. Non-GSO kernels hit the WriteTo fallback. +// Called with len(bufs) >= 1. len(bufs) > maxGSOSegments is chunked. +func (u *StdConn) WriteSegmented(bufs [][]byte, addr netip.AddrPort, segSize int) error { + if len(bufs) == 0 { + return nil + } + if !u.gsoSupported { + for _, b := range bufs { + if err := u.WriteTo(b, addr); err != nil { + return err + } + } + return nil + } + + nlen, err := writeSockaddr(u.gsoName, addr, u.isV4) + if err != nil { + return err + } + u.gsoMsg.Namelen = uint32(nlen) + setMsgControllen(&u.gsoMsg, unix.CmsgSpace(2)) + + // Cap the per-syscall fan-out by both segment count and total bytes. + // Kernel rejects sendmsg with EMSGSIZE when segCount*segSize would + // exceed sk_gso_max_size (typically 65536). For segSize > maxGSOBytes + // we can't use GSO at all and must fall back per-packet. + segsByBytes := maxGSOBytes / segSize + if segsByBytes == 0 { + for _, b := range bufs { + if werr := u.WriteTo(b, addr); werr != nil { + return werr + } + } + return nil + } + maxChunk := maxGSOSegments + if segsByBytes < maxChunk { + maxChunk = segsByBytes + } + + i := 0 + for i < len(bufs) { + chunk := len(bufs) - i + if chunk > maxChunk { + chunk = maxChunk + } + for k := 0; k < chunk; k++ { + b := bufs[i+k] + if len(b) == 0 { + u.gsoIovs[k].Base = nil + setIovLen(&u.gsoIovs[k], 0) + } else { + u.gsoIovs[k].Base = &b[0] + setIovLen(&u.gsoIovs[k], len(b)) + } + } + setMsgIovlen(&u.gsoMsg, chunk) + binary.NativeEndian.PutUint16(u.gsoCmsg[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(segSize)) + + if serr := u.sendmsgGSO(); serr != nil { + // Fall back to a per-packet loop for the remainder of the + // batch. Dropping the GSO call entirely is safer than + // returning mid-superpacket and losing bytes. + for k := 0; k < chunk; k++ { + if werr := u.WriteTo(bufs[i+k], addr); werr != nil { + return werr + } + } + } + i += chunk + } + return nil +} + +// sendmsgRawWriteGSO is the preallocated rawConn.Write callback for the GSO +// path. Reads the prebuilt u.gsoMsg and writes u.gsoSent / u.gsoErrno. +func (u *StdConn) sendmsgRawWriteGSO(fd uintptr) bool { + r1, _, errno := unix.Syscall( + unix.SYS_SENDMSG, + fd, + uintptr(unsafe.Pointer(&u.gsoMsg)), + 0, + ) + if errno == syscall.EAGAIN || errno == syscall.EWOULDBLOCK { + return false + } + u.gsoSent = int(r1) + u.gsoErrno = errno + return true +} + +func (u *StdConn) sendmsgGSO() error { + u.gsoSent = 0 + u.gsoErrno = 0 + if err := u.rawConn.Write(u.gsoFunc); err != nil { + return err + } + if u.gsoErrno != 0 { + return &net.OpError{Op: "sendmsg", Err: u.gsoErrno} + } + return nil +} + func (u *StdConn) sendmmsg(n int) (int, error) { u.writeChunk = n u.writeSent = 0 diff --git a/udp/udp_linux_32.go b/udp/udp_linux_32.go index 4f45b1a4..1fd469f3 100644 --- a/udp/udp_linux_32.go +++ b/udp/udp_linux_32.go @@ -72,3 +72,15 @@ func (u *StdConn) prepareWriteMessages(n int) { func setIovLen(v *iovec, n int) { v.Len = uint32(n) } + +func setMsgIovlen(m *msghdr, n int) { + m.Iovlen = uint32(n) +} + +func setMsgControllen(m *msghdr, n int) { + m.Controllen = uint32(n) +} + +func setCmsgLen(h *unix.Cmsghdr, n int) { + h.Len = uint32(n) +} diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 96c548e7..6f5008c0 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -75,3 +75,15 @@ func (u *StdConn) prepareWriteMessages(n int) { func setIovLen(v *iovec, n int) { v.Len = uint64(n) } + +func setMsgIovlen(m *msghdr, n int) { + m.Iovlen = uint64(n) +} + +func setMsgControllen(m *msghdr, n int) { + m.Controllen = uint64(n) +} + +func setCmsgLen(h *unix.Cmsghdr, n int) { + h.Len = uint64(n) +} diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index f21f5610..fc15acbb 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -325,6 +325,17 @@ func (u *RIOConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { return nil } +func (u *RIOConn) WriteSegmented(bufs [][]byte, addr netip.AddrPort, _ int) error { + for _, b := range bufs { + if err := u.WriteTo(b, addr); err != nil { + return err + } + } + return nil +} + +func (u *RIOConn) SupportsGSO() bool { return false } + func (u *RIOConn) LocalAddr() (netip.AddrPort, error) { sa, err := windows.Getsockname(u.sock) if err != nil { diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 69ca24a1..522f95f7 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -116,6 +116,17 @@ func (u *TesterConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { return nil } +func (u *TesterConn) WriteSegmented(bufs [][]byte, addr netip.AddrPort, _ int) error { + for _, b := range bufs { + if err := u.WriteTo(b, addr); err != nil { + return err + } + } + return nil +} + +func (u *TesterConn) SupportsGSO() bool { return false } + func (u *TesterConn) ListenOut(r EncReader) error { for { p, ok := <-u.RxPackets