diff --git a/interface.go b/interface.go index 2ef1c31..74e2c84 100644 --- a/interface.go +++ b/interface.go @@ -280,15 +280,20 @@ func (f *Interface) listenOut(i int) { ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) lhh := f.lightHouse.NewRequestHandler() - // Allocate plaintext buffer with virtio header headroom to avoid copies on TUN write - plaintext := make([]byte, virtioNetHdrLen+udp.MTU) + // Pre-allocate output buffers for batch processing + batchSize := li.BatchSize() + outs := make([][]byte, batchSize) + for idx := range outs { + // Allocate full buffer with virtio header space + outs[idx] = make([]byte, virtioNetHdrLen, virtioNetHdrLen+udp.MTU) + } h := &header.H{} fwPacket := &firewall.Packet{} nb := make([]byte, 12) - li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { - f.readOutsidePackets(fromUdpAddr, nil, plaintext[:virtioNetHdrLen], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) + li.ListenOutBatch(func(addrs []netip.AddrPort, payloads [][]byte, count int) { + f.readOutsidePacketsBatch(addrs, payloads, count, outs[:count], nb, i, h, fwPacket, lhh, ctCache.Get(f.l)) }) } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index f40031c..3c98d72 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -102,6 +102,11 @@ func (w *wgDeviceWrapper) Write(b []byte) (int, error) { return len(b), nil } +func (w *wgDeviceWrapper) WriteBatch(bufs [][]byte, offset int) (int, error) { + // Pass all buffers to WireGuard's batch write + return w.dev.Write(bufs, offset) +} + func (w *wgDeviceWrapper) Close() error { return w.dev.Close() } @@ -436,6 +441,22 @@ func (t *tun) Write(b []byte) (int, error) { } } +// WriteBatch writes multiple packets to the TUN device in a single syscall +func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) { + if t.wgDevice != nil { + return t.wgDevice.Write(bufs, offset) + } + + // Fallback: write individually (shouldn't happen in normal operation) + for i, buf := range bufs { + _, err := t.Write(buf) + if err != nil { + return i, err + } + } + return len(bufs), nil +} + func (t *tun) deviceBytes() (o [16]byte) { for i, c := range t.Device { o[i] = byte(c) diff --git a/udp/conn.go b/udp/conn.go index 8c821d3..f3267c9 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -13,13 +13,21 @@ type EncReader func( payload []byte, ) +type EncBatchReader func( + addrs []netip.AddrPort, + payloads [][]byte, + count int, +) + type Conn interface { Rebind() error LocalAddr() (netip.AddrPort, error) ListenOut(r EncReader) + ListenOutBatch(r EncBatchReader) WriteTo(b []byte, addr netip.AddrPort) error WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) ReloadConfig(c *config.C) + BatchSize() int Close() error } @@ -34,6 +42,9 @@ func (NoopConn) LocalAddr() (netip.AddrPort, error) { func (NoopConn) ListenOut(_ EncReader) { return } +func (NoopConn) ListenOutBatch(_ EncBatchReader) { + return +} func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { return nil } @@ -43,6 +54,9 @@ func (NoopConn) WriteMulti(_ [][]byte, _ []netip.AddrPort) (int, error) { func (NoopConn) ReloadConfig(_ *config.C) { return } +func (NoopConn) BatchSize() int { + return 1 +} func (NoopConn) Close() error { return nil } diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index b409d77..787f1c4 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -195,6 +195,34 @@ func (u *StdConn) ListenOut(r EncReader) { } } +// ListenOutBatch - fallback to single-packet reads for Darwin +func (u *StdConn) ListenOutBatch(r EncBatchReader) { + buffer := make([]byte, MTU) + addrs := make([]netip.AddrPort, 1) + payloads := make([][]byte, 1) + + for { + // Just read one packet at a time and call batch callback with count=1 + n, rua, err := u.ReadFromUDPAddrPort(buffer) + if err != nil { + if errors.Is(err, net.ErrClosed) { + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return + } + + u.l.WithError(err).Error("unexpected udp socket receive error") + } + + addrs[0] = netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()) + payloads[0] = buffer[:n] + r(addrs, payloads, 1) + } +} + +func (u *StdConn) BatchSize() int { + return 1 +} + func (u *StdConn) Rebind() error { var err error if u.isV4 { diff --git a/udp/udp_generic.go b/udp/udp_generic.go index cb21e57..7c8cdf4 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -85,3 +85,42 @@ func (u *GenericConn) ListenOut(r EncReader) { r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) } } + +// ListenOutBatch - fallback to single-packet reads for generic platforms +func (u *GenericConn) ListenOutBatch(r EncBatchReader) { + buffer := make([]byte, MTU) + addrs := make([]netip.AddrPort, 1) + payloads := make([][]byte, 1) + + for { + // Just read one packet at a time and call batch callback with count=1 + n, rua, err := u.ReadFromUDPAddrPort(buffer) + if err != nil { + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return + } + + addrs[0] = netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()) + payloads[0] = buffer[:n] + r(addrs, payloads, 1) + } +} + +// WriteMulti sends multiple packets - fallback implementation +func (u *GenericConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) { + for i := range packets { + err := u.WriteTo(packets[i], addrs[i]) + if err != nil { + return i, err + } + } + return len(packets), nil +} + +func (u *GenericConn) BatchSize() int { + return 1 +} + +func (u *GenericConn) Rebind() error { + return nil +} diff --git a/udp/udp_linux.go b/udp/udp_linux.go index c591a09..efb71c3 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -174,6 +174,46 @@ func (u *StdConn) ListenOut(r EncReader) { } } +func (u *StdConn) ListenOutBatch(r EncBatchReader) { + var ip netip.Addr + + msgs, buffers, names := u.PrepareRawMessages(u.batch) + read := u.ReadMulti + if u.batch == 1 { + read = u.ReadSingle + } + + udpBatchHist := metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024)) + + // Pre-allocate slices for batch callback + addrs := make([]netip.AddrPort, u.batch) + payloads := make([][]byte, u.batch) + + for { + n, err := read(msgs) + if err != nil { + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return + } + + udpBatchHist.Update(int64(n)) + + // Prepare batch data + for i := 0; i < n; i++ { + if u.isV4 { + ip, _ = netip.AddrFromSlice(names[i][4:8]) + } else { + ip, _ = netip.AddrFromSlice(names[i][8:24]) + } + addrs[i] = netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])) + payloads[i] = buffers[i][:msgs[i].Len] + } + + // Call batch callback with all packets + r(addrs, payloads, n) + } +} + func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) { for { n, _, err := unix.Syscall6( @@ -463,6 +503,10 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { return nil } +func (u *StdConn) BatchSize() int { + return u.batch +} + func (u *StdConn) Close() error { return syscall.Close(u.sysFd) } diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 8d5e6c1..d88d8aa 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -116,6 +116,31 @@ func (u *TesterConn) ListenOut(r EncReader) { } } +func (u *TesterConn) ListenOutBatch(r EncBatchReader) { + addrs := make([]netip.AddrPort, 1) + payloads := make([][]byte, 1) + + for { + p, ok := <-u.RxPackets + if !ok { + return + } + addrs[0] = p.From + payloads[0] = p.Data + r(addrs, payloads, 1) + } +} + +func (u *TesterConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) { + for i := range packets { + err := u.WriteTo(packets[i], addrs[i]) + if err != nil { + return i, err + } + } + return len(packets), nil +} + func (u *TesterConn) ReloadConfig(*config.C) {} func NewUDPStatsEmitter(_ []Conn) func() { @@ -127,6 +152,10 @@ func (u *TesterConn) LocalAddr() (netip.AddrPort, error) { return u.Addr, nil } +func (u *TesterConn) BatchSize() int { + return 1 +} + func (u *TesterConn) Rebind() error { return nil }