use wg tun library; batching & locking improvements

This commit is contained in:
Jay Wren
2025-11-04 15:04:24 -05:00
parent 42bee7cf17
commit 5cc3ff594a
33 changed files with 1353 additions and 121 deletions

View File

@@ -102,7 +102,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
relay: relay,
IsRelayed: true,
}
f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
f.readOutsidePackets(via, out[:virtioNetHdrLen], signedPayload, h, fwPacket, lhf, nb, q, localCache)
return
case ForwardingType:
// Find the target HostInfo relay object
@@ -145,7 +145,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
}
//TODO: assert via is not relayed
lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f)
lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d[virtioNetHdrLen:], f)
// Fallthrough to the bottom to record incoming traffic
@@ -167,7 +167,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
// This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding
f.handleHostRoaming(hostinfo, via)
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
f.send(header.Test, header.TestReply, ci, hostinfo, d[virtioNetHdrLen:], nb, out)
}
// Fallthrough to the bottom to record incoming traffic
@@ -210,7 +210,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
return
}
f.relayManager.HandleControlMsg(hostinfo, d, f)
f.relayManager.HandleControlMsg(hostinfo, d[virtioNetHdrLen:], f)
default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
@@ -481,9 +481,11 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return false
}
err = newPacket(out, true, fwPacket)
packetData := out[virtioNetHdrLen:]
err = newPacket(packetData, true, fwPacket)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
hostinfo.logger(f.l).WithError(err).WithField("packet", packetData).
Warnf("Error while validating inbound packet")
return false
}
@@ -498,7 +500,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
if dropReason != nil {
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
f.rejectOutside(packetData, hostinfo.ConnectionState, hostinfo, nb, packet, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
@@ -562,3 +564,108 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
// We also delete it from pending hostmap to allow for fast reconnect.
f.handshakeManager.DeleteHostInfo(hostinfo)
}
// readOutsidePacketsBatch processes multiple packets received from UDP in a batch
// and writes all successfully decrypted packets to TUN in a single operation
func (f *Interface) readOutsidePacketsBatch(addrs []netip.AddrPort, payloads [][]byte, count int, outs [][]byte, nb []byte, q int, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, localCache firewall.ConntrackCache) {
// Pre-allocate slice for accumulating successful decryptions
tunPackets := make([][]byte, 0, count)
for i := 0; i < count; i++ {
payload := payloads[i]
addr := addrs[i]
out := outs[i]
// Parse header
err := h.Parse(payload)
if err != nil {
if len(payload) > 1 {
f.l.WithField("packet", payload).Infof("Error while parsing inbound packet from %s: %s", addr, err)
}
continue
}
if addr.IsValid() {
if f.myVpnNetworksTable.Contains(addr.Addr()) {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("udpAddr", addr).Debug("Refusing to process double encrypted packet")
}
continue
}
}
var hostinfo *HostInfo
if h.Type == header.Message && h.Subtype == header.MessageRelay {
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
} else {
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
}
var ci *ConnectionState
if hostinfo != nil {
ci = hostinfo.ConnectionState
}
switch h.Type {
case header.Message:
if !f.handleEncrypted(ci, ViaSender{UdpAddr: addr}, h) {
continue
}
switch h.Subtype {
case header.MessageNone:
// Decrypt packet
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, payload[:header.Len], payload[header.Len:], h.MessageCounter, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
continue
}
packetData := out[virtioNetHdrLen:]
err = newPacket(packetData, true, fwPacket)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("packet", packetData).Warnf("Error while validating inbound packet")
continue
}
if !hostinfo.ConnectionState.window.Update(f.l, h.MessageCounter) {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).Debugln("dropping out of window packet")
continue
}
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason != nil {
f.rejectOutside(packetData, hostinfo.ConnectionState, hostinfo, nb, payload, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).WithField("reason", dropReason).Debugln("dropping inbound packet")
}
continue
}
f.connectionManager.In(hostinfo)
// Add to batch for TUN write
tunPackets = append(tunPackets, out)
case header.MessageRelay:
// Skip relay packets in batch mode for now (less common path)
f.readOutsidePackets(ViaSender{UdpAddr: addr}, out[:virtioNetHdrLen], payload, h, fwPacket, lhf, nb, q, localCache)
default:
hostinfo.logger(f.l).Debugf("unexpected message subtype %d", h.Subtype)
}
default:
// Handle non-Message types using single-packet path
f.readOutsidePackets(ViaSender{UdpAddr: addr}, out[:virtioNetHdrLen], payload, h, fwPacket, lhf, nb, q, localCache)
}
}
if len(tunPackets) > 0 {
n, err := f.readers[q].WriteBatch(tunPackets, virtioNetHdrLen)
if err != nil {
f.l.WithError(err).WithField("sent", n).WithField("total", len(tunPackets)).Error("Failed to batch write to tun")
}
}
}