This commit is contained in:
JackDoan
2026-04-17 11:39:46 -05:00
parent c05fa793a6
commit f8f63c470a
7 changed files with 310 additions and 0 deletions

70
batch.go Normal file
View File

@@ -0,0 +1,70 @@
package nebula
import "net/netip"
// sendBatchCap is the maximum number of encrypted packets accumulated before a
// flush is forced. TSO superpackets segment to at most ~45 packets on
// reasonable MTUs, so 128 leaves headroom without bloating the backing
// allocation.
const sendBatchCap = 128
// sendBatch accumulates encrypted UDP packets for a single sendmmsg flush.
// One sendBatch is owned by each listenIn goroutine; no locking is needed.
// The backing storage holds up to batchCap packets of slotCap bytes each;
// bufs and dsts are parallel slices of committed slots.
type sendBatch struct {
bufs [][]byte
dsts []netip.AddrPort
backing []byte
slotCap int
batchCap int
nextSlot int
}
func newSendBatch(batchCap, slotCap int) *sendBatch {
return &sendBatch{
bufs: make([][]byte, 0, batchCap),
dsts: make([]netip.AddrPort, 0, batchCap),
backing: make([]byte, batchCap*slotCap),
slotCap: slotCap,
batchCap: batchCap,
}
}
// Next returns a zero-length slice with slotCap capacity over the next unused
// slot's backing bytes. The caller writes into the returned slice and then
// calls Commit with the final length and destination. Next returns nil when
// the batch is full.
func (b *sendBatch) Next() []byte {
if b.nextSlot >= b.batchCap {
return nil
}
start := b.nextSlot * b.slotCap
return b.backing[start : start : start+b.slotCap]
}
// Commit records the slot just returned by Next as a packet of length n
// destined for dst.
func (b *sendBatch) Commit(n int, dst netip.AddrPort) {
start := b.nextSlot * b.slotCap
b.bufs = append(b.bufs, b.backing[start:start+n])
b.dsts = append(b.dsts, dst)
b.nextSlot++
}
// Reset clears committed slots; backing storage is retained for reuse.
func (b *sendBatch) Reset() {
b.bufs = b.bufs[:0]
b.dsts = b.dsts[:0]
b.nextSlot = 0
}
// Len returns the number of committed packets.
func (b *sendBatch) Len() int {
return len(b.bufs)
}
// Cap returns the maximum number of slots in the batch.
func (b *sendBatch) Cap() int {
return b.batchCap
}

69
batch_test.go Normal file
View File

@@ -0,0 +1,69 @@
package nebula
import (
"net/netip"
"testing"
)
func TestSendBatchBookkeeping(t *testing.T) {
b := newSendBatch(4, 32)
if b.Len() != 0 || b.Cap() != 4 {
t.Fatalf("fresh batch: len=%d cap=%d", b.Len(), b.Cap())
}
ap := netip.MustParseAddrPort("10.0.0.1:4242")
for i := 0; i < 4; i++ {
slot := b.Next()
if slot == nil {
t.Fatalf("slot %d: Next returned nil before cap", i)
}
if cap(slot) != 32 || len(slot) != 0 {
t.Fatalf("slot %d: got len=%d cap=%d want len=0 cap=32", i, len(slot), cap(slot))
}
// Write a marker byte.
slot = append(slot, byte(i), byte(i+1), byte(i+2))
b.Commit(len(slot), ap)
}
if b.Next() != nil {
t.Fatalf("Next should return nil when full")
}
if b.Len() != 4 {
t.Fatalf("Len=%d want 4", b.Len())
}
for i, buf := range b.bufs {
if len(buf) != 3 || buf[0] != byte(i) {
t.Errorf("buf %d: %x", i, buf)
}
if b.dsts[i] != ap {
t.Errorf("dst %d: got %v want %v", i, b.dsts[i], ap)
}
}
// Reset returns empty and Next works again.
b.Reset()
if b.Len() != 0 {
t.Fatalf("after Reset Len=%d want 0", b.Len())
}
slot := b.Next()
if slot == nil || cap(slot) != 32 {
t.Fatalf("after Reset Next nil or wrong cap: %v cap=%d", slot == nil, cap(slot))
}
}
func TestSendBatchSlotsDoNotOverlap(t *testing.T) {
b := newSendBatch(3, 8)
ap := netip.MustParseAddrPort("10.0.0.1:80")
// Fill three slots, each with its own sentinel byte.
for i := 0; i < 3; i++ {
s := b.Next()
s = append(s, byte(0xA0+i), byte(0xB0+i))
b.Commit(len(s), ap)
}
for i, buf := range b.bufs {
if buf[0] != byte(0xA0+i) || buf[1] != byte(0xB0+i) {
t.Errorf("slot %d corrupted: %x", i, buf)
}
}
}

View File

@@ -140,6 +140,15 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
}
}
func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error {
for i, b := range bufs {
if err := u.WriteTo(b, addrs[i]); err != nil {
return err
}
}
return nil
}
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
a := u.UDPConn.LocalAddr()

View File

@@ -24,6 +24,13 @@ type StdConn struct {
isV4 bool
l *logrus.Logger
batch int
// sendmmsg scratch. Each queue has its own StdConn, so no locking is
// needed. Sized to MaxWriteBatch at construction; WriteBatch chunks
// larger inputs.
writeMsgs []rawMessage
writeIovs []iovec
writeNames [][]byte
}
func setReusePort(network, address string, c syscall.RawConn) error {
@@ -70,6 +77,8 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
}
out.isV4 = af == unix.AF_INET
out.prepareWriteMessages(MaxWriteBatch)
return out, nil
}
@@ -235,6 +244,121 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
return err
}
// WriteBatch sends bufs via sendmmsg(2) using the preallocated scratch on
// StdConn. Chunks larger than the scratch are processed in multiple syscalls.
// If sendmmsg returns a fatal error mid-chunk we fall back to single WriteTo
// calls for the remainder so the caller still gets best-effort delivery.
func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error {
if len(bufs) != len(addrs) {
return fmt.Errorf("WriteBatch: len(bufs)=%d != len(addrs)=%d", len(bufs), len(addrs))
}
i := 0
for i < len(bufs) {
chunk := len(bufs) - i
if chunk > len(u.writeMsgs) {
chunk = len(u.writeMsgs)
}
for k := 0; k < chunk; k++ {
b := bufs[i+k]
if len(b) == 0 {
// sendmmsg with an empty iovec is legal but pointless; fall
// through after filling the slot so Base is still valid.
u.writeIovs[k].Base = nil
setIovLen(&u.writeIovs[k], 0)
} else {
u.writeIovs[k].Base = &b[0]
setIovLen(&u.writeIovs[k], len(b))
}
nlen, err := writeSockaddr(u.writeNames[k], addrs[i+k], u.isV4)
if err != nil {
return err
}
u.writeMsgs[k].Hdr.Namelen = uint32(nlen)
}
sent, serr := u.sendmmsg(chunk)
if serr != nil {
if sent <= 0 {
// nothing went out; fall back to WriteTo for this chunk.
for k := 0; k < chunk; k++ {
if err := u.WriteTo(bufs[i+k], addrs[i+k]); err != nil {
return err
}
}
i += chunk
continue
}
// partial: treat as success for the sent packets and retry the
// remainder on the next outer-loop iteration.
}
if sent == 0 {
return fmt.Errorf("sendmmsg made no progress")
}
i += sent
}
return nil
}
func (u *StdConn) sendmmsg(n int) (int, error) {
var sent int
var sysErr error
err := u.rawConn.Write(func(fd uintptr) (done bool) {
r1, _, errno := unix.Syscall6(
unix.SYS_SENDMMSG,
fd,
uintptr(unsafe.Pointer(&u.writeMsgs[0])),
uintptr(n),
0,
0,
0,
)
if errno == syscall.EAGAIN || errno == syscall.EWOULDBLOCK {
return false
}
sent = int(r1)
if errno != 0 {
sysErr = &net.OpError{Op: "sendmmsg", Err: errno}
}
return true
})
if err != nil {
return sent, err
}
return sent, sysErr
}
// writeSockaddr encodes addr into buf (which must be at least
// SizeofSockaddrInet6 bytes). Returns the number of bytes used. If isV4 is
// true and addr is not a v4 (or v4-in-v6) address, returns an error.
func writeSockaddr(buf []byte, addr netip.AddrPort, isV4 bool) (int, error) {
ap := addr.Addr().Unmap()
if isV4 {
if !ap.Is4() {
return 0, ErrInvalidIPv6RemoteForSocket
}
// struct sockaddr_in: { sa_family_t(2), in_port_t(2, BE), in_addr(4), zero(8) }
// sa_family is host endian.
binary.NativeEndian.PutUint16(buf[0:2], unix.AF_INET)
binary.BigEndian.PutUint16(buf[2:4], addr.Port())
ip4 := ap.As4()
copy(buf[4:8], ip4[:])
for j := 8; j < 16; j++ {
buf[j] = 0
}
return unix.SizeofSockaddrInet4, nil
}
// struct sockaddr_in6: { sa_family_t(2), in_port_t(2, BE), flowinfo(4), in6_addr(16), scope_id(4) }
binary.NativeEndian.PutUint16(buf[0:2], unix.AF_INET6)
binary.BigEndian.PutUint16(buf[2:4], addr.Port())
binary.NativeEndian.PutUint32(buf[4:8], 0)
ip6 := addr.Addr().As16()
copy(buf[8:24], ip6[:])
binary.NativeEndian.PutUint32(buf[24:28], 0)
return unix.SizeofSockaddrInet6, nil
}
func (u *StdConn) ReloadConfig(c *config.C) {
b := c.GetInt("listen.read_buffer", 0)
if b > 0 {

View File

@@ -52,3 +52,23 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
return msgs, buffers, names
}
// prepareWriteMessages allocates one Mmsghdr/iovec/sockaddr scratch per slot,
// wired up so each writeMsgs[i] already points at writeIovs[i] and
// writeNames[i]. Callers fill in the iovec Base/Len, the sockaddr bytes, and
// Namelen before each sendmmsg.
func (u *StdConn) prepareWriteMessages(n int) {
u.writeMsgs = make([]rawMessage, n)
u.writeIovs = make([]iovec, n)
u.writeNames = make([][]byte, n)
for i := range u.writeMsgs {
u.writeNames[i] = make([]byte, unix.SizeofSockaddrInet6)
u.writeMsgs[i].Hdr.Iov = &u.writeIovs[i]
u.writeMsgs[i].Hdr.Iovlen = 1
u.writeMsgs[i].Hdr.Name = &u.writeNames[i][0]
}
}
func setIovLen(v *iovec, n int) {
v.Len = uint32(n)
}

View File

@@ -316,6 +316,15 @@ func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error {
return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
}
func (u *RIOConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error {
for i, b := range bufs {
if err := u.WriteTo(b, addrs[i]); err != nil {
return err
}
}
return nil
}
func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
sa, err := windows.Getsockname(u.sock)
if err != nil {

View File

@@ -107,6 +107,15 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
return nil
}
func (u *TesterConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error {
for i, b := range bufs {
if err := u.WriteTo(b, addrs[i]); err != nil {
return err
}
}
return nil
}
func (u *TesterConn) ListenOut(r EncReader) error {
for {
p, ok := <-u.RxPackets