what about with bad GRO on UDP

This commit is contained in:
JackDoan
2025-11-10 14:47:38 -06:00
parent 42591c2042
commit c645a45438
8 changed files with 411 additions and 223 deletions

View File

@@ -7,6 +7,7 @@ import (
"time"
"github.com/google/gopacket/layers"
"github.com/slackhq/nebula/packet"
"golang.org/x/net/ipv6"
"github.com/sirupsen/logrus"
@@ -216,21 +217,14 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
f.connectionManager.In(hostinfo)
}
func (f *Interface) readOutsidePacketsMany(ip []netip.AddrPort, out [][]byte, outNeedsTun []*int, packets [][]byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
for i, packet := range packets {
err := h.Parse(packet)
if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
if len(packet) > 1 {
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err)
}
return
}
func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
for i, pkt := range packets {
out[i].Scratch = out[i].Scratch[:0]
ip := pkt.AddrPort()
//l.Error("in packet ", header, packet[HeaderLen:])
if ip[i].IsValid() {
if f.myVpnNetworksTable.Contains(ip[i].Addr()) {
if ip.IsValid() {
if f.myVpnNetworksTable.Contains(ip.Addr()) {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
}
@@ -238,182 +232,194 @@ func (f *Interface) readOutsidePacketsMany(ip []netip.AddrPort, out [][]byte, ou
}
}
var hostinfo *HostInfo
// verify if we've seen this index before, otherwise respond to the handshake initiation
if h.Type == header.Message && h.Subtype == header.MessageRelay {
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
} else {
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
}
//todo per-segment!
for segment := range pkt.Segments() {
var ci *ConnectionState
if hostinfo != nil {
ci = hostinfo.ConnectionState
}
switch h.Type {
case header.Message:
// TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
if !f.handleEncrypted(ci, ip[i], h) {
err := h.Parse(segment)
if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
if len(segment) > 1 {
f.l.WithField("packet", pkt).Infof("Error while parsing inbound packet from %s: %s", ip, err)
}
return
}
switch h.Subtype {
case header.MessageNone:
out[i] = out[i][:0]
if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out[i][:0], outNeedsTun[i], packet, fwPacket, nb, q, localCache) {
return
}
case header.MessageRelay:
// The entire body is sent as AD, not encrypted.
// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
// otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice
// which will gracefully fail in the DecryptDanger call.
signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()]
signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():]
out[i], err = hostinfo.ConnectionState.dKey.DecryptDanger(out[i], signedPayload, signatureValue, h.MessageCounter, nb)
if err != nil {
return
}
// Successfully validated the thing. Get rid of the Relay header.
signedPayload = signedPayload[header.Len:]
// Pull the Roaming parts up here, and return in all call paths.
f.handleHostRoaming(hostinfo, ip[i])
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
f.connectionManager.In(hostinfo)
f.connectionManager.RelayUsed(h.RemoteIndex)
var hostinfo *HostInfo
// verify if we've seen this index before, otherwise respond to the handshake initiation
if h.Type == header.Message && h.Subtype == header.MessageRelay {
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
} else {
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
}
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
if !ok {
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
// its internal mapping. This should never happen.
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
var ci *ConnectionState
if hostinfo != nil {
ci = hostinfo.ConnectionState
}
switch h.Type {
case header.Message:
// TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
if !f.handleEncrypted(ci, ip, h) {
return
}
switch relay.Type {
case TerminalType:
// If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[i][:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
return
case ForwardingType:
// Find the target HostInfo relay object
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
switch h.Subtype {
case header.MessageNone:
if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out[i], pkt, segment, fwPacket, nb, q, localCache) {
return
}
case header.MessageRelay:
// The entire body is sent as AD, not encrypted.
// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
// otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice
// which will gracefully fail in the DecryptDanger call.
signedPayload := segment[:len(segment)-hostinfo.ConnectionState.dKey.Overhead()]
signatureValue := segment[len(segment)-hostinfo.ConnectionState.dKey.Overhead():]
out[i].Scratch, err = hostinfo.ConnectionState.dKey.DecryptDanger(out[i].Scratch, signedPayload, signatureValue, h.MessageCounter, nb)
if err != nil {
hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
return
}
// Successfully validated the thing. Get rid of the Relay header.
signedPayload = signedPayload[header.Len:]
// Pull the Roaming parts up here, and return in all call paths.
f.handleHostRoaming(hostinfo, ip)
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
f.connectionManager.In(hostinfo)
f.connectionManager.RelayUsed(h.RemoteIndex)
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
if !ok {
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
// its internal mapping. This should never happen.
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
return
}
// If that relay is Established, forward the payload through it
if targetRelay.State == Established {
switch targetRelay.Type {
case ForwardingType:
// Forward this packet through the relay tunnel
// Find the target HostInfo
f.SendVia(targetHI, targetRelay, signedPayload, nb, out[i], false)
return
case TerminalType:
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
}
} else {
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
switch relay.Type {
case TerminalType:
// If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[i].Scratch[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
return
case ForwardingType:
// Find the target HostInfo relay object
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
if err != nil {
hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
return
}
// If that relay is Established, forward the payload through it
if targetRelay.State == Established {
switch targetRelay.Type {
case ForwardingType:
// Forward this packet through the relay tunnel
// Find the target HostInfo
f.SendVia(targetHI, targetRelay, signedPayload, nb, out[i].Scratch, false)
return
case TerminalType:
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
}
} else {
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
return
}
}
}
}
case header.LightHouse:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, ip[i], h) {
case header.LightHouse:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, ip, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", segment).
Error("Failed to decrypt lighthouse packet")
return
}
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
// Fallthrough to the bottom to record incoming traffic
case header.Test:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, ip, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", segment).
Error("Failed to decrypt test packet")
return
}
if h.Subtype == header.TestRequest {
// This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding
f.handleHostRoaming(hostinfo, ip)
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out[i].Scratch)
}
// Fallthrough to the bottom to record incoming traffic
// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
// are unauthenticated
case header.Handshake:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handshakeManager.HandleIncoming(ip, nil, segment, h)
return
case header.RecvError:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handleRecvError(ip, h)
return
case header.CloseTunnel:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, ip, h) {
return
}
hostinfo.logger(f.l).WithField("udpAddr", ip).
Info("Close tunnel received, tearing down.")
f.closeTunnel(hostinfo)
return
case header.Control:
if !f.handleEncrypted(ci, ip, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", segment).
Error("Failed to decrypt Control packet")
return
}
f.relayManager.HandleControlMsg(hostinfo, d, f)
default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip)
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i], packet, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", packet).
Error("Failed to decrypt lighthouse packet")
return
}
f.handleHostRoaming(hostinfo, ip)
lhf.HandleRequest(ip[i], hostinfo.vpnAddrs, d, f)
// Fallthrough to the bottom to record incoming traffic
case header.Test:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, ip[i], h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i], packet, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", packet).
Error("Failed to decrypt test packet")
return
}
if h.Subtype == header.TestRequest {
// This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding
f.handleHostRoaming(hostinfo, ip[i])
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out[i])
}
// Fallthrough to the bottom to record incoming traffic
// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
// are unauthenticated
case header.Handshake:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handshakeManager.HandleIncoming(ip[i], nil, packet, h)
return
case header.RecvError:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handleRecvError(ip[i], h)
return
case header.CloseTunnel:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, ip[i], h) {
return
}
hostinfo.logger(f.l).WithField("udpAddr", ip).
Info("Close tunnel received, tearing down.")
f.closeTunnel(hostinfo)
return
case header.Control:
if !f.handleEncrypted(ci, ip[i], h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i], packet, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", packet).
Error("Failed to decrypt Control packet")
return
}
f.relayManager.HandleControlMsg(hostinfo, d, f)
default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip)
return
f.connectionManager.In(hostinfo)
}
f.handleHostRoaming(hostinfo, ip[i])
f.connectionManager.In(hostinfo)
}
}
@@ -666,16 +672,17 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
return out, nil
}
func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out []byte, outNeedsTun *int, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out *packet.OutPacket, pkt *packet.Packet, inSegment []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
var err error
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
out.Segments[out.SegCounter] = out.Segments[out.SegCounter][:0]
out.Segments[out.SegCounter], err = hostinfo.ConnectionState.dKey.DecryptDanger(out.Segments[out.SegCounter], inSegment[:header.Len], inSegment[header.Len:], messageCounter, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
return false
}
err = newPacket(out, true, fwPacket)
err = newPacket(out.Segments[out.SegCounter], true, fwPacket)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
Warnf("Error while validating inbound packet")
@@ -692,7 +699,7 @@ func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter ui
if dropReason != nil {
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
f.rejectOutside(out.Segments[out.SegCounter], hostinfo.ConnectionState, hostinfo, nb, inSegment, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
@@ -702,7 +709,8 @@ func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter ui
}
f.connectionManager.In(hostinfo)
*outNeedsTun = len(out)
pkt.OutLen += len(inSegment)
out.SegCounter++
return true
}