From f8f63c470a2f9d7429f1f0a5cd2456b99469b73e Mon Sep 17 00:00:00 2001 From: JackDoan Date: Fri, 17 Apr 2026 11:39:46 -0500 Subject: [PATCH] checkpt --- batch.go | 70 +++++++++++++++++++++++ batch_test.go | 69 +++++++++++++++++++++++ udp/udp_darwin.go | 9 +++ udp/udp_linux.go | 124 +++++++++++++++++++++++++++++++++++++++++ udp/udp_linux_32.go | 20 +++++++ udp/udp_rio_windows.go | 9 +++ udp/udp_tester.go | 9 +++ 7 files changed, 310 insertions(+) create mode 100644 batch.go create mode 100644 batch_test.go diff --git a/batch.go b/batch.go new file mode 100644 index 00000000..02d86bc7 --- /dev/null +++ b/batch.go @@ -0,0 +1,70 @@ +package nebula + +import "net/netip" + +// sendBatchCap is the maximum number of encrypted packets accumulated before a +// flush is forced. TSO superpackets segment to at most ~45 packets on +// reasonable MTUs, so 128 leaves headroom without bloating the backing +// allocation. +const sendBatchCap = 128 + +// sendBatch accumulates encrypted UDP packets for a single sendmmsg flush. +// One sendBatch is owned by each listenIn goroutine; no locking is needed. +// The backing storage holds up to batchCap packets of slotCap bytes each; +// bufs and dsts are parallel slices of committed slots. +type sendBatch struct { + bufs [][]byte + dsts []netip.AddrPort + backing []byte + slotCap int + batchCap int + nextSlot int +} + +func newSendBatch(batchCap, slotCap int) *sendBatch { + return &sendBatch{ + bufs: make([][]byte, 0, batchCap), + dsts: make([]netip.AddrPort, 0, batchCap), + backing: make([]byte, batchCap*slotCap), + slotCap: slotCap, + batchCap: batchCap, + } +} + +// Next returns a zero-length slice with slotCap capacity over the next unused +// slot's backing bytes. The caller writes into the returned slice and then +// calls Commit with the final length and destination. Next returns nil when +// the batch is full. +func (b *sendBatch) Next() []byte { + if b.nextSlot >= b.batchCap { + return nil + } + start := b.nextSlot * b.slotCap + return b.backing[start : start : start+b.slotCap] +} + +// Commit records the slot just returned by Next as a packet of length n +// destined for dst. +func (b *sendBatch) Commit(n int, dst netip.AddrPort) { + start := b.nextSlot * b.slotCap + b.bufs = append(b.bufs, b.backing[start:start+n]) + b.dsts = append(b.dsts, dst) + b.nextSlot++ +} + +// Reset clears committed slots; backing storage is retained for reuse. +func (b *sendBatch) Reset() { + b.bufs = b.bufs[:0] + b.dsts = b.dsts[:0] + b.nextSlot = 0 +} + +// Len returns the number of committed packets. +func (b *sendBatch) Len() int { + return len(b.bufs) +} + +// Cap returns the maximum number of slots in the batch. +func (b *sendBatch) Cap() int { + return b.batchCap +} diff --git a/batch_test.go b/batch_test.go new file mode 100644 index 00000000..ee72c63e --- /dev/null +++ b/batch_test.go @@ -0,0 +1,69 @@ +package nebula + +import ( + "net/netip" + "testing" +) + +func TestSendBatchBookkeeping(t *testing.T) { + b := newSendBatch(4, 32) + if b.Len() != 0 || b.Cap() != 4 { + t.Fatalf("fresh batch: len=%d cap=%d", b.Len(), b.Cap()) + } + + ap := netip.MustParseAddrPort("10.0.0.1:4242") + for i := 0; i < 4; i++ { + slot := b.Next() + if slot == nil { + t.Fatalf("slot %d: Next returned nil before cap", i) + } + if cap(slot) != 32 || len(slot) != 0 { + t.Fatalf("slot %d: got len=%d cap=%d want len=0 cap=32", i, len(slot), cap(slot)) + } + // Write a marker byte. + slot = append(slot, byte(i), byte(i+1), byte(i+2)) + b.Commit(len(slot), ap) + } + if b.Next() != nil { + t.Fatalf("Next should return nil when full") + } + if b.Len() != 4 { + t.Fatalf("Len=%d want 4", b.Len()) + } + for i, buf := range b.bufs { + if len(buf) != 3 || buf[0] != byte(i) { + t.Errorf("buf %d: %x", i, buf) + } + if b.dsts[i] != ap { + t.Errorf("dst %d: got %v want %v", i, b.dsts[i], ap) + } + } + + // Reset returns empty and Next works again. + b.Reset() + if b.Len() != 0 { + t.Fatalf("after Reset Len=%d want 0", b.Len()) + } + slot := b.Next() + if slot == nil || cap(slot) != 32 { + t.Fatalf("after Reset Next nil or wrong cap: %v cap=%d", slot == nil, cap(slot)) + } +} + +func TestSendBatchSlotsDoNotOverlap(t *testing.T) { + b := newSendBatch(3, 8) + ap := netip.MustParseAddrPort("10.0.0.1:80") + + // Fill three slots, each with its own sentinel byte. + for i := 0; i < 3; i++ { + s := b.Next() + s = append(s, byte(0xA0+i), byte(0xB0+i)) + b.Commit(len(s), ap) + } + + for i, buf := range b.bufs { + if buf[0] != byte(0xA0+i) || buf[1] != byte(0xB0+i) { + t.Errorf("slot %d corrupted: %x", i, buf) + } + } +} diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 863c98f3..30f8f50e 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -140,6 +140,15 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error { } } +func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { + for i, b := range bufs { + if err := u.WriteTo(b, addrs[i]); err != nil { + return err + } + } + return nil +} + func (u *StdConn) LocalAddr() (netip.AddrPort, error) { a := u.UDPConn.LocalAddr() diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 21a34147..895fda7b 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -24,6 +24,13 @@ type StdConn struct { isV4 bool l *logrus.Logger batch int + + // sendmmsg scratch. Each queue has its own StdConn, so no locking is + // needed. Sized to MaxWriteBatch at construction; WriteBatch chunks + // larger inputs. + writeMsgs []rawMessage + writeIovs []iovec + writeNames [][]byte } func setReusePort(network, address string, c syscall.RawConn) error { @@ -70,6 +77,8 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in } out.isV4 = af == unix.AF_INET + out.prepareWriteMessages(MaxWriteBatch) + return out, nil } @@ -235,6 +244,121 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { return err } +// WriteBatch sends bufs via sendmmsg(2) using the preallocated scratch on +// StdConn. Chunks larger than the scratch are processed in multiple syscalls. +// If sendmmsg returns a fatal error mid-chunk we fall back to single WriteTo +// calls for the remainder so the caller still gets best-effort delivery. +func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { + if len(bufs) != len(addrs) { + return fmt.Errorf("WriteBatch: len(bufs)=%d != len(addrs)=%d", len(bufs), len(addrs)) + } + + i := 0 + for i < len(bufs) { + chunk := len(bufs) - i + if chunk > len(u.writeMsgs) { + chunk = len(u.writeMsgs) + } + + for k := 0; k < chunk; k++ { + b := bufs[i+k] + if len(b) == 0 { + // sendmmsg with an empty iovec is legal but pointless; fall + // through after filling the slot so Base is still valid. + u.writeIovs[k].Base = nil + setIovLen(&u.writeIovs[k], 0) + } else { + u.writeIovs[k].Base = &b[0] + setIovLen(&u.writeIovs[k], len(b)) + } + nlen, err := writeSockaddr(u.writeNames[k], addrs[i+k], u.isV4) + if err != nil { + return err + } + u.writeMsgs[k].Hdr.Namelen = uint32(nlen) + } + + sent, serr := u.sendmmsg(chunk) + if serr != nil { + if sent <= 0 { + // nothing went out; fall back to WriteTo for this chunk. + for k := 0; k < chunk; k++ { + if err := u.WriteTo(bufs[i+k], addrs[i+k]); err != nil { + return err + } + } + i += chunk + continue + } + // partial: treat as success for the sent packets and retry the + // remainder on the next outer-loop iteration. + } + if sent == 0 { + return fmt.Errorf("sendmmsg made no progress") + } + i += sent + } + return nil +} + +func (u *StdConn) sendmmsg(n int) (int, error) { + var sent int + var sysErr error + err := u.rawConn.Write(func(fd uintptr) (done bool) { + r1, _, errno := unix.Syscall6( + unix.SYS_SENDMMSG, + fd, + uintptr(unsafe.Pointer(&u.writeMsgs[0])), + uintptr(n), + 0, + 0, + 0, + ) + if errno == syscall.EAGAIN || errno == syscall.EWOULDBLOCK { + return false + } + sent = int(r1) + if errno != 0 { + sysErr = &net.OpError{Op: "sendmmsg", Err: errno} + } + return true + }) + if err != nil { + return sent, err + } + return sent, sysErr +} + +// writeSockaddr encodes addr into buf (which must be at least +// SizeofSockaddrInet6 bytes). Returns the number of bytes used. If isV4 is +// true and addr is not a v4 (or v4-in-v6) address, returns an error. +func writeSockaddr(buf []byte, addr netip.AddrPort, isV4 bool) (int, error) { + ap := addr.Addr().Unmap() + if isV4 { + if !ap.Is4() { + return 0, ErrInvalidIPv6RemoteForSocket + } + // struct sockaddr_in: { sa_family_t(2), in_port_t(2, BE), in_addr(4), zero(8) } + // sa_family is host endian. + binary.NativeEndian.PutUint16(buf[0:2], unix.AF_INET) + binary.BigEndian.PutUint16(buf[2:4], addr.Port()) + ip4 := ap.As4() + copy(buf[4:8], ip4[:]) + for j := 8; j < 16; j++ { + buf[j] = 0 + } + return unix.SizeofSockaddrInet4, nil + } + // struct sockaddr_in6: { sa_family_t(2), in_port_t(2, BE), flowinfo(4), in6_addr(16), scope_id(4) } + binary.NativeEndian.PutUint16(buf[0:2], unix.AF_INET6) + binary.BigEndian.PutUint16(buf[2:4], addr.Port()) + binary.NativeEndian.PutUint32(buf[4:8], 0) + ip6 := addr.Addr().As16() + copy(buf[8:24], ip6[:]) + binary.NativeEndian.PutUint32(buf[24:28], 0) + return unix.SizeofSockaddrInet6, nil +} + func (u *StdConn) ReloadConfig(c *config.C) { b := c.GetInt("listen.read_buffer", 0) if b > 0 { diff --git a/udp/udp_linux_32.go b/udp/udp_linux_32.go index de8f1cdf..4f45b1a4 100644 --- a/udp/udp_linux_32.go +++ b/udp/udp_linux_32.go @@ -52,3 +52,23 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { return msgs, buffers, names } + +// prepareWriteMessages allocates one Mmsghdr/iovec/sockaddr scratch per slot, +// wired up so each writeMsgs[i] already points at writeIovs[i] and +// writeNames[i]. Callers fill in the iovec Base/Len, the sockaddr bytes, and +// Namelen before each sendmmsg. +func (u *StdConn) prepareWriteMessages(n int) { + u.writeMsgs = make([]rawMessage, n) + u.writeIovs = make([]iovec, n) + u.writeNames = make([][]byte, n) + for i := range u.writeMsgs { + u.writeNames[i] = make([]byte, unix.SizeofSockaddrInet6) + u.writeMsgs[i].Hdr.Iov = &u.writeIovs[i] + u.writeMsgs[i].Hdr.Iovlen = 1 + u.writeMsgs[i].Hdr.Name = &u.writeNames[i][0] + } +} + +func setIovLen(v *iovec, n int) { + v.Len = uint32(n) +} diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index 607b978e..f21f5610 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -316,6 +316,15 @@ func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error { return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) } +func (u *RIOConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { + for i, b := range bufs { + if err := u.WriteTo(b, addrs[i]); err != nil { + return err + } + } + return nil +} + func (u *RIOConn) LocalAddr() (netip.AddrPort, error) { sa, err := windows.Getsockname(u.sock) if err != nil { diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 5db72555..69ca24a1 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -107,6 +107,15 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { return nil } +func (u *TesterConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { + for i, b := range bufs { + if err := u.WriteTo(b, addrs[i]); err != nil { + return err + } + } + return nil +} + func (u *TesterConn) ListenOut(r EncReader) error { for { p, ok := <-u.RxPackets