zero copy even with virtioheder

This commit is contained in:
Jay Wren
2025-11-19 12:03:38 -05:00
parent 518a78c9d2
commit 8b32382cd9
5 changed files with 27 additions and 63 deletions

View File

@@ -22,6 +22,7 @@ import (
) )
const mtu = 9001 const mtu = 9001
const virtioNetHdrLen = overlay.VirtioNetHdrLen
type InterfaceConfig struct { type InterfaceConfig struct {
HostMap *HostMap HostMap *HostMap
@@ -266,13 +267,16 @@ func (f *Interface) listenOut(i int) {
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
lhh := f.lightHouse.NewRequestHandler() lhh := f.lightHouse.NewRequestHandler()
plaintext := make([]byte, udp.MTU)
// Allocate plaintext buffer with virtio header headroom to avoid copies on TUN write
plaintext := make([]byte, virtioNetHdrLen+udp.MTU)
h := &header.H{} h := &header.H{}
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12) nb := make([]byte, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) f.readOutsidePackets(fromUdpAddr, nil, plaintext[:virtioNetHdrLen], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
}) })
} }
@@ -298,11 +302,10 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
func (f *Interface) listenInSingle(reader io.ReadWriteCloser, i int) { func (f *Interface) listenInSingle(reader io.ReadWriteCloser, i int) {
packet := make([]byte, mtu) packet := make([]byte, mtu)
// Allocate out buffer with virtio header headroom (10 bytes) to avoid copies on write // Allocate out buffer with virtio header headroom (10 bytes) to avoid copies on write
const virtioNetHdrLen = 10
outBuf := make([]byte, virtioNetHdrLen+mtu) outBuf := make([]byte, virtioNetHdrLen+mtu)
out := outBuf[virtioNetHdrLen:] // Use slice starting after headroom out := outBuf[virtioNetHdrLen:] // Use slice starting after headroom
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12) nb := make([]byte, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
@@ -324,7 +327,6 @@ func (f *Interface) listenInSingle(reader io.ReadWriteCloser, i int) {
func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchReader, i int) { func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchReader, i int) {
batchSize := batchReader.BatchSize() batchSize := batchReader.BatchSize()
const virtioNetHdrLen = 10
// Allocate buffers for batch reading // Allocate buffers for batch reading
bufs := make([][]byte, batchSize) bufs := make([][]byte, batchSize)
@@ -346,7 +348,7 @@ func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchRe
batchAddrs := make([]netip.AddrPort, 0, batchSize) batchAddrs := make([]netip.AddrPort, 0, batchSize)
// Pre-allocate nonce buffer (reused for all encryptions) // Pre-allocate nonce buffer (reused for all encryptions)
nb := make([]byte, 12, 12) nb := make([]byte, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)

View File

@@ -95,8 +95,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
switch relay.Type { switch relay.Type {
case TerminalType: case TerminalType:
// If I am the target of this relay, process the unwrapped packet // If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:virtioNetHdrLen], signedPayload, h, fwPacket, lhf, nb, q, localCache)
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
return return
case ForwardingType: case ForwardingType:
// Find the target HostInfo relay object // Find the target HostInfo relay object
@@ -474,9 +473,11 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return false return false
} }
err = newPacket(out, true, fwPacket) packetData := out[virtioNetHdrLen:]
err = newPacket(packetData, true, fwPacket)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("packet", out). hostinfo.logger(f.l).WithError(err).WithField("packet", packetData).
Warnf("Error while validating inbound packet") Warnf("Error while validating inbound packet")
return false return false
} }
@@ -491,7 +492,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
if dropReason != nil { if dropReason != nil {
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in // This gives us a buffer to build the reject packet in
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q) f.rejectOutside(packetData, hostinfo.ConnectionState, hostinfo, nb, packet, q)
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket). hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
WithField("reason", dropReason). WithField("reason", dropReason).

View File

@@ -11,6 +11,7 @@ import (
) )
const DefaultMTU = 1300 const DefaultMTU = 1300
const VirtioNetHdrLen = 10 // Size of virtio_net_hdr structure
// TODO: We may be able to remove routines // TODO: We may be able to remove routines
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)

View File

@@ -66,10 +66,6 @@ type ifreqQLEN struct {
pad [8]byte pad [8]byte
} }
const (
virtioNetHdrLen = 10 // Size of virtio_net_hdr structure
)
// wgDeviceWrapper wraps a wireguard Device to implement io.ReadWriteCloser // wgDeviceWrapper wraps a wireguard Device to implement io.ReadWriteCloser
// This allows multiqueue readers to use the same wireguard Device batching as the main device // This allows multiqueue readers to use the same wireguard Device batching as the main device
type wgDeviceWrapper struct { type wgDeviceWrapper struct {
@@ -92,27 +88,11 @@ func (w *wgDeviceWrapper) Read(b []byte) (int, error) {
} }
func (w *wgDeviceWrapper) Write(b []byte) (int, error) { func (w *wgDeviceWrapper) Write(b []byte) (int, error) {
// Check if buffer has the expected headroom pattern to avoid copy // Buffer b should have virtio header space (10 bytes) at the beginning
var buf []byte // The decrypted packet data starts at offset 10
// Pass the full buffer to WireGuard with offset=virtioNetHdrLen
if cap(b) >= len(b)+virtioNetHdrLen { bufs := [][]byte{b}
buf = b[:cap(b)] n, err := w.dev.Write(bufs, VirtioNetHdrLen)
if len(buf) == len(b)+virtioNetHdrLen {
// Perfect! Buffer has headroom, no copy needed
buf = buf[:len(b)+virtioNetHdrLen]
} else {
// Unexpected capacity, safer to copy
buf = make([]byte, virtioNetHdrLen+len(b))
copy(buf[virtioNetHdrLen:], b)
}
} else {
// No headroom, need to allocate and copy
buf = make([]byte, virtioNetHdrLen+len(b))
copy(buf[virtioNetHdrLen:], b)
}
bufs := [][]byte{buf}
n, err := w.dev.Write(bufs, virtioNetHdrLen)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -419,32 +399,11 @@ func (t *tun) BatchSize() int {
func (t *tun) Write(b []byte) (int, error) { func (t *tun) Write(b []byte) (int, error) {
if t.wgDevice != nil { if t.wgDevice != nil {
// Use wireguard device which handles virtio headers internally // Buffer b should have virtio header space (10 bytes) at the beginning
// Check if buffer has the expected headroom pattern: // The decrypted packet data starts at offset 10
// cap(b) should be len(b) + virtioNetHdrLen, indicating pre-allocated headroom // Pass the full buffer to WireGuard with offset=virtioNetHdrLen
var buf []byte bufs := [][]byte{b}
n, err := t.wgDevice.Write(bufs, VirtioNetHdrLen)
if cap(b) >= len(b)+virtioNetHdrLen {
// Buffer likely has headroom - use unsafe to access it
// Create a slice that includes the headroom by re-slicing from capacity
buf = b[:cap(b)]
// Check if we have exactly the right amount of extra capacity
if len(buf) == len(b)+virtioNetHdrLen {
// Perfect! This buffer was allocated with headroom, no copy needed
buf = buf[:len(b)+virtioNetHdrLen]
} else {
// Unexpected capacity, safer to copy
buf = make([]byte, virtioNetHdrLen+len(b))
copy(buf[virtioNetHdrLen:], b)
}
} else {
// No headroom, need to allocate and copy
buf = make([]byte, virtioNetHdrLen+len(b))
copy(buf[virtioNetHdrLen:], b)
}
bufs := [][]byte{buf}
n, err := t.wgDevice.Write(bufs, virtioNetHdrLen)
if err != nil { if err != nil {
return 0, err return 0, err
} }

View File

@@ -6,6 +6,7 @@ import (
"log" "log"
"net" "net"
"net/http" "net/http"
_ "net/http/pprof"
"runtime" "runtime"
"strconv" "strconv"
"time" "time"