From 4fb5cdb4faaa1c47ef0c8e59fb46641db707dca9 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Wed, 6 May 2026 12:23:27 -0400 Subject: [PATCH] refactor readOutsidePackets (#1642) * refactor readOutsidePackets They layout of this method is confusing and relys on certain parts to return early for things to work correctly. Change the ordering of the logic so that we do this: - Handle unencrypted packets - Decrypt packet - Handle encrypted packets This way, nothing can sneak through unencrypted to where it shouldn't be. * fix comment * code review comments * check for expected type/subtype * check header version * log header * need to handle TestReply * clean roaming / connectionManager * dont need to roam here now, we do it earlier * cleanup metrics and errors * rxInvalid * debug logger checks * ErrOutOfWindow --- header/header.go | 14 ++ message_metrics.go | 8 + outside.go | 413 +++++++++++++++++++++------------------------ 3 files changed, 210 insertions(+), 225 deletions(-) diff --git a/header/header.go b/header/header.go index f22509b8..b973141f 100644 --- a/header/header.go +++ b/header/header.go @@ -174,6 +174,10 @@ func (h *H) SubTypeName() string { return SubTypeName(h.Type, h.Subtype) } +func (h *H) IsValidSubType() bool { + return IsValidSubType(h.Type, h.Subtype) +} + // SubTypeName will transform a nebula message sub type into a human string func SubTypeName(t MessageType, s MessageSubType) string { if n, ok := subTypeMap[t]; ok { @@ -185,6 +189,16 @@ func SubTypeName(t MessageType, s MessageSubType) string { return "unknown" } +func IsValidSubType(t MessageType, s MessageSubType) bool { + if n, ok := subTypeMap[t]; ok { + if _, ok := (*n)[s]; ok { + return true + } + } + + return false +} + // NewHeader turns bytes into a header func NewHeader(b []byte) (*H, error) { h := new(H) diff --git a/message_metrics.go b/message_metrics.go index 10e8472c..45de9a5c 100644 --- a/message_metrics.go +++ b/message_metrics.go @@ -13,6 +13,8 @@ type MessageMetrics struct { rxUnknown metrics.Counter txUnknown metrics.Counter + + rxInvalid metrics.Counter } func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int64) { @@ -33,6 +35,11 @@ func (m *MessageMetrics) Tx(t header.MessageType, s header.MessageSubType, i int } } } +func (m *MessageMetrics) RxInvalid(i int64) { + if m != nil && m.rxInvalid != nil { + m.rxInvalid.Inc(i) + } +} func newMessageMetrics() *MessageMetrics { gen := func(t string) [][]metrics.Counter { @@ -56,6 +63,7 @@ func newMessageMetrics() *MessageMetrics { rxUnknown: metrics.GetOrRegisterCounter("messages.rx.other", nil), txUnknown: metrics.GetOrRegisterCounter("messages.tx.other", nil), + rxInvalid: metrics.GetOrRegisterCounter("messages.rx.invalid", nil), } } diff --git a/outside.go b/outside.go index 1e00a0a9..17013ed3 100644 --- a/outside.go +++ b/outside.go @@ -20,23 +20,46 @@ const ( minFwPacketLen = 4 ) +var ErrOutOfWindow = errors.New("out of window packet") + func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { err := h.Parse(packet) if err != nil { // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors + // TODO: record metrics for rx holepunch/punchy packets? if len(packet) > 1 { - f.l.Info("Error while parsing inbound packet", - "from", via, - "error", err, - "packet", packet, - ) + f.messageMetrics.RxInvalid(1) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Error while parsing inbound packet", + "from", via, + "error", err, + "packet", packet, + ) + } + } + return + } + + if h.Version != header.Version { + f.messageMetrics.RxInvalid(1) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Unexpected header version received", "from", via) + } + return + } + + // Check before processing to see if this is a expected type/subtype + if !h.IsValidSubType() { + f.messageMetrics.RxInvalid(1) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Unexpected packet received", "from", via) } return } - //l.Error("in packet ", header, packet[HeaderLen:]) if !via.IsRelayed { if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) { + f.messageMetrics.RxInvalid(1) if f.l.Enabled(context.Background(), slog.LevelDebug) { f.l.Debug("Refusing to process double encrypted packet", "from", via) } @@ -44,215 +67,192 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, } } + // don't keep Rx metrics for message type, since you can see those in the tun metrics + if h.Type != header.Message { + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + } + + // Unencrypted packets + switch h.Type { + case header.Handshake: + f.handshakeManager.HandleIncoming(via, packet, h) + return + + case header.RecvError: + f.handleRecvError(via.UdpAddr, h) + return + } + + // Relay packets are special + isMessageRelay := (h.Type == header.Message && h.Subtype == header.MessageRelay) + var hostinfo *HostInfo - // verify if we've seen this index before, otherwise respond to the handshake initiation - if h.Type == header.Message && h.Subtype == header.MessageRelay { + if isMessageRelay { hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex) } else { hostinfo = f.hostMap.QueryIndex(h.RemoteIndex) } - var ci *ConnectionState - if hostinfo != nil { - ci = hostinfo.ConnectionState + // At this point we should have a valid existing tunnel, verify and send + // recvError if necessary + if hostinfo == nil || hostinfo.ConnectionState == nil { + if !via.IsRelayed { + f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex) + } + return } + // All remaining packets are encrypted + ci := hostinfo.ConnectionState + if !ci.window.Check(f.l, h.MessageCounter) { + return + } + + // Relay packets are special + if isMessageRelay { + f.handleOutsideRelayPacket(hostinfo, via, out, packet, h, fwPacket, lhf, nb, q, localCache) + + return + } + + out, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) + if err != nil { + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Failed to decrypt packet", + "error", err, + "from", via, + "header", h, + ) + } + return + } + + // Roam before we respond + f.handleHostRoaming(hostinfo, via) + f.connectionManager.In(hostinfo) + switch h.Type { case header.Message: - if !f.handleEncrypted(ci, via, h) { - return - } - switch h.Subtype { case header.MessageNone: - if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) { - return - } - case header.MessageRelay: - // The entire body is sent as AD, not encrypted. - // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. - // The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's - // otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice - // which will gracefully fail in the DecryptDanger call. - signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()] - signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():] - out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, signedPayload, signatureValue, h.MessageCounter, nb) - if err != nil { - return - } - // Successfully validated the thing. Get rid of the Relay header. - signedPayload = signedPayload[header.Len:] - // Pull the Roaming parts up here, and return in all call paths. - f.handleHostRoaming(hostinfo, via) - // Track usage of both the HostInfo and the Relay for the received & authenticated packet - f.connectionManager.In(hostinfo) - f.connectionManager.RelayUsed(h.RemoteIndex) - - relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex) - if !ok { - // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing - // its internal mapping. This should never happen. - hostinfo.logger(f.l).Error("HostInfo missing remote relay index", - "vpnAddrs", hostinfo.vpnAddrs, - "remoteIndex", h.RemoteIndex, - ) - return - } - - switch relay.Type { - case TerminalType: - // If I am the target of this relay, process the unwrapped packet - // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. - via = ViaSender{ - UdpAddr: via.UdpAddr, - relayHI: hostinfo, - remoteIdx: relay.RemoteIndex, - relay: relay, - IsRelayed: true, - } - f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) - return - case ForwardingType: - // Find the target HostInfo relay object - targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) - if err != nil { - hostinfo.logger(f.l).Info("Failed to find target host info by ip", - "relayTo", relay.PeerAddr, - "error", err, - "hostinfo.vpnAddrs", hostinfo.vpnAddrs, - ) - return - } - - // If that relay is Established, forward the payload through it - if targetRelay.State == Established { - switch targetRelay.Type { - case ForwardingType: - // Forward this packet through the relay tunnel - // Find the target HostInfo - f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false) - return - case TerminalType: - hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") - } - } else { - hostinfo.logger(f.l).Info("Unexpected target relay state", - "relayTo", relay.PeerAddr, - "relayFrom", hostinfo.vpnAddrs[0], - "targetRelayState", targetRelay.State, - ) - return - } - } + f.handleOutsideMessagePacket(hostinfo, out, packet, fwPacket, nb, q, localCache) + default: + hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message subtype seen", "from", via, "header", h) + return } case header.LightHouse: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, via, h) { - return - } - - d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) - if err != nil { - hostinfo.logger(f.l).Error("Failed to decrypt lighthouse packet", - "error", err, - "from", via, - "packet", packet, - ) - return - } - //TODO: assert via is not relayed - lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f) - - // Fallthrough to the bottom to record incoming traffic + lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, out, f) case header.Test: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, via, h) { + switch h.Subtype { + case header.TestReply: + // No-op, useful for the Roaming and connectionManager side-effects above + case header.TestRequest: + f.send(header.Test, header.TestReply, ci, hostinfo, out, nb, out) + default: + hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected test subtype seen", "from", via, "header", h) return } - d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) - if err != nil { - hostinfo.logger(f.l).Error("Failed to decrypt test packet", - "error", err, - "from", via, - "packet", packet, - ) - return - } - - if h.Subtype == header.TestRequest { - // This testRequest might be from TryPromoteBest, so we should roam - // to the new IP address before responding - f.handleHostRoaming(hostinfo, via) - f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out) - } - - // Fallthrough to the bottom to record incoming traffic - - // Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they - // are unauthenticated - - case header.Handshake: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handshakeManager.HandleIncoming(via, packet, h) - return - - case header.RecvError: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handleRecvError(via.UdpAddr, h) - return - case header.CloseTunnel: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, via, h) { - return - } - _, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) - if err != nil { - hostinfo.logger(f.l).Error("Failed to decrypt CloseTunnel packet", - "error", err, - "from", via, - "packet", packet, - ) - return - } - hostinfo.logger(f.l).Info("Close tunnel received, tearing down.", "from", via) - f.closeTunnel(hostinfo) - return case header.Control: - if !f.handleEncrypted(ci, via, h) { - return - } - - d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) - if err != nil { - hostinfo.logger(f.l).Error("Failed to decrypt Control packet", - "error", err, - "from", via, - "packet", packet, - ) - return - } - - f.relayManager.HandleControlMsg(hostinfo, d, f) + f.relayManager.HandleControlMsg(hostinfo, out, f) default: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if f.l.Enabled(context.Background(), slog.LevelDebug) { - hostinfo.logger(f.l).Debug("Unexpected packet received", "from", via) - } + hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message type seen", "from", via, "header", h) + } +} + +func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { + // The entire body is sent as AD, not encrypted. + // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. + // The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's + // otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice + // which will gracefully fail in the DecryptDanger call. + signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()] + signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():] + var err error + out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, signedPayload, signatureValue, h.MessageCounter, nb) + if err != nil { + return + } + // Successfully validated the thing. Get rid of the Relay header. + signedPayload = signedPayload[header.Len:] + // Pull the Roaming parts up here, and return in all call paths. + f.handleHostRoaming(hostinfo, via) + // Track usage of both the HostInfo and the Relay for the received & authenticated packet + f.connectionManager.In(hostinfo) + f.connectionManager.RelayUsed(h.RemoteIndex) + + relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex) + if !ok { + // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing + // its internal mapping. This should never happen. + hostinfo.logger(f.l).Error("HostInfo missing remote relay index", + "vpnAddrs", hostinfo.vpnAddrs, + "remoteIndex", h.RemoteIndex, + ) return } - f.handleHostRoaming(hostinfo, via) + switch relay.Type { + case TerminalType: + // If I am the target of this relay, process the unwrapped packet + // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. + via = ViaSender{ + UdpAddr: via.UdpAddr, + relayHI: hostinfo, + remoteIdx: relay.RemoteIndex, + relay: relay, + IsRelayed: true, + } + f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) + case ForwardingType: + // Find the target HostInfo relay object + targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) + if err != nil { + hostinfo.logger(f.l).Info("Failed to find target host info by ip", + "relayTo", relay.PeerAddr, + "error", err, + "hostinfo.vpnAddrs", hostinfo.vpnAddrs, + ) + return + } - f.connectionManager.In(hostinfo) + // If that relay is Established, forward the payload through it + if targetRelay.State == Established { + switch targetRelay.Type { + case ForwardingType: + // Forward this packet through the relay tunnel + // Find the target HostInfo + f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false) + case TerminalType: + hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") + return + default: + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Unexpected targetRelay Type", "from", via, "relayType", targetRelay.Type) + } + return + } + } else { + hostinfo.logger(f.l).Info("Unexpected target relay state", + "relayTo", relay.PeerAddr, + "relayFrom", hostinfo.vpnAddrs[0], + "targetRelayState", targetRelay.State, + ) + return + } + default: + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Unexpected relay type", "from", via, "relayType", relay.Type) + } + } } // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote @@ -300,23 +300,6 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) { } -// handleEncrypted returns true if a packet should be processed, false otherwise -func (f *Interface) handleEncrypted(ci *ConnectionState, via ViaSender, h *header.H) bool { - // If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect - if ci == nil { - if !via.IsRelayed { - f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex) - } - return false - } - // If the window check fails, refuse to process the packet, but don't send a recv error - if !ci.window.Check(f.l, h.MessageCounter) { - return false - } - - return true -} - var ( ErrPacketTooShort = errors.New("packet is too short") ErrUnknownIPVersion = errors.New("packet is an unknown ip version") @@ -523,38 +506,20 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] } if !hostinfo.ConnectionState.window.Update(f.l, mc) { - if f.l.Enabled(context.Background(), slog.LevelDebug) { - hostinfo.logger(f.l).Debug("dropping out of window packet", "header", h) - } - return nil, errors.New("out of window packet") + return nil, ErrOutOfWindow } return out, nil } -func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool { - var err error - - out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) - if err != nil { - hostinfo.logger(f.l).Error("Failed to decrypt packet", "error", err) - return false - } - - err = newPacket(out, true, fwPacket) +func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) { + err := newPacket(out, true, fwPacket) if err != nil { hostinfo.logger(f.l).Warn("Error while validating inbound packet", "error", err, "packet", out, ) - return false - } - - if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) { - if f.l.Enabled(context.Background(), slog.LevelDebug) { - hostinfo.logger(f.l).Debug("dropping out of window packet", "fwPacket", fwPacket) - } - return false + return } dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) @@ -568,15 +533,13 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out "reason", dropReason, ) } - return false + return } - f.connectionManager.In(hostinfo) _, err = f.readers[q].Write(out) if err != nil { f.l.Error("Failed to write to tun", "error", err) } - return true } func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) {